magnus.hegdahl's blog

By magnus.hegdahl, 2 years ago, In English

Commutative Segment Trees

Preface

This blog is about an alternative way to implement segment trees. Knowing segment tree already should not be necessary to understand this blog, but it is mostly targeted towards people who already know segment trees but don't know how to for example create a merge sort tree that supports updates.

Pros:

  • You don't need a fast merge function (which allows you to for example extend to 2D)
  • Very easy to implement
  • Fast

Cons:

  • Merging has to be commutative
  • Can't easily be turned lazy (meaning only update or query can be on a range, not both)

Disclaimer: This way of implementing segment tree may be what you think of as "a standard segment tree". But the reason I made this blog is that I talked to people who were very familiar with segment trees but did not know how to make a 2D segment tree (which is a trivial modification of this way to implement segment trees).

If you just want to look at the code, it's at the bottom of the blog.

Range Sum, Point Add

Consider this simple example problem.

You have an array $$$a$$$ of length $$$n$$$ and have to process the following queries:

  • $$$p \enspace x$$$: $$$a_p := a_p + x$$$
  • $$$l \enspace r$$$: print $$$\sum_{i=l}^{r} a_i$$$

Solution

We use a binary tree, and look at $$$\mathcal{O}(\log n)$$$ nodes when querying a range or updating a value.

How does this work?

Let's define $$$rootpath(p)$$$ as the set of nodes we touch when walking from the leaf $$$p$$$ to the root (the nodes we look at when doing an update in this problem),

and $$$rangecover(l, r)$$$ as the set of nodes we look at when we find the sum of values at positions in $$$[l, r]$$$.

A more formal definition of $$$rangecover(l, r)$$$ is the smallest set of nodes such that

  • for all $$$p$$$ in $$$[l, r]$$$, there is exactly $$$1$$$ node in $$$rangecover(l, r)$$$ that is also in $$$rootpath(p)$$$.

  • for all $$$p$$$ not in $$$[l, r]$$$, there are no nodes in $$$rangecover(l, r)$$$ that are also in $$$rootpath(p)$$$.

Now we can do updates like this:

def update(p, x):
  for node in rootpath(p):
    # instead of only updating the bottom layer and
    # then merging up along the path,
    # we just add x directly to every node on the path.
    value[node] += x
  
def query(l, r):
  result = 0
  for node in rangecover(l, r):
    result += value[node]
  return result

This works intuitively from the illustrations in the spoiler tag above, but here is an argument for that it works using the definition of $$$rangecover$$$.

When we ask for the sum of a range, that is the same as asking for the sum of updates to positions within the range (ignoring the initial values, we can just assume they were 0 and update them).

Let's consider a query to $$$[l, r]$$$ and some position $$$p$$$ in $$$[l, r]$$$. Since an update to $$$p$$$ is applied to every node in $$$rootpath(p)$$$, there is also exactly one node in $$$rangecover(l, r)$$$ which the update is applied to.

Range Add, Point Sum

This problem is the same as in the previous section, except query and update are swapped:

You have an array $$$a$$$ of length $$$n$$$ and have to process the following queries:

  • $$$l \enspace r \enspace x$$$: For each $$$i$$$ in $$$[l, r]$$$, $$$a_i := a_i + x$$$
  • $$$p$$$: print $$$a_p$$$

Solution

Here we can use the same set definitions as before, but use a different definition for what value the nodes in the tree should hold.

In the previous problem we said:

For all nodes, the value in the node is the sum of $$$a_p$$$ for positions $$$p$$$ such that the node is in $$$rootpath(p)$$$.

But what if we instead say:

For all positions $$$p$$$, the sum of values in nodes in $$$rootpath(p)$$$ is $$$a_p$$$.

If this holds, queries are easy. Just sum up the values on the path to the root. But what about updates? Since $$$rangecover(l, r)$$$ contains exactly $$$1$$$ node from $$$rootpath(x)$$$ for each $$$x$$$ in the range, we can apply the update to each of those nodes, and maintain that the sum of $$$rootpath(x)$$$ is $$$a_x$$$.

In code it will look like this:

def update(l, r, x):
  for node in rangecover(l, r):
    value[node] += x
  
def query(p):
  result = 0
  for node in rootpath(p):
    result += value[node]
  return result

Commutativity

In the two problems above, the only properties of the updates we used are commutativity and associativity.

So this means we have problems with for example assignment on a range, because assigning something to $$$x$$$ then to $$$y$$$ is not the same as assigning it to $$$y$$$ then $$$x$$$.

But in the point update / range query case, as long as merging would be commutative, we can often represent non-commutative updates as commutative updates by first examining the state of the array.

For range sum and point assign, we can do this:

$$$(a_p := x) \Leftrightarrow (a_p := a_p + (x - a_p))$$$

Dynamic Merge Sort Tree (Online)

The previous two problems were solvable with a standard merging segment tree as well. Now comes a problem that isn't (tell me in the comments if I'm wrong).

You have an array $$$a$$$ of length $$$n$$$ and have to process the following queries:

  • $$$p \enspace x$$$: $$$a_p := x$$$
  • $$$l \enspace r \enspace x$$$: print the number of positions $$$p$$$ in $$$[l, r]$$$ such thath $$$a_p < x$$$.

Solution

If there were no updates we could use a merge sort tree on the array, and in each node binary search to find the number of positions $$$p$$$ covered by that node such that $$$a_p < x$$$.

A merge sort tree is just a regular merging segment tree where the merge operation is to weave together two sorted lists into a new sorted list.

Merging everything once takes only $$$\mathcal{O}(n \log n)$$$ time, because merging a single node takes $$$\mathcal{O}(w)$$$ time where w is how many leafs are below this node, and the sum of $$$w$$$ in a layer is $$$n$$$, and the number of layers is $$$\mathcal{O}(\log n)$$$.

But changing anything in this merge sort tree takes $$$\mathcal{O}(n)$$$ time, since even though only $$$\mathcal{O}(\log n)$$$ nodes need to be visited, some of them have big $$$w$$$.

Commutative segment tree solves this, since we don't need to merge anything. Instead of merging, we can keep a binary search tree in each node, and when we update something we just replace the element in every relevant BST.

In code it will look like this:

def update(p, x):
  for node in rootpath(p):
    bst[node].erase(a[p])
    bst[node].insert(x)
  a[p] = x

def query(l, r, x):
  res = 0
  for node in rangecover(l, r):
    res += bst[node].count_smaller(x)
  return res

Harder Practice Problems (With Explanations)

Generalization of CF 406 Div.1 B. Legacy

In the original problem you are asked to find single source shortest paths to every node in a graph with some special edges. They state 3 different types of edges, all of which can be generalized to this:

$$$l_1 \enspace r_1 \enspace l_2 \enspace r_2 \enspace w$$$ means you can go from any node in $$$[l_1, r_1]$$$ to any node in $$$[l_2, r_2]$$$ for the cost of $$$w$$$.

The graph has up to $$$10^5$$$ nodes, and up to $$$10^5$$$ of these special edges.

Target time complexity (better than the one in the editorial)
Solution

IATI 21 — News

You are given a rooted tree with each node having the value $$$0$$$ or $$$1$$$. Let $$$s(x, k)$$$ be the set of nodes in $$$x$$$'s subtree that are within a distance of $$$k$$$ from node $$$x$$$. Process $$$q$$$ of the following queries:

  • $$$1 \enspace x \enspace k$$$: set the value of every node in $$$s(x, k)$$$ to $$$1$$$.
  • $$$2 \enspace x \enspace k$$$: print the sum of values of nodes in $$$s(x, k)$$$.

$$$n, q \le 2 \cdot 10^5$$$

Solution

Implementation

Python
C++
  • Vote: I like it
  • +135
  • Vote: I do not like it

| Write comment?
»
2 years ago, # |
  Vote: I like it +23 Vote: I do not like it

Related problem: USACO 2021 — Tickets

»
2 years ago, # |
  Vote: I like it +8 Vote: I do not like it

Auto comment: topic has been updated by magnus.hegdahl (previous revision, new revision, compare).

»
2 years ago, # |
  Vote: I like it +6 Vote: I do not like it

OP Blog

»
2 years ago, # |
Rev. 2   Vote: I like it 0 Vote: I do not like it

Can't easily be turned lazy (meaning only update or query can be on a range, not both)

I do not agree with this statement You'll need to compute the value for a given range fast, but it's the same for non-commutative functions. Lifehack: you can precompute the values for all powers of 2 in $$$O(logn)$$$ using bottom-top technique

  • »
    »
    2 years ago, # ^ |
      Vote: I like it +8 Vote: I do not like it

    For simple operations like sum I agree it's the same as in a merging segment tree. But for lazy to be fast you need to be able to merge updates. And merging is kinda the thing we were trying to avoid. So in many cases it can be done, but in the cases where this merge-free way to think about segment tree is actually useful it doesn't work I think.

»
2 years ago, # |
  Vote: I like it +4 Vote: I do not like it

Isn't this essentially the same as AICash segment trees?

  • »
    »
    2 years ago, # ^ |
      Vote: I like it +19 Vote: I do not like it

    I assume you mean this blog.

    I hadn't seen it before writing my blog and I see now that his section "Modification on interval, single element access" is about the same trick as "Range Add Point Sum" here.

    To not write a merge function is not anything new. But some people I know view segment tree as some machine working with an associative merge function, and a set of update functions that is closed under composition where $$$f(a) + f(b) = f(a + b)$$$. Thinking about it this way works well as long as merge and composition can be computed fast, but sometimes they can't, and this blog is about what to do in some of those cases, and from skimming through Al.Cash's blog it seems it doesn't talk about that.

»
2 years ago, # |
  Vote: I like it 0 Vote: I do not like it

Thank you for writing this blog, it is a really good explanation for range add/point sum segment trees and mergesort trees! But I want to clarify/weaken your assumptions about commutativity a bit. It is not entirely clear to me what you mean by "commutative." In the cons at the beginning, you say "Merging has to be commutative" which I thought at first meant that you want merging the values stored in each node to be commutative but then in the Commutativity section, you say "assigning something to x then to y is not the same as assigning it to y then x" which implies you want updates to be commutative. However, I don't see how this framework requires merging values or updates to be commutative and indeed the updates for a mergesort tree are point assignments which are not commutative.

I bring this up because you mention that commutative segment trees are used to extending to 2D segment trees, but I think it is possible to make 2D segment trees without updates being commutative (although I do think 2D segment trees require merging values to be commutative). For example, let's say you want a 2D segment tree supporting the following operations:

  1. $$$x\ y\ v$$$: Set $$$a[x][y] := v$$$
  2. $$$x_1\ x_2\ y_1\ y_2$$$: Return $$$\min_{x_1\leq x\leq x_2,y_1\leq y\leq y_2} a[x][y]$$$

Here, updates are not commutative because updates are done through point assignment, and we can not use the trick of changing the update to $$$a[x][y] := a[x][y]+(v-a[x][y])$$$ because our queries are done with the $$$\min$$$ operation, not the sum operation.

However, I have implemented this 2D segment tree before (source code), and I think this kind of 2D segment tree can be implemented with your framework because (1) merging values is associative (the value at each node is a 1D segment tree supporting updates of the form $$$y\ n$$$: $$$a[y] := n$$$ and queries of the form $$$y_1\ y_2$$$: return $$$\min_{y_1\leq y\leq y_2} a[y]$$$; merging these 1D segment trees is associative because $$$\min$$$ is associative) and (2) there is a fast way to apply an update to the value of a node representing an entire interval of x-values $$$[x_l, x_r]$$$. To demonstrate (2), I wrote what I think the update function looks like in your framework:

def update(x, y, v):
    for node in rootpath(x):
        x_l, x_r = node.interval() # Get the range of x-values this interval represents
        # At leaf nodes, just apply the update at the appropriate y-value:
        if x_l == x_r:
            segment_tree[node].update(y, v)
        # At non-leaf nodes, query both the left and right child for a[y], merge the two results together using min, and then apply the update:
        # (this only works if child nodes have already been updated, so I assume rootpath goes in order from leaf to root node)
        else:
            left_val = segment_tree[node.left_child()].query(y, y)
            right_val = segment_tree[node.right_child()].query(y, y)
            segment_tree[node].update(y, min(left_val, right_val))

For each node, we do a constant number of queries/updates on 1D segment trees, so updating a single node takes $$$O(\log(Y))$$$ time (where $$$Y$$$ is the range of y-values). This is definitely better than naively merging the two segment trees at the left and right child nodes, which would take $$$O(Y)$$$ time. Since rootpath(x) contains $$$O(\log(X))$$$ nodes (where $$$X$$$ is the range of x-values), this means the entire update takes $$$O(\log(X)\log(Y))$$$ time, as expected.