사용자 도구

사이트 도구


ps:teflib:segmenttree

segmenttree.py

imports and globals

from typing import Callable, Iterable, TypeVar

T = TypeVar('T')

SegmentTree

코드

# N SegmentTree
# I {"version": "1.41", "typing": ["Callable", "Iterable", "TypeVar"], "const": ["T"]}
class SegmentTree:
    """Bottom-up segment tree supporting point update and range query."""

    def __init__(self,
                 values: Iterable[T],
                 merge: Callable[[T, T], T] = min):
        l = list(values)
        self._size = len(l)
        self._tree = l + l
        self._merge = merge
        for i in range(self._size - 1, 0, -1):
            self._tree[i] = merge(self._tree[i * 2], self._tree[i * 2 + 1])

    def set(self, pos: int, value: T):
        i = pos + self._size
        while i:
            self._tree[i] = value
            value = (self._merge(self._tree[i - 1], value) if i % 2
                     else self._merge(value, self._tree[i + 1]))
            i >>= 1

    def get(self, pos: int) -> T:
        return self._tree[pos + self._size]

    def query(self, beg: int, end: int) -> T:
        if end == beg + 1:
            return self._tree[beg + self._size]
        l, r = beg + self._size + 1, end + self._size - 2
        ret_l, ret_r = self._tree[l - 1], self._tree[r + 1]
        while l <= r:
            if l % 2:
                ret_l = self._merge(ret_l, self._tree[l])
            if not r % 2:
                ret_r = self._merge(self._tree[r], ret_r)
            l, r = (l + 1) >> 1, (r - 1) >> 1
        return self._merge(ret_l, ret_r)

설명

  • merge인자를 안 주면 default로 min이 들어가서 구간 최솟값을 계산하게 된다.
  • 구간합 트리가 필요한 경우에는 merge에 operator.sum을 넘겨서 만드는 것 보다, teflib.fenwicktree.FenwickTree를 사용하는 것이 효율적이다.
  • v1.2 → v1.3의 변화
    • merge에서 처리하는 함수가 교환법칙이 성립하지 않을 때에도 (merge(a,b) != merge(b,a) 일때) 정확히 처리하도록 수정했다. 그러나 이 수정으로 인해 속도는 v1.2보다 조금 느려졌다.
    • update 함수를 제거했다. 기본적으로 set으로 커버가 되기도 하고, 쓸일이 별로 없기도 하다. update가 주로 쓰이는 것은 그나마 구간합 쿼리의 경우인데, 이 경우는 어차피 펜윅트리를 쓸것이므로.
  • v1.3 → v1.4 변화
    • get 함수 추가

이 코드를 사용하는 문제

출처문제 번호Page레벨
BOJ10167금광다이아몬드 5
BOJ11503가장 긴 증가하는 부분 수열실버 2
BOJ11055가장 큰 증가 부분 수열실버 2
BOJ11505구간 곱 구하기골드 1
BOJ12986화려한 마을2플래티넘 2
BOJ13557수열과 쿼리 10플래티넘 1
BOJ14002가장 긴 증가하는 부분 수열 4골드 4
BOJ14427수열과 쿼리 15골드 1
BOJ14428수열과 쿼리 16골드 1
BOJ14438수열과 쿼리 17골드 1
BOJ15560구간 합 최대? 1골드 2
BOJ15561구간 합 최대? 2플래티넘 2
BOJ16933연속합과 쿼리플래티넘 2
BOJ17407괄호 문자열과 쿼리플래티넘 2
BOJ17975Strike Zone다이아몬드 5
BOJ19651수열과 쿼리 39다이아몬드 5
BOJ6519Frequent values플래티넘 1

MinSegmentTree

코드

# N MinSegmentTree
# I {"version": "1.01", "typing": ["Iterable"]}
class MinSegmentTree:
    """Bottom-up segment tree supporting point update and range min query."""
    __slots__ = ('_size', '_tree')

    def __init__(self,
                 nums_or_size: Iterable[float] | int,
                 *,
                 default: float = 0):
        if isinstance(nums_or_size, int):
            self._size = nums_or_size
            self._tree = [default] * (nums_or_size + nums_or_size)
        else:
            l = list(nums_or_size)
            self._size = len(l)
            self._tree = l + l
            it = reversed(self._tree)
            for i in range(self._size - 1, 0, -1):
                self._tree[i] = min(next(it), next(it))

    def set(self, pos: int, value: float):
        i = pos + self._size
        while i:
            self._tree[i] = value
            adj_value = self._tree[i - 1] if i % 2 else self._tree[i + 1]
            if adj_value < value:
                value = adj_value
            i >>= 1

    def get(self, pos: int) -> float:
        return self._tree[pos + self._size]

    def query(self, beg: int, end: int) -> float:
        if end == beg + 1:
            return self._tree[beg + self._size]
        l, r = beg + self._size + 1, end + self._size - 2
        ret_l, ret_r = self._tree[l - 1], self._tree[r + 1]
        while l <= r:
            if l % 2 and self._tree[l] < ret_l:
                ret_l = self._tree[l]
            if not r % 2 and self._tree[r] < ret_r:
                ret_r = self._tree[r]
            l, r = (l + 1) >> 1, (r - 1) >> 1
        return min(ret_l, ret_r)

설명

  • SegmentTree 를 min 연산에 최적화 시킨 버전.
  • 연산자는 고정되어있으므로 인자로 넘길 필요가 없고, 인자로는 초기값들을 넘겨주거나, 크기와 디폴트값을 넘겨주거나 하면 된다,
  • merge 함수를 적용하던 부분을 그냥 비교연산을 통해서 작은값으로 업데이트하게 된다. op1과 op2의 순서가 상관 없으므로, 그부분도 간단해졌다. 그러면 이제 self._tree[i - 1] if i % 2 else self._tree[i + 1] 도 self._tree[i ^ 1] 로 간단하게 써도 되긴 하는데, 놀랍게도 속도가 더 느려진다..;

이 코드를 사용하는 문제

LazySegmentTree

코드

# N LazySegmentTree
# I {"version": "1.1", "typing": ["Callable", "Iterable", "TypeVar"], "const": ["ValueType", "ParamType"]}
class LazySegmentTree:
    def __init__(self,
                 values: Iterable[ValueType],
                 merge: Callable[[ValueType, ValueType], ValueType],
                 update_value: Callable[[ValueType, ParamType, int], ValueType],
                 update_param: Callable[[ParamType, ParamType], ParamType],
                 should_keep_update_order: bool = True):

        l = list(values)
        self._size = len(l)
        self._tree = l + l
        self._param = [None] * self._size
        self._merge = merge
        self._update_value = update_value
        self._update_param = update_param
        self._should_keep_update_order = should_keep_update_order
        for i in range(self._size - 1, 0, -1):
            self._tree[i] = merge(self._tree[i * 2], self._tree[i * 2 + 1])

    def _apply(self, pos: int, param: ParamType, size: int):
        self._tree[pos] = self._update_value(self._tree[pos], param, size)
        if pos < self._size:
            cur_param = self._param[pos]
            self._param[pos] = (param if cur_param is None
                                else self._update_param(cur_param, param))

    def _push_down(self, pos: int):
        h = self._size.bit_length()
        size = 1 << (h - 1)
        for i in range(h, 0, -1):
            parent = pos >> i
            param = self._param[parent]
            if param is not None:
                self._apply(parent * 2, param, size)
                self._apply(parent * 2 + 1, param, size)
                self._param[parent] = None
            size >>= 1

    def _build_up(self, pos: int):
        s = 1
        while pos > 1:
            pos >>= 1
            s *= 2
            t = self._merge(self._tree[pos * 2], self._tree[pos * 2 + 1])
            self._tree[pos] = (t if self._param[pos] is None
                               else self._update_value(t, self._param[pos], s))

    def range_update(self, beg: int, end: int, param: ParamType):
        l, r = beg + self._size, end + self._size - 1
        if self._should_keep_update_order:
            self._push_down(l)
            self._push_down(r)
        l2, r2, size = l, r, 1
        while l2 <= r2:
            if l2 % 2:
                self._apply(l2, param, size)
            if not r2 % 2:
                self._apply(r2, param, size)
            l2, r2 = (l2 + 1) >> 1, (r2 - 1) >> 1
            size *= 2
        self._build_up(l)
        self._build_up(r)

    def get(self, pos: int) -> ValueType:
        self._push_down(pos + self._size)
        return self._tree[pos + self._size]

    def query(self, beg: int, end: int) -> ValueType:
        if end == beg + 1:
            return self.get(beg)
        l, r = beg + self._size + 1, end + self._size - 2
        self._push_down(l - 1)
        self._push_down(r + 1)
        ret_l, ret_r = self._tree[l - 1], self._tree[r + 1]
        while l <= r:
            if l % 2:
                ret_l = self._merge(ret_l, self._tree[l])
            if not r % 2:
                ret_r = self._merge(self._tree[r], ret_r)
            l, r = (l + 1) >> 1, (r - 1) >> 1
        return self._merge(ret_l, ret_r)

설명

이 코드를 사용하는 문제

출처문제 번호Page레벨
BOJ12844XOR플래티넘 3
BOJ12895화려한 마을플래티넘 3
BOJ13309트리플래티넘 1
BOJ13925수열과 쿼리 13다이아몬드 5
BOJ1395스위치플래티넘 3
BOJ17429국제 메시 기구다이아몬드 4
BOJ18407가로 블록 쌓기플래티넘 3
BOJ18437회사 문화 5플래티넘 3
BOJ24320Rectpoints플래티넘 2

OrderStatisticTree

코드

# N OrderStatisticTree
# I {"version": "1.0"}
class OrderStatisticTree:
    def __init__(self, counts_or_max_num):
        if isinstance(counts_or_max_num, int):
            self._size = 1 << ((counts_or_max_num + 1).bit_length())
            self._tree = [0] * (self._size * 2)
        else:
            l = list(counts_or_max_num)
            self._size = 1 << (len(l) - 1).bit_length()
            self._tree = [0] * (self._size) + l + [0] * (self._size - len(l))
            for i in range(self._size - 1, 0, -1):
                self._tree[i] = self._tree[i + i] + self._tree[i + i + 1]

    def size(self) -> int:
        return self._tree[1]

    def count(self, num: int) -> int:
        return self._tree[num + self._size]

    def add(self, num: int, count: int = 1):
        i = num + self._size
        while i:
            self._tree[i] += count
            i >>= 1

    def kth(self, k: int) -> int:
        i = 1
        while i < self._size:
            i += i
            t = self._tree[i]
            if t < k:
                k -= t
                i += 1
        return i - self._size

    def count_less_than(self, num: int) -> int:
        ret = 0
        i = num + self._size - 1
        while i:
            if not i % 2:
                ret += self._tree[i]
                i -= 1
            i >>= 1
        return ret

설명

  • teflib.fenwicktree.OrderStatisticTree도 동일한 메소드들을 갖고 있고 시간 복잡도도 동일하다. 이쪽 구현이 count_less_than()에 대해서는 약간 더 빠르게 동작한다. 따라서 주로 사용할 연산이 count_less_than() 이라면 이쪽 구현체를 사용하자.

이 코드를 사용하는 문제

출처문제 번호Page레벨
BOJ1158요세푸스 문제실버 5
BOJ1168요세푸스 문제 2플래티넘 4
BOJ11866요세푸스 문제 0실버 4
BOJ1321군인플래티넘 4
프로그래머스81303표 편집Level 3

토론

댓글을 입력하세요:
S D E P D
 
ps/teflib/segmenttree.txt · 마지막으로 수정됨: 2023/08/31 05:39 저자 teferi