목차

(old) segmenttree.py

SegmentTree v1.2

코드

from typing import Callable, Iterable, TypeVar

T = TypeVar('T')


# N SegmentTree
# I {"version": "1.2", "typing": ["Callable", "Iterable", "TypeVar"], "const": ["T"]}
class SegmentTree:
    """Non-recursive segment tree supporting point update and range query."""

    def __init__(self,
                 values: Iterable[T],
                 merge: Callable[[T, T], T] = min):
        l = list(values)
        self._size = len(l)
        self._tree = l + l
        self._merge = merge
        for i in range(self._size - 1, 0, -1):
            self._tree[i] = merge(self._tree[i * 2], self._tree[i * 2 + 1])

    def update(self, pos: int, value: T):
        i = pos + self._size
        while i:
            self._tree[i] = self._merge(self._tree[i], value)
            i >>= 1

    def set(self, pos: int, value: T):
        i = pos + self._size        
        while i:
            self._tree[i] = value
            value = self._merge(value, self._tree[i ^ 1])
            i >>= 1

    def query(self, beg: int, end: int) -> T:
        ret = self._tree[beg + self._size]
        l, r = beg + self._size + 1, end + self._size - 1
        while l <= r:
            if l % 2:
                ret = self._merge(self._tree[l], ret)
            if not r % 2:
                ret = self._merge(self._tree[r], ret)
            l, r = (l + 1) >> 1, (r - 1) >> 1
        return ret

SegmentTree v1.3

코드

# N SegmentTree
# I {"version": "1.3", "typing": ["Callable", "Iterable", "TypeVar"], "const": ["T"]}
class SegmentTree:
    """Non-recursive segment tree supporting point update and range query."""

    def __init__(self,
                 values: Iterable[T],
                 merge: Callable[[T, T], T] = min):
        l = list(values)
        self._size = len(l)
        self._tree = l + l
        self._merge = merge
        for i in range(self._size - 1, 0, -1):
            self._tree[i] = merge(self._tree[i * 2], self._tree[i * 2 + 1])

    def set(self, pos: int, value: T):
        i = pos + self._size        
        while i:
            self._tree[i] = value
            value = (self._merge(self._tree[i - 1], value) if i % 2 
                     else self._merge(value, self._tree[i + 1]))
            i >>= 1

    def query(self, beg: int, end: int) -> T:
        if end == beg + 1:
            return self._tree[beg + self._size]
        l, r = beg + self._size + 1, end + self._size - 2
        ret_l, ret_r = self._tree[l - 1], self._tree[r + 1]
        while l <= r:
            if l % 2:
                ret_l = self._merge(ret_l, self._tree[l])
            if not r % 2:
                ret_r = self._merge(self._tree[r], ret_r)
            l, r = (l + 1) >> 1, (r - 1) >> 1
        return self._merge(ret_l, ret_r)