线段树 数据结构详解与模板


转载注明出自bestsort.cn,谢谢合作


线段树

线段树是一个查询和修改复杂度都为log(n)的数据结构。主要用于数组的单点修改&&单点查询&&区间求和&&区间修改. 另外一个拥有类似功能的是树状数组,但是树状数组最常用的是单点修改&&区间求和. 线段树完全涵盖树状数组所有功能

和树状数组的区别和联系

1.两者在复杂度上同级, 但是树状数组的常数明显优于线段树, 其编程复杂度也远小于线段树. 2.树状数组的作用被线段树完全涵盖, 凡是可以使用树状数组解决的问题, 使用线段树一定可以解决, 但是线段树能够解决的问题树状数组未必能够解决. 说了这么多,其实线段树就是个二叉树而已,只不过叶子节点记录的是区间之间的和而已 先给一份样图 其中,矩形内的是区间之和,区间外的是数组下标(线段树用数组存数据).不难看出,线段树的左孩子=根节点下标_2,右孩子=根节点下标_2+1,而左右孩子则是根节点将区间二分的结果. 先给出线段树的结构体定义然后咱们再仔细讲讲各种(sao)操作

struct node {
    int l,r,w,flag;
} a[maxn<<2]; //4倍空间

结构体里有个延迟标记的东西,咱们下面再说这个问题 需要注意的是如果是n个数,那么线段树需要开4n的空间.理论上是2n-1的空间,但是你递归建立的时候当前节点为r,那么左右孩子分别是2_r,2_r+1,此时编译器并不知道递归已结束,因为你的结束条件是在递归之前的,所以编译器会认为下标访问出错,也就是空间开小了,应该再开大2倍。有时候可能你发现开2,3倍的空间也可以AC,那只是因为测试数据并没有那么大。 至于为什么开4倍,我从网上摘抄了一部分(反正我是看不懂


首先线段树是一棵二叉树,最底层有n个叶子节点(n为区间大小) 那么由此可知,此二叉树的高度为,可证然后通过等比数列求和求得二叉树的节点个数,具体公式为,(x为树的层数,为树的高度+1) 化简可得,整理之后即为$4n$(近似计算忽略掉-1) 证毕


线段树的基础操作主要有5个: 建树、单点查询、单点修改、区间查询、区间修改。


建树

会建二叉树的话这一条也就没什么说的了 主要就是递归建树而已 其中,k为根节点,l,r分别为左右区间 输入n个数将其建立为线段树只需要调用 build(1,1,n)即可 递归过程应该都能看懂(看不懂回去学二叉树去

1
2
3
4
5
6
7
8
9
10
11
void build(int k,int l,int r) {
a[k].l = l,a[k].r = r;
if(a[k].l == a[k].r) {
scanf("%d",&a[k].w);
//cin >> a[k].w;
return;
}
build(k*2,l,(l+r)/2); //左
build(k*2+1,(l+r)/2+1, r);//右
a[k].w += a[k*2].w+a[k*2+1].w;//求和
}

延迟标记

这里咱们开始用到上面的变量flag了 上面说了,线段树是支持区间修改的,比如说开始那张图,咱把[1,5]都加上3,总不能把[1,5],[1,3],[4,5],[1,2],[3,3],[4,4],[5,5],[1,1],[2,2]都修改了啊,这样从第二层一直到第四层那我还要这个线段树干嘛,时间早爆炸了. 这时候,精髓部分来了,诶咱就只修改a[2]这个地方,也就是[1,5],下面的暂时用不上,就不管它.然后让flag=3. 如果下一次需要用到这一部分数据的话,将flag下传,这样查询哪一部分咱就算哪一部分的和,其他的就不管 要将[1,5]这部分+3但是不查询他的话,那么[1,5]的左右孩子也就没有更改的必要了 这个flag就是延迟标记,有了它,我们就只需要将修改过的区域标记,等到查询此部分的时候再向下修改就行了 以线段树区间1-10,初值全为0,[1,5]全部+3为例: 可以看出,[1,5]的子区间内的区间和是不对的(修改后不应该为0~) 没关系,我们只需要修改[1,5]和包含[1,5]的区间的内容即可,然后我们让flag = 3,[1,5]的子区间暂时不用管 (黑色数字代表区间和,红色代表flag的值) 如果接下来查询[1,3]或者[1,5]的其他子区间,我们再向下计算区间和,对于查询[1,3]而言,图是这样子的: 结论已经呼之欲出了: 如果查询的区域有延迟标记flag,就将标记下传,并且左右孩子的和+=flag(左右孩子区间内所存的数) 比如说[1,5]的左孩子区间为1-3,则为3\(3-1+1) = 3*3 具体操作如下

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
void down(int k) {
a[k*2].flag += a[k].flag;            //标记下传
a[k*2+1].flag += a[k].flag;

a[k*2].w += a[k].flag*(a[k*2].r-a[k*2].l+1);    //标记求和
a[k*2+1].w += a[k].flag *(a[k*2+1].r-a[k*2+1].l+1);
a[k].flag = 0;                        //下传之后清空当前节点的标记
}
```

---

## 区间查询


有了延迟标记的基础我们就可以进行区间求和了 也是比较简单的过程,会二分应该就能看懂
```cpp
void askinterval(int k,int x,int y) {
if(a[k].l>=x && a[k].r<=y) {
ans += a[k].w;            ///ans为全局变量,记得每次查询令ans = 0;
return;
}
if(a[k].flag)
down(k);
int buf = (a[k].l+a[k].r)/2;
if(x <= buf)
askinterval(k*2,x,y);           ///递归查左子树
if(y > buf)
askinterval(k*2+1,x,y);         ///递归查右子树
}


区间修改

区间修改和上面的区间查询代码基本相同,自行研究咯~

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
void changeinterval(int k,int x,int y,int z) {
if(a[k].l>=x &&a[k].r<=y) {
a[k].w += (a[k].r-a[k].l+1)*z;
a[k].flag += z;
return;
}
if(a[k].flag)
down(k);
int buf = (a[k].l+a[k].r)/2;
if(x <= buf)
changeinterval(k*2,x,y,z);
if(y > buf)
changeinterval(k*2+1,x,y,z);
a[k].w = a[k*2].w + a[k*2+1].w;
}


单点查询

其实单点查询完全可以使用上面区间查询的函数,反正都是一样的~ 不过毕竟是模板嘛,还是贴一份代码

1
2
3
4
5
6
7
8
9
10
11
12
13
void askinterval(int k,int x) {
if(a[k].l==x && a[k].r==x) {
ans = a[k].w;
return;
}
if(a[k].flag)
down(k);
int buf = (a[k].l+a[k].r)/2;
if(x <= buf)
askinterval(k*2,x);
if(y > buf)
askinterval(k*2+1,x);
}

单点修改

同样,单点修改也可以使用区间修改的代码,只需要让x和y一样就行.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
void changeinterval(int k,int x,int z) {
if(a[k].l==x &&a[k].r==x) {
a[k].w += (a[k].r-a[k].l+1)*z;
a[k].flag += z;
return;
}
if(a[k].flag)
down(k);
int buf = (a[k].l+a[k].r)/2;
if(x <= buf)
changeinterval(k*2,x,z);
if(y > buf)
changeinterval(k*2+1,x,z);
a[k].w = a[k*2].w + a[k*2+1].w;
}
```

老规矩,最后一道例题 [Hdu1754 I Hate It](http://acm.hdu.edu.cn/showproblem.php?pid=1754 "Hdu1754 I Hate It") 解题代码如下:
```cpp
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cmath>
#include <algorithm>
#include <queue>
#include <string>
#include <vector>
#define For(a,b) for(ll a=0;a<b;a++)
#define mem(a,b) memset(a,b,sizeof(a))
#define _mem(a,b) memset(a,0,(b+1)<<2)
#define lowbit(a) ((a)&-(a))
#define IO do{
ios::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);}while(0)

using namespace std;
typedef long long ll;
const ll maxn = 2*1e5+5;
const ll INF = 0x3f3f3f3f;
struct node {
ll l,r,w,flag;
} a[maxn<<2]; //4倍空间
ll c[maxn];
ll cnt;
void build(ll k,ll l,ll r) {
a[k].l = l,a[k].r = r;
if(a[k].l == a[k].r) {
scanf("%lld",&a[k].w);
//cin >> a[k].w;
return;
}
build(k*2, l, (l+r)/2);
build(k*2+1, (l+r)/2+1, r);
a[k].w = max(a[k*2].w,a[k*2+1].w);
}

void changellerval(ll k,ll x,ll z) {
if(a[k].l==x &&a[k].r==x) {
a[k].w = z;
return;
}
ll buf = (a[k].l+a[k].r)/2;
if(x <= buf)
changellerval(k*2,x,z);
if(x > buf)
changellerval(k*2+1,x,z);
a[k].w = max(a[k*2].w, a[k*2+1].w);
}
ll ans;
void askllerval(ll k,ll x,ll y) {
if(a[k].l>=x && a[k].r<=y) {
ans = max(a[k].w,ans);
return;
}
ll buf = (a[k].l+a[k].r)/2;
if(x <= buf)
askllerval(k*2,x,y);
if(y > buf)
askllerval(k*2+1,x,y);
}

int main() {
//IO;

char buf;
ll n,m;
ll x,y,z;
while(cin >> n >> m) {
build(1,1,n);
For(i,m) {
getchar();
scanf("%c",&buf);
//cin >> buf;
if(buf == 'Q') {
scanf("%lld%lld",&x,&y);
//cin >> x >> y;
ans = 0;
askllerval(1,x,y);
printf("%lldn",ans);
//cout << ans << endl;
} else {
scanf("%lld%lld",&x,&z);
//cin >> x >> y >> z;
changellerval(1,x,z);
}
}
}
return 0;
}


线段树模板

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
//最好是全局开long long
#include <bits/stdc++.h>

using namespace std;
#define mem(a, b) memset(a, b, sizeof a)
#define IN freopen("in.txt", "r", stdin)
#define DEBUG(a) cout << (a) << endl

typedef long long ll;
int dir8[8][2] = {{1, 0}, {0, 1}, {-1, 0}, {0, -1}, {1, 1}, {1, -1}, {-1, 1}, {-1, -1}};
int dir4[4][2] = {1, 0, 0, 1, -1, 0, 0, -1};
const int INF = 0x3f3f3f3f;
int mod = 1e9 + 7;
const int maxn = 1e5 + 10;

struct node
{
int l, r, w, flag;
int dis() { return r - l + 1; }
int mid() { return (r + l) / 2; }
} a[maxn * 4];

void build(int k, int l, int r)
{ //当前节点的区间
a[k] = {l, r, 0, 0};
if (l == r)
{
cin >> a[k].w;
return;
}
build(k << 1, l, a[k].mid());
build(k << 1 | 1, a[k].mid() + 1, r);
a[k].w = a[k << 1].w + a[k << 1 | 1].w;
}
void down(int k)
{
a[k << 1].w += a[k << 1].dis() * a[k].flag;
a[k << 1 | 1].w += a[k << 1 | 1].dis() * a[k].flag;
a[k << 1].flag += a[k].flag;
a[k << 1 | 1].flag += a[k].flag;
a[k].flag = 0;
}

void update(int k, int l, int r, int w)
{ // 要更新的总区间.(l,r)不变
if (a[k].l >= l && a[k].r <= r)
{
a[k].w += w * a[k].dis();
a[k].flag += w;
return;
}
if (a[k].flag)
down(k);
if (a[k].mid() >= l)
update(k << 1, l, r, w);
if (a[k].mid() < r)
update(k << 1 | 1, l, r, w);
a[k].w = a[k << 1].w + a[k << 1 | 1].w;
}

int query(int k, int l, int r)
{
if (a[k].l >= l && a[k].r <= r)
return a[k].w;
if (a[k].flag)
down(k);
int sum = 0;
if (a[k].mid() >= l)
sum += query(k << 1, l, r);
if (a[k].mid() < r)
sum += query(k << 1 | 1, l, r);
a[k].w = a[k << 1].w + a[k << 1 | 1].w;
return sum;
}

int main()
{
//IN;
ios::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);

int n, m;
cin >> n >> m;
build(1, 1, n);
while (m--)
{
int func;
int l, r, w;
cin >> func;
if (func == 1)
{
cin >> l >> r >> w;
update(1, l, r, w);
}
else if (func == 2)
{
cin >> l >> r;
cout << query(1, l, r) << endl;
}
/* code */
}

return 0;
}

最后

线段树重要的其实不是它本身,而是以后学习其他相关算法的基石,区间查询 / 修改这个特性使得线段树能够和 树链剖分 或者其他数据结构联动,从而达到解题的目的

觉得文章不错的话可以请我喝一杯茶哟~
  • 本文作者: bestsort
  • 本文链接: https://bestsort.cn/2019/04/28/482/
  • 版权声明: 本博客所有文章除特别声明外,均采用 BY-SA 许可协议。转载请注明出处!并保留本声明。感谢您的阅读和支持!
-------------本文结束感谢您的阅读-------------