ramchandra's blog

By ramchandra, 4 years ago, In English
Introduction

You may have heard of the sparse table data structure, which can do min/gcd queries in $$$O(1)$$$ and sum queries in $$$O(\log n)$$$. The space complexity and time complexity of constructing a sparse table is $$$O(n \log n)$$$.

However, the disjoint sparse table data structure can do any type of query such as sum/min/gcd in $$$O(1)$$$ with the same $$$O(n \log n)$$$ space and time complexity of construction. The operation just needs to be a monoid (i.e. it has an identity element and is associative). Strictly speaking, an identity element is not required, but it allows for empty sum queries which is convenient.

Structure

Here we assume the array is already resized so that its size is a power of $$$2$$$. A disjoint sparse table consists of log n levels numbered from 0 onwards. In the $$$i$$$th level, we split the array into $$$2^i$$$ blocks. For each block, we store the operation sum of elements between any index in the block and the middle of block. For example, level $$$0$$$ would contain one block, whose middle is at $$$mid = n/2$$$, and level $$$1$$$ would contain two blocks, one with middle at $$$n/4$$$ and another with middle at $$$3n/4$$$.

To be specific: if $$$idx < mid$$$, sum[level][idx] stores the sum of elements $$$[idx, mid)$$$, otherwise if $$$idx \geq mid$$$ we store the sum of elements $$$[mid, idx]$$$. In other words, for any given $$$mid = 2^k\cdot o$$$ our table includes all the sums between idx and m for all $$$idx \in [2^k\cdot(o-1), 2^k\cdot(o+1) )$$$. The block with middle $$$mid$$$ has level $$$p-1-k$$$, where $$$n = 2^p$$$.

Since each level contains n elements and there are $$$\log_2 n$$$ levels, the space and time complexity of construction is $$$O(n \log n)$$$

Implementation

To query for the sum of elements in $$$[l,r]$$$, we just need to find the block that includes $$$l$$$ and $$$r$$$.

Suppose we have a disjoint sparse table with $$$7$$$ levels and $$$2^7$$$ elements the binary representation of $$$l$$$ and $$$r$$$ are (using $$$7$$$ bits to represent each index):

$$$l$$$: 1101​[0​]01

$$$r$$$: 1101​[1​]10

Note the leftmost position where the bits differ is at the $$$4$$$th bit (0-indexed) from left marked in square brackets. So, the middle element of block will be $$$1101100$$$. Since the $$$4$$$th bit differs, the level containing this block is $$$4$$$. The range of elements for this block is $$$[1101000, 1101111]$$$. We simply need to add two elements of the table, sum[level][l] which has the sum of $$$[l, mid)$$$ and sum[level][r] which has the sum of $$$[mid, r]$$$.

So, we just need to find the level, which is the position of the leftmost bit that differs. This is simply the leftmost 1 in l^r = 0000111. For finding this we can use the function __builtin_clz(x), a GCC extension that counts the number of leading zeros in the integer $$$x$$$. The functioned is an abbreviation of C​ount L​eading Z​eros. This corresponds to the lzcnt x86 machine instruction, so this is $$$O(1)$$$.

Code
#include <bits/stdc++.h>
using namespace std;
/*! T is the type of the elements
 * Monoid is the operation functor type
 * identity is the identity element of the Monoid (e.g. 0 for addition and inf for minimum)
 */
template <typename T, typename Monoid, auto identity>
class DisjointSparseTable {
public:
	explicit DisjointSparseTable(vector<T> arr) {
		// Find the highest cnt such that pow2 = 2^cnt >= x
		int pow2 = 1, cnt = 0;
		for (; pow2 < arr.size(); pow2 *= 2, ++cnt);
		
		arr.resize(pow2, identity);
		sum.resize(cnt, vector<T>(pow2));
		
		for(int level = 0; level < sum.size(); ++level) {
			for(int block = 0; block < 1 << level; ++block) {
				// The first half of the block contains suffix sums,
				// the second half contains prefix sums
				const auto start = block << (sum.size() - level);
				const auto end = (block + 1) << (sum.size() - level);
				const auto middle = (end + start) / 2;
				
				auto val = arr[middle];
				sum[level][middle] = val;
				for(int x = middle + 1; x < end; ++x) {
					val = Monoid{}(val, arr[x]);
					sum[level][x] = val;
				}
				
				val = arr[middle - 1];
				sum[level][middle - 1] = val;
				for(int x = middle-2; x >= start; --x) {
					val = Monoid{}(val, arr[x]);
					sum[level][x] = val;
				}
			}
		}
	}
	/*! Returns Monoid sum over range [l, r)*/
	T query(int l, int r) const {
		assert(l < r);
		// Convert half open interval to closed interval
		--r;
		if(r == l-1){
			return identity;
		}
		if(l == r){
			return sum.back()[l];
		}
		// Position of the leftmost different bit from the right
		const auto pos_diff = (sizeof(ll) * CHAR_BIT) - 1 - __builtin_clzll(l ^ r);
		const auto level = sum.size() - 1 - pos_diff;
		return Monoid{}(sum[level][l], sum[level][r]);
	}
private:
	vector<vector<T>> sum;
};
int main(){
	// Tests the DisjointSparseTable
	vector<int> data{6, 2, 4, 1, 7, 3, 4, 2, 7, 2, 4, 1, 6};
	DisjointSparseTable<int, plus<>, 0> sp{data};
	for(int start = 0; start < data.size(); ++start) {
		for(int end = start; end <= data.size(); ++end){
			assert(sp.query(start, end) == accumulate(begin(data) + start, begin(data) + end, 0));
		}
	}
}
Credits:

AsleepAdhyyan for telling me about this DS.

I learnt about this DS from Nilesh3105's Disjoint Sparse Table tutorial

EDIT: Fix undefined behavior in calling __builtin_clzll(0) when l == r.

  • Vote: I like it
  • +91
  • Vote: I do not like it

»
5 months ago, # |
  Vote: I like it 0 Vote: I do not like it

Please, fix these minor bugs.

1) insert using ll = long long;

2) when you test your structure begin end from start + 1. Just edit to int end = start + 1. Otherwise fails from assert(l < r).

»
5 months ago, # |
Rev. 4   Vote: I like it 0 Vote: I do not like it

Idea is really cool ! For anyone who got confused here is a more bitwise way of thinking to see how it works.

Lets say the lefmost bit in which $$$l$$$ and $$$r$$$ differ is $$$X$$$ , also the number when you only have the bits before $$$X$$$ is $$$Y$$$

So basically I knew the first $$$X$$$ bits are same in both $$$l$$$ and $$$r$$$ , since we can acces every rangesum which is in the form $$$[ Y , Y + 2^{ld-1} ]$$$ and $$$[Y + 2^{ld-1} , Y + 2^{ld} - 1 ]$$$ , we can also acces the rangesum $$$[ l , r ]$$$ because $$$l \in [ Y , Y + 2^{ld-1} - 1]$$$ and $$$r \in [Y + 2^{ld-1} , Y + 2^{ld} - 1 ]$$$ , the reason comes from the the fact $$$l$$$'s $$$X$$$th bit is 0 and $$$r$$$'s $$$X$$$th bit is 1 which is always true because $$$l < r$$$ and $$$X$$$ is the most significant bit in which they differ.

Note that $$$Y + 2^{ld-1}$$$ is the middle point of the segment $$$[ Y , Y + 2^{ld} - 1]$$$

You can also visualize like the following :

$$$Y = "111000"$$$

$$$Y + 2^{ld-1} = "111100"$$$

$$$[ Y , Y + 2^{ld-1} - 1] = "1110??"$$$

$$$[Y + 2^{ld-1} , Y + 2^{ld} - 1 ] = "1111??"$$$

$$$[ Y , Y + 2^{ld} - 1] = "111???"$$$