线段树

写在最前

我承认这个封面的线段树有点过于抽象,但是真的没有别的办法乐。

先引用一个大佬说的一句话:

什么是线段树?
如果你在考提高组前一天还在问这个问题,那么你会与一等奖失之交臂;如果你还在冲击普及组一等奖,那么这篇博客会浪费你人生中宝贵的5~20分钟。

显而易见,线段树是一个 OIer 从萌新过渡到正式选手的标志性算法。

然而事实上,对于一个正式的 OIer 选手,线段树更应该是一个工具,所以……

线段树是世界上最好的数据结构!!!(bushi

线段树可以 $O(\log n)$ 的实现区间和、区间乘、区间最值、区间修改等操作,在很多题中都可以以略逊于正解的解法过掉

核心思想

线段树的核心思想是 分治。它将一个区间分成若干个子区间,每个节点代表一个区间,并通过递归的方式维护区间的信息。

  • 区间划分:对于一个区间 $[l, r]$,我们将其分成两个子区间 $[l, mid]$ 和 $[mid+1, r]$,其中 $mid = \lfloor \frac{l+r}{2} \rfloor$。
  • 递归维护:通过递归的方式,逐步将区间划分到最小单位(即叶子节点),并在回溯时合并子区间的信息。

不难发现复杂度为均 $O(n\log n)$。

正常情况下我们存一棵树可能会用到存图的方法,或者是存每一个点的子节点,然而为了维护一串序列而真的建一个图实在是又臭又长过于屎山,不难想到父子两倍,即父节点为 $n$ 的话,子节点分别是 $2n$ 和 $2n+1$。

如果用这种方法的话不妨思考一下他所需的空间(以下非重点,可以酌情跳过):

首先设这个序列长度为 $n$。
按照正常的逻辑,线段树的最下面是长度为 1 的区间,那么也就是说总共有 $n$ 个叶子节点。
不难发现总共有 $2n-1$ 个节点。
然而实际上,我们在建树中,会浪费掉一些空间,所以说开 $2n-1$ 是不一定够的。

  • 线段树的深度为 $\lceil \log_2 n \rceil$。
  • 在最坏情况下,线段树的最后一层可能会有 $2^{\lceil \log_2 n \rceil}$ 个节点。
  • 由于 $2^{\lceil \log_2 n \rceil} \leq 2n$,因此总节点数最多为 $2 \times 2n - 1 = 4n - 1$。

在实现线段树时,区间更新是一个常见的操作。然而,如果每次更新都递归到叶子节点,时间复杂度会退化为 $O(N)$,这显然无法接受。

例如我们进行两次操作:

操作 1:区间 [1, 3] 加 2

  • 递归到叶子节点 [1, 1][2, 2][3, 3],分别更新它们的值。
  • 更新后的数组为 a = [3, 4, 5, 4, 5]

操作 2:区间 [2, 4] 加 3

  • 递归到叶子节点 [2, 2][3, 3][4, 4],分别更新它们的值。
  • 更新后的数组为 a = [3, 7, 8, 7, 5]

可以看到在操作 2 中,节点 [2, 2][3, 3] 已经被操作 1 更新过,现在又被操作 2 更新了一次。这种重复操作会导致时间复杂度增加。

不难想到可以把一个节点的修改累加标记起来,直到需要访问时再下传,这就是懒标记。

懒标记下传是线段树的核心优化。它通过延迟更新的方式,避免不必要的递归操作。

实现

建树

建树的过程是通过递归实现的。我们从根节点开始,逐步将区间划分到最小单位,并在回溯时合并子区间的信息。

1
2
3
4
5
6
7
8
9
10
11
12
void build(int ro, int l, int r) {
t[ro].l = l;
t[ro].r = r;
if (l == r) {
t[ro].sum = a[l]; // 叶子节点,直接赋值
return;
}
int mid = (l + r) / 2;
build(ro * 2, l, mid); // 递归构建左子树
build(ro * 2 + 1, mid + 1, r); // 递归构建右子树
t[ro].sum = t[ro * 2].sum + t[ro * 2 + 1].sum; // 合并子区间信息
}

区间查询

区间查询的过程是通过递归实现的。我们从根节点开始,逐步向下查询目标区间,并合并子区间的信息。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
int query(int ro, int l, int r) {
if (l <= t[ro].l && t[ro].r <= r) {
return t[ro].sum; // 当前节点区间完全包含在查询区间内
}
down(ro); // 下传懒标记
int mid = (t[ro].l + t[ro].r) / 2;
int sum = 0;
if (l <= mid) {
sum += query(ro * 2, l, r); // 递归查询左子树
}
if (r > mid) {
sum += query(ro * 2 + 1, l, r); // 递归查询右子树
}
return sum;
}

区间更新

区间更新的过程是通过递归实现的。我们从根节点开始,逐步向下更新目标区间,并通过懒标记延迟更新。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
void update(int ro, int l, int r, int x) {
if (l <= t[ro].l && t[ro].r <= r) {
t[ro].flag += x; // 更新懒标记
t[ro].sum += (t[ro].r - t[ro].l + 1) * x; // 更新当前节点的区间和
return;
}
down(ro); // 下传懒标记
int mid = (t[ro].l + t[ro].r) / 2;
if (l <= mid) {
update(ro * 2, l, r, x); // 递归更新左子树
}
if (r > mid) {
update(ro * 2 + 1, l, r, x); // 递归更新右子树
}
t[ro].sum = t[ro * 2].sum + t[ro * 2 + 1].sum; // 合并子区间信息
}

懒标记下传

将节点的懒标记下传到子节点,并清空当前节点的标记

1
2
3
4
5
6
7
8
9
void down(int ro) {
if (t[ro].flag != 0) {
t[ro * 2].flag += t[ro].flag;
t[ro * 2].sum += (t[ro * 2].r - t[ro * 2].l + 1) * t[ro].flag;
t[ro * 2 + 1].flag += t[ro].flag;
t[ro * 2 + 1].sum += (t[ro * 2 + 1].r - t[ro * 2 + 1].l + 1) * t[ro].flag;
t[ro].flag = 0; // 清空当前节点的懒标记
}
}

完整代码

喜闻乐见封装环节

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
class SegmentTree {
#define ls ro*2
#define rs ro*2+1
struct node {
int l, r, sum, flag;
} c[4 * N];
void down(int ro) {
c[ls].flag += c[ro].flag;
c[ls].sum += (c[ls].r - c[ls].l + 1) * c[ro].flag;
c[rs].flag += c[ro].flag;
c[rs].sum += (c[rs].r - c[rs].l + 1) * c[ro].flag;
c[ro].flag = 0;
}
void build(int ro, int l, int r, int *a) {
c[ro].l = l; c[ro].r = r;
if(l == r) {
c[ro].sum = a[l];
return;
}
int mid = l + r >> 1;
build(ls, l, mid, a);
build(rs, mid + 1, r, a);
c[ro].sum = c[ls].sum + c[rs].sum;
}
void update(int ro, int l, int r, int x) {
if(l <= c[ro].l && c[ro].r <= r) {
c[ro].flag += x;
c[ro].sum += (c[ro].r - c[ro].l + 1) * x;
return ;
}
down(ro);
int mid = c[ro].l + c[ro].r >> 1;
if(l <= mid)
update(ls, l, r, x);
if(mid < r)
update(rs, l, r, x);
c[ro].sum = c[ls].sum + c[rs].sum;
}
int ask(int ro, int l, int r) {
if(l <= c[ro].l && c[ro].r <= r) {
return c[ro].sum;
}
down(ro);
int sum = 0;
int mid = c[ro].l + c[ro].r >> 1;
if(l <= mid)
sum += ask(ls, l, r);
if(mid < r)
sum += ask(rs, l, r);
return sum;
}
public:
SegmentTree(int n, int *a) {
build(1, 1, n, a);
}
void modify(int l, int r, int x) {
update(1, l, r, x);
}
int query(int l, int r) {
return ask(1, l, r);
}
};