Segment tree を書く (1)
皆さん segment tree (セグ木)をご存じですか?
完全二分木にモノイドの元を乗せ、ある区間に演算した結果を求めることができるデータ構造です。 セグ木を貼れば これ とか これ みたいな問題が解けます。 詳細を説明している記事はググるとたくさん見つかるので、ここでは深入りしません。
この記事では一点更新・区間取得の場合における実装例とそれに至る気持ちを書きます。 タイトルに「(1)」とあるのは、今後遅延セグ木などについても書きたいという意思表示です。
問題設定
集合 に が単位元となる演算 を入れたモノイドを考え、 に対して次の操作を行います。
- 一点更新:update(i, val)
- を val に変更する
- 区間取得:query(l, r)
- を求める
が空なら を返すのが自然だと思います。
この記事での実装の要点
以下、少し詳しく書いていきます。
配列の領域確保と index の持ち方
簡単のため、 を 2 の冪に切り上げたものを とし、長さ の領域を確保することとします。 切り上げると完全二分木との対応(下図参照)が取れて見通しが良くなります。
以下ではこの配列をと呼び、番目()の要素を A[i] もしくは と表記します。
根の index は 1 とします。このとき、A[k] の親は A[k/2]、子は A[2*k] と A[2*k+1] となります(存在すれば)。 と が対応するとか、ビットシフトで祖先を辿れるとかの理由で根の index を 0 にするよりも 1 にする方が好みです。
一点更新
に対応する と、 を含む区間の結果を持つところ(A[L+i] の祖先)を更新します。
これは下からやると素直にできます。更新する場所の添字について、右に 1 つビットシフトしつつ 0 になるまでループすれば良いです。 上の図の index を2進数で書き直してみると自然に見えると思います。
区間取得
ここでも葉から遡っていくようにして計算します。
最初は最下段に着目するので、 と に を加えておきます。 ここで最下段の値を結果の計算に用いるのはどのような場合かを考えると、l または r が奇数であるときだと分かります。 しかも、用いられるのは または という区間の端に相当する値です。このことは最下段以外でも同様に成り立ちます。
したがって、次のような操作で計算できることが分かります。
l += L; r += L; v_l = e; v_r = e; while(l < r){ if(l & 1) v_l = v_l ∘ A[l++]; l >>= 1; if(r & 1) v_r = A[r-1] ∘ v_r; r >>= 1; } return v_l ∘ v_r;
C/C++ っぽく書いてますが、 は適切に置き換える必要がありますね (ところで私の環境だと と の区別が絶望的なフォントで表示されて悲しいです)。
「が奇数なら結果に反映させつつ区間を狭める」という操作を、上に遡りながら区間が十分狭くなるまで繰り返しています。*2
また、ここでは演算 に可換性を仮定していません。もし可換なら、v_l と v_r を区別する必要がなくなるので 若干簡潔になります。
実装例
template <typename T> class segment_tree{ T *tree; std::size_t len; T e; T op(T, T); public: segment_tree(std::size_t n, T ident) : e(ident){ for(len = 1; len < n; len <<= 1); tree = new T[2 * len]; for(std::size_t i = 1; i < 2 * len; ++i){ tree[i] = ident; } } ~segment_tree(){ delete[] tree; } void update(std::size_t pos, T val){ tree[len + pos] = val; for(pos = (len + pos) / 2; pos > 0; pos >>= 1){ tree[pos] = op(tree[2 * pos], tree[2 * pos + 1]); } } T query(std::size_t left, std::size_t right){ left += len; right += len; T v_l = e, v_r = e; while(left < right){ if(left & 1) v_l = op(v_l, tree[left++]); if(right & 1) v_r = op(tree[right - 1], v_r); left >>= 1; right >>= 1; } return op(v_l, v_r); } };