写在最前
我承认这个封面的线段树
有点过于抽象,但是真的没有别的办法乐。
先引用一个大佬说的一句话:
什么是线段树?
如果你在考提高组前一天还在问这个问题,那么你会与一等奖失之交臂;如果你还在冲击普及组一等奖,那么这篇博客会浪费你人生中宝贵的5~20分钟。
显而易见,线段树是一个 OIer 从萌新过渡到正式选手的标志性算法。
然而事实上,对于一个正式的 OIer 选手,线段树更应该是一个工具,所以……
线段树是世界上最好的数据结构!!!(bushi
线段树可以 $O(\log n)$ 的实现区间和、区间乘、区间最值、区间修改等操作,在很多题中都可以以略逊于正解的解法过掉。
核心思想
线段树的核心思想是 分治。它将一个区间分成若干个子区间,每个节点代表一个区间,并通过递归的方式维护区间的信息。
- 区间划分:对于一个区间 $[l, r]$,我们将其分成两个子区间 $[l, mid]$ 和 $[mid+1, r]$,其中 $mid = \lfloor \frac{l+r}{2} \rfloor$。
- 递归维护:通过递归的方式,逐步将区间划分到最小单位(即叶子节点),并在回溯时合并子区间的信息。
不难发现复杂度为均 $O(n\log n)$。
正常情况下我们存一棵树可能会用到存图的方法,或者是存每一个点的子节点,然而为了维护一串序列而真的建一个图实在是又臭又长过于屎山,不难想到父子两倍,即父节点为 $n$ 的话,子节点分别是 $2n$ 和 $2n+1$。
如果用这种方法的话不妨思考一下他所需的空间(以下非重点,可以酌情跳过):
首先设这个序列长度为 $n$。
按照正常的逻辑,线段树的最下面是长度为 1 的区间,那么也就是说总共有 $n$ 个叶子节点。
不难发现总共有 $2n-1$ 个节点。
然而实际上,我们在建树中,会浪费掉一些空间,所以说开 $2n-1$ 是不一定够的。
- 线段树的深度为 $\lceil \log_2 n \rceil$。
- 在最坏情况下,线段树的最后一层可能会有 $2^{\lceil \log_2 n \rceil}$ 个节点。
- 由于 $2^{\lceil \log_2 n \rceil} \leq 2n$,因此总节点数最多为 $2 \times 2n - 1 = 4n - 1$。
在实现线段树时,区间更新是一个常见的操作。然而,如果每次更新都递归到叶子节点,时间复杂度会退化为 $O(N)$,这显然无法接受。
例如我们进行两次操作:
操作 1:区间 [1, 3]
加 2
- 递归到叶子节点
[1, 1]
、[2, 2]
和 [3, 3]
,分别更新它们的值。
- 更新后的数组为
a = [3, 4, 5, 4, 5]
。
操作 2:区间 [2, 4]
加 3
- 递归到叶子节点
[2, 2]
、[3, 3]
和 [4, 4]
,分别更新它们的值。
- 更新后的数组为
a = [3, 7, 8, 7, 5]
。
可以看到在操作 2 中,节点 [2, 2]
和 [3, 3]
已经被操作 1 更新过,现在又被操作 2 更新了一次。这种重复操作会导致时间复杂度增加。
不难想到可以把一个节点的修改累加标记起来,直到需要访问时再下传,这就是懒标记。
懒标记下传是线段树的核心优化。它通过延迟更新的方式,避免不必要的递归操作。
实现
建树
建树的过程是通过递归实现的。我们从根节点开始,逐步将区间划分到最小单位,并在回溯时合并子区间的信息。
1 2 3 4 5 6 7 8 9 10 11 12
| void build(int ro, int l, int r) { t[ro].l = l; t[ro].r = r; if (l == r) { t[ro].sum = a[l]; return; } int mid = (l + r) / 2; build(ro * 2, l, mid); build(ro * 2 + 1, mid + 1, r); t[ro].sum = t[ro * 2].sum + t[ro * 2 + 1].sum; }
|
区间查询
区间查询的过程是通过递归实现的。我们从根节点开始,逐步向下查询目标区间,并合并子区间的信息。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
| int query(int ro, int l, int r) { if (l <= t[ro].l && t[ro].r <= r) { return t[ro].sum; } down(ro); int mid = (t[ro].l + t[ro].r) / 2; int sum = 0; if (l <= mid) { sum += query(ro * 2, l, r); } if (r > mid) { sum += query(ro * 2 + 1, l, r); } return sum; }
|
区间更新
区间更新的过程是通过递归实现的。我们从根节点开始,逐步向下更新目标区间,并通过懒标记延迟更新。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
| void update(int ro, int l, int r, int x) { if (l <= t[ro].l && t[ro].r <= r) { t[ro].flag += x; t[ro].sum += (t[ro].r - t[ro].l + 1) * x; return; } down(ro); int mid = (t[ro].l + t[ro].r) / 2; if (l <= mid) { update(ro * 2, l, r, x); } if (r > mid) { update(ro * 2 + 1, l, r, x); } t[ro].sum = t[ro * 2].sum + t[ro * 2 + 1].sum; }
|
懒标记下传
将节点的懒标记下传到子节点,并清空当前节点的标记
1 2 3 4 5 6 7 8 9
| void down(int ro) { if (t[ro].flag != 0) { t[ro * 2].flag += t[ro].flag; t[ro * 2].sum += (t[ro * 2].r - t[ro * 2].l + 1) * t[ro].flag; t[ro * 2 + 1].flag += t[ro].flag; t[ro * 2 + 1].sum += (t[ro * 2 + 1].r - t[ro * 2 + 1].l + 1) * t[ro].flag; t[ro].flag = 0; } }
|
完整代码
喜闻乐见封装环节
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
| class SegmentTree { #define ls ro*2 #define rs ro*2+1 struct node { int l, r, sum, flag; } c[4 * N]; void down(int ro) { c[ls].flag += c[ro].flag; c[ls].sum += (c[ls].r - c[ls].l + 1) * c[ro].flag; c[rs].flag += c[ro].flag; c[rs].sum += (c[rs].r - c[rs].l + 1) * c[ro].flag; c[ro].flag = 0; } void build(int ro, int l, int r, int *a) { c[ro].l = l; c[ro].r = r; if(l == r) { c[ro].sum = a[l]; return; } int mid = l + r >> 1; build(ls, l, mid, a); build(rs, mid + 1, r, a); c[ro].sum = c[ls].sum + c[rs].sum; } void update(int ro, int l, int r, int x) { if(l <= c[ro].l && c[ro].r <= r) { c[ro].flag += x; c[ro].sum += (c[ro].r - c[ro].l + 1) * x; return ; } down(ro); int mid = c[ro].l + c[ro].r >> 1; if(l <= mid) update(ls, l, r, x); if(mid < r) update(rs, l, r, x); c[ro].sum = c[ls].sum + c[rs].sum; } int ask(int ro, int l, int r) { if(l <= c[ro].l && c[ro].r <= r) { return c[ro].sum; } down(ro); int sum = 0; int mid = c[ro].l + c[ro].r >> 1; if(l <= mid) sum += ask(ls, l, r); if(mid < r) sum += ask(rs, l, r); return sum; } public: SegmentTree(int n, int *a) { build(1, 1, n, a); } void modify(int l, int r, int x) { update(1, l, r, x); } int query(int l, int r) { return ask(1, l, r); } };
|