0%

数据结构学习笔记 - 10 线段树

介绍

经典面试问题,一面墙,每次可以在其中一段区间进行染色,染色可以覆盖之前的染色,若干次染色后,问某个区间内有几种颜色?

这里涉及了两个操作,一个是更新,一个是查询。可以使用数组解决这个问题,染色就修改数组的指定区间,查询就是遍历整个数组,两种操作的复杂度都为 O(n)。

而如果使用线段树解决,复杂度为 O(logn)。

操作 使用数组复杂度 使用线段树复杂度
区间更新 O(n) O(logn)
区间查询 O(n) O(logn)

对于线段树,不考虑添加和删除操作,线段树解决的问题,区间本身是固定的,因此存储线段树使用静态数组即可。假定研究的问题是数组区间求和,数组共 8 个元素,则构造的线段树如下:

1
2
3
4
5
6
7
8
9
                   A[0...7]
/ \
A[0...3] A[4...7]
/ \ / \
A[0...1] A[2...3] A[4...5] A[6...7]
/ \ / \ / \ / \
A[0] A[1] A[2] A[3] A[4] A[5] A[6] A[7]

// 根节点存储全部区间的和,A[0...3]存储前半段的和,以此类推。

如果数组是 10 个元素,线段树则表示为:

1
2
3
4
5
6
7
8
9
                   A[0...9]
/ \
A[0...4] A[5...9]
/ \ / \
A[0...1] A[2...4] A[5...6] A[7...9]
/ \ / \ / \ / \
A[0] A[1] A[2] A[3,4] A[5] A[6] A[7] A[8,9]
/ \ / \
A[3] A[4] A[8] A[9]
  • 线段树不是满二叉树,更不是完全二叉树。
  • 平衡二叉树,指二叉树的每个节点的左右子树的高度差的绝对值不超过 1。平衡二叉树一定不会退化成链表。
  • 堆就是平衡二叉树,完全二叉树就一定是平衡二叉树。但是二分搜索树就不一定是平衡二叉树。
  • 线段树是一棵平衡二叉树。
  • 线段树虽然不是完全二叉树,但类似也是只有最下层是不「满」的,我们可以近似将其看作一个满二叉树,因此也可以使用数组表示。
  • 对于线段树来说,如果使用数组来表示,如果区间有 n 个元素,那么数组需要 4n 的空间来存储。对于线段树我们不考虑添加元素,即区间固定。所以创建线段树复杂度为 O(n),准确的讲是 O(4n)。当然,如果使用链式的方式存储线段树,则不需要 4n 的空间也可以。

应用

线段树常用于基于区间进行的统计查询。如果我们经常需要查询数组中某区间的最大值,最小值或者这个区间的数字和等,就可以使用线段树使操作复杂度由 O(n) 下降为 O(logn)。

具体的说,比如一个网站需要统计其2017年注册用户中至今为止消费最高的用户?消费最少的用户?学习时长最长的用户?

实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
public interface Merger<E> {

E merge(E a, E b);
}

public class SegmentTree<E> {

private E[] tree;
private E[] data;
// 融合器,区间中的元素进行何种融合存到线段树中。
private Merger<E> merger;

/**
* 构造函数。
*
* @param arr 传入的数组。
* @param merger 融合器。
*/
@SuppressWarnings("unchecked")
public SegmentTree(E[] arr, Merger<E> merger) {

this.merger = merger;

data = (E[]) new Object[arr.length];
for (int i = 0; i < arr.length; i++) {
data[i] = arr[i];
}

// 线段树所使用数组空间长度应该是传入数组长度的 4 倍
tree = (E[]) new Object[4 * arr.length];
buildSegmentTree(0, 0, data.length - 1);
}

/**
* 在 treeIndex 的位置创建表示区间[l...r]的线段树
*
* @param treeIndex 创建的线段树根节点所在的索引
* @param l 区间的左边
* @param r 区间的右边
*/
private void buildSegmentTree(int treeIndex, int l, int r) {
if (l == r) {
tree[treeIndex] = data[l];
return;
}

int leftTreeIndex = leftChild(treeIndex);
int rightTreeIndex = rightChild(treeIndex);

int mid = l + (r - l) / 2;
buildSegmentTree(leftTreeIndex, l, mid);
buildSegmentTree(rightTreeIndex, mid + 1, r);

tree[treeIndex] = merger.merge(tree[leftTreeIndex], tree[rightTreeIndex]);
}

// 区间长度。
public int getSize() {
return data.length;
}

// 获得指定索引的元素。
public E get(int index) {
if (index < 0 || index >= data.length) {
throw new IllegalArgumentException("Index is illegal.");
}
return data[index];
}

// 左孩子索引。
private int leftChild(int index) {
return 2 * index + 1;
}

// 右孩子索引。
private int rightChild(int index) {
return 2 * index + 2;
}

// 返回区间 [queryL, queryR]的值
public E query(int queryL, int queryR) {
if (queryL < 0 || queryL >= data.length || queryR < 0 || queryR >= data.length || queryL > queryR) {
throw new IllegalArgumentException("Index is illegal.");
}

return query(0, 0, data.length - 1, queryL, queryR);
}

// 在以treeIndex 为根的线段树中[l...r]的范围内,搜索区间 [queryL...queryR]的值
private E query(int treeIndex, int l, int r, int queryL, int queryR) {
if (l == queryL && r == queryR) {
return tree[treeIndex];
}

int mid = l + (r - l) / 2;
int leftTreeIndex = leftChild(treeIndex);
int rightTreeIndex = rightChild(treeIndex);

if (queryL >= mid + 1) {
return query(rightTreeIndex, mid + 1, r, queryL, queryR);
} else if (queryR <= mid) {
return query(leftTreeIndex, l, mid, queryL, queryR);
}

E leftResult = query(leftTreeIndex, l, mid, queryL, mid);
E rightResult = query(rightTreeIndex, mid + 1, r, mid + 1, queryR);
return merger.merge(leftResult, rightResult);
}

// 将index位置的值,更新为e
public void set(int index, E e) {
if (index < 0 || index >= data.length) {
throw new IllegalArgumentException("Index is illegal");
}

data[index] = e;
set(0, 0, data.length - 1, index, e);
}

// 在以treeIndex为根的线段树中更新index的值为e
private void set(int treeIndex, int l, int r, int index, E e) {

if (l == r) {
tree[treeIndex] = e;
return;
}

int mid = l + (r - l) / 2;
// treeIndex的节点分为[l...mid]和[mid+1...r]两部分

int leftTreeIndex = leftChild(treeIndex);
int rightTreeIndex = rightChild(treeIndex);
if (index >= mid + 1) {
set(rightTreeIndex, mid + 1, r, index, e);
} else { // index <= mid
set(leftTreeIndex, l, mid, index, e);
}

tree[treeIndex] = merger.merge(tree[leftTreeIndex], tree[rightTreeIndex]);
}

@Override
public String toString() {
StringBuilder res = new StringBuilder();
res.append('[');
for (int i = 0; i < tree.length; i++) {
if (tree[i] != null) {
res.append(tree[i]);
} else {
res.append("null");
}

if (i != tree.length - 1) {
res.append(", ");
}
}
res.append(']');
return res.toString();
}
}

使用线段树,可以以 O(logn) 的复杂度方便的进行数组区间求和、求区间最大值等操作:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
public class Main {

public static void main(String[] args) {

Integer[] nums = { -2, 0, 3, -5, 2, -1 };
SegmentTree<Integer> segTree = new SegmentTree<>(nums, new Merger<Integer>() {
@Override
// 用于求和的线段树。
public Integer merge(Integer a, Integer b) {
return a + b;
// return Math.max(a, b); // 用于求最大值的线段树。
}
});
// SegmentTree<Integer> segTree = new SegmentTree<>(nums, (a, b) -> a + b); // Lambda 表达式 的写法

System.out.println(segTree.query(0, 2)); // 1
System.out.println(segTree.query(2, 5)); // -1
System.out.println(segTree.query(0, 5)); // -3

System.out.println(segTree);
}
}

更多

  • 我们这里实现的线段树只能实现单个元素更新,没有实现区间更新。比如说,我们希望区间 [2,5]中所有元素 +3,在线段树中,需要将这个叶子节点以及他们的父节点都进行更新,复杂度会变为 O(n) 级别,比较慢。一个方式是进行懒惰更新,lazy 更新,我们只把线段树中具体到相应区间的节点进行更新,其下面的子节点先不进行更新,而使用一个 lazy 数组记录未更新的内容,之后如果进行查询时先查找 lazy 数组看是否有未更新的内容,如果有再更新一下。
  • 线段树不仅仅适用于一维,二维甚至三维的区间问题都可以使用线段树解决。
  • 关于区间操作,还有另外一个重要的数据结构:树状数组。