Wavelet Tree
Wavelet Tree is a binary tree-based data structure used for efficient range queries on a sequence of elements. It divides the elements into multiple levels, where each level represents a bit position. At each level, the elements are divided into two groups based on the corresponding bit value. This division continues recursively until we reach the leaf nodes, which represent individual elements. Wavelet Tree allows us to answer range queries efficiently by leveraging the binary representation of the elements.
Let's implement a basic Wavelet Tree in Python:
class WaveletTree:
def __init__(self, sequence):
self.sequence = sequence
self.alphabet = sorted(set(sequence))
self.root = self._build(0, len(sequence)-1, self.alphabet)
def _build(self, left, right, alphabet):
if len(alphabet) == 0 or left > right:
return None
if left == right:
return None
mid = (left + right) // 2
median = alphabet[len(alphabet) // 2]
left_child_alphabet = [c for c in alphabet if c <= median]
right_child_alphabet = [c for c in alphabet if c > median]
left_child = self._build(left, mid, left_child_alphabet)
right_child = self._build(mid+1, right, right_child_alphabet)
return (left_child, right_child)
def _rank(self, node, left, right, symbol):
if node is None or left > right:
return 0
if left == right:
return int(self.sequence[left] == symbol)
mid = (left + right) // 2
if symbol <= self.root[mid]:
return self._rank(node[0], left, mid, symbol)
else:
return self._rank(node[1], mid+1, right, symbol)
def rank(self, left, right, symbol):
return self._rank(self.root, left, right, symbol)
# Create a Wavelet Tree
sequence = "abracadabra"
wavelet_tree = WaveletTree(sequence)
# Perform rank queries
print("Rank of 'a' from index 0 to 6:", wavelet_tree.rank(0, 6, 'a'))
print("Rank of 'b' from index 0 to 6:", wavelet_tree.rank(0, 6, 'b'))
print("Rank of 'c' from index 0 to 6:", wavelet_tree.rank(0, 6, 'c'))
In this example, we define a WaveletTree
class with methods for building the tree and performing rank queries. The build
method recursively builds the wavelet tree by dividing the sequence
based on the median value. Each node of the tree represents a range of
elements and has two child nodes. The rank
method
calculates the number of occurrences of a symbol within a range by
traversing the wavelet tree based on the symbol and range boundaries. We
create a Wavelet Tree from a sequence, and then perform rank queries to
count the occurrences of different symbols within specific ranges.
Lets dive more into this topic.
class WaveletTree:
def __init__(self, arr):
self.arr = arr
self.min_val = min(arr)
self.max_val = max(arr)
self.bits = self._get_bits(self.max_val)
self.root = self._build_tree(self.arr, self.bits)
def _get_bits(self, num):
# Get the number of bits required to represent the maximum value
bits = 0
while num > 0:
bits += 1
num >>= 1
return bits
def _build_tree(self, arr, bits):
if bits == 0:
return None
mid = (self.min_val + self.max_val) // 2
left_arr, right_arr = [], []
bitmap = []
for num in arr:
bit = (num >> (bits - 1)) & 1
bitmap.append(bit)
if bit == 0:
left_arr.append(num)
else:
right_arr.append(num)
node = WaveletTreeNode(bitmap)
node.left = self._build_tree(left_arr, bits - 1)
node.right = self._build_tree(right_arr, bits - 1)
return node
def _count_less_or_equal(self, node, bits, value):
if node is None or bits == 0:
return 0
bit = (value >> (bits - 1)) & 1
if bit == 0:
return node.bitmap[bits - 1] + self._count_less_or_equal(node.left, bits - 1, value)
else:
return self._count_less_or_equal(node.right, bits - 1, value)
def count_less_or_equal(self, value):
return self._count_less_or_equal(self.root, self.bits, value)
def range_count(self, left, right, value):
return self.count_less_or_equal(right) - self.count_less_or_equal(left - 1)
class WaveletTreeNode:
def __init__(self, bitmap):
self.bitmap = bitmap
self.left = None
self.right = None
Now, let's solve a question from LeetCode using the Wavelet Tree data structure. The problem we'll solve is "Count of Smaller Numbers After Self," which asks to count the number of smaller elements to the right of each element in an input array.
class WaveletTree:
# ... (previous implementation remains the same)
def count_smaller(self, arr):
counts = []
sorted_arr = sorted(set(arr))
for num in arr:
index = bisect_left(sorted_arr, num)
count = self.count_less_or_equal(num)
counts.append(count)
self.delete(index) # Removing the element from the tree
return counts
def delete(self, index):
self._delete(self.root, self.bits, index)
def _delete(self, node, bits, index):
if node is None or bits == 0:
return
node.bitmap[bits - 1] -= 1 # Decrementing the bitmap count
bit = (index >> (bits - 1)) & 1
if bit == 0:
self._delete(node.left, bits - 1, index)
else:
self._delete(node.right, bits - 1, index)
# Usage:
arr = [5, 2, 6, 1]
wt = WaveletTree(arr)
counts = wt.count_smaller(arr)
print(counts)
In this example, we create a WaveletTree instance wt
with the input array arr
. We then call the count_smaller
method, which counts the number of smaller elements to the right of each element in arr
. The method utilizes the count_less_or_equal
method of the WaveletTree to count the smaller elements. We also implement the delete
method in the WaveletTree class to remove elements from the tree as we process them.
The time complexity of building the Wavelet Tree is O(n log m), where n is the number of elements in the input array and m is the maximum value. The time complexity of the count_less_or_equal
operation is O(log m), and the time complexity of the count_smaller
method is O(n log m). The space complexity is O(n log m) to store the Wavelet Tree nodes.