简介

线段树是用于维护一个区间内具有结合律数据的高效数据结构,复杂度为 O(logn)O(\log n),但常数较大。

将一个区间不断以 mid=l+r2,[l,mid],[mid+1,r]mid=\lfloor \cfrac{l+r}{2}\rfloor,[l,mid],[mid+1,r] 为区间分为两个子树,由于除了最后一层所有层都是满的,所以用类似于存储堆的形式,即数组来存储线段树,一个结点编号为 nn 那么左子结点是 2n2n,右子结点是 2n+12n+1

设倒数第二层有 2k2^k 个结点,那么上面有 i=0k12i=2k1\sum_{i=0}^{k-1} 2^i = 2^k-1 个结点,下面一层最多有 2k+12^{k+1} 个结点,设总共有 NN 个元素,坏的情况下应该是倒数第二层很多,最后一层很少,因此保险起见开 4N4N 的空间绝对够用。

下面是例题。

单点修改

最大数

给定一个正整数数列 a1,a2,…,an,每一个数都在 0~p-1 之间。

可以对这列数进行两种操作:

  1. 添加操作:向序列后添加一个数,序列长度变成 n+1;
  2. 询问操作:询问这个序列中最后 L 个数中最大的数是多少。

程序运行的最开始,整数序列为空。

一共要对整数序列进行 m 次操作。

写一个程序,读入操作的序列,并输出询问操作的答案。

输入格式

第一行有两个正整数 m,p,意义如题目描述;

接下来 m 行,每一行表示一个操作。

如果该行的内容是 Q L,则表示这个操作是询问序列中最后 L 个数的最大数是多少;

如果是 A t,则表示向序列后面加一个数,加入的数是 (t+a)%p。其中,t 是输入的参数,a 是在这个添加操作之前最后一个询问操作的答案(如果之前没有询问操作,则 a=0)。

第一个操作一定是添加操作。对于询问操作,L>0 且不超过当前序列的长度。

输出格式

对于每一个询问操作,输出一行。该行只有一个数,即序列中最后 L 个数的最大数。

题目链接:AcWing 1275

先开 m 个元素需要的空间。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#include <iostream>
#include <algorithm>
using namespace std;

typedef long long ll;
const int M = 2e5+10;
int m, p;

struct Node {
int l, r;
ll v;
} tr[M<<2];

// 由子节点信息计算父结点信息, 传入父结点
void pushup(int u) {
// u << 1 和 u << 1 | 1 其实就是 2*u 和 2*u + 1
// 这点编译器肯定能优化的东西为什么要这么写呢
// 答案很简单 装x而已
tr[u].v = max(tr[u << 1].v, tr[u << 1 | 1].v);
}

// 建树
void build(int u, int l, int r) {
tr[u] = {l, r};
if (l == r) return;
int mid = l+r >> 1;
build(u << 1, l, mid); // 左子树
build(u << 1 | 1, mid+1, r); // 右子树
}

// 查询
ll query(int u, int l, int r) {
// 树中所有节点已经被完全包含
if (l <= tr[u].l && tr[u].r <= r) return tr[u].v;
int mid = tr[u].l + tr[u].r >> 1;
ll res = 0;
if (l <= mid) res = query(u << 1, l, r);
if (mid+1 <= r) res = max(res, query(u << 1 | 1, l, r));
return res;
}

void modify(int u, int p, ll v) {
if (tr[u].l == p && tr[u].r == p) tr[u].v = v;
else {
int mid = tr[u].l + tr[u].r >> 1;
// 判断处于哪个子树中
if (p <= mid) modify(u << 1, p, v);
else modify(u << 1 | 1, p, v);
// 用子节点信息更新当前结点
pushup(u);
}
}

int main() {
int last = 0, n = 0;
cin >> m >> p;
build(1, 1, m);
int x; char op;
while (m--) {
cin >> op >> x;
if (op == 'Q') {
last = query(1, n-x+1, n);
cout << last << endl;
} else {
modify(1, n+1, ((ll)x + last) % p);
n++; // 修改完之后再增加长度
}
}
return 0;
}

最大连续子区间和

给定长度为 N 的数列 A,以及 M 条指令,每条指令可能是以下两种之一:

  1. 1 x y,查询区间 [x,y] 中的最大连续子段和,即 maxxlry(i=lrA[i])\max _{x\le l \le r \le y}(\sum_{i=l}^rA[i])
  2. 2 x y,把 A[x] 改成 y。

对于每个查询指令,输出一个整数表示答案。

题目链接:AcWing 245

每个结点需要维护四个属性:最大连续子区间和 tmaxtmax,最大前缀和 lmaxlmax,最大后缀和 rmaxrmax,区间和sumsum,对于一个结点 uu 和它的两个子结点 l,rl, r 来说,满足下面的关系:

tmax(u)=max{tmax(l),tmax(r),rmax(l)+lmax(r)}lmax(u)=max{lmax(l),sum(l)+lmax(r)}rmax(u)=max{rmax(r),rmax(l)+sum(r)}sum(u)=sum(l)+sum(r)\begin{aligned} tmax(u)&=\max\{tmax(l), tmax(r), rmax(l)+lmax(r)\}\\ lmax(u)&=\max\{lmax(l), sum(l)+lmax(r)\}\\ rmax(u)&=max\{rmax(r), rmax(l)+sum(r)\}\\ sum(u)&=sum(l)+sum(r) \end{aligned}

在线段树中查询一段时,如果区间 [x,y][x,y] 覆盖了结点 uu 的两个子结点,那么需要获取到这两个子结点然后仿照上面计算最大值。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#include <iostream>
#include <algorithm>
using namespace std;

const int N = 500010;

struct Node {
int l, r;
int tmax, lmax, rmax, sum;
void set(int v) {
tmax = lmax = rmax = sum = v;
}
} tr[N<<2];
int a[N], n, m;

void pushup(Node& u, Node l, Node r) {
u.tmax = max(l.tmax, max(r.tmax, l.rmax + r.lmax));
u.lmax = max(l.lmax, l.sum+r.lmax);
u.rmax = max(l.rmax+r.sum, r.rmax);
u.sum = l.sum + r.sum;
}

void pushup(int u) {
pushup(tr[u], tr[u<<1], tr[u<<1|1]);
}

void build(int u, int l, int r) {
// 不管哪个结点都要初始化 l 和 r
tr[u] = {l, r};
if (l == r) tr[u].set(a[l]);
else {
int mid = l+r >> 1;
build(u<<1, l, mid);
build(u<<1|1, mid+1, r);
pushup(u);
}
}

void modify(int u, int p, int v) {
if (tr[u].l == p && tr[u].r == p) tr[u].set(v);
else {
int mid = tr[u].l + tr[u].r >> 1;
if (p <= mid) modify(u<<1, p, v);
else modify(u<<1|1, p, v);
pushup(u);
}
}

Node query(int u, int l, int r) {
if (l <= tr[u].l && tr[u].r <= r) return tr[u];
else {
int mid = tr[u].l + tr[u].r >> 1;
if (l > mid) return query(u<<1|1, l, r);
else if (r < mid+1) return query(u<<1, l, r);
else {
Node res;
pushup(res, query(u<<1, l, r), query(u<<1|1, l, r));
return res;
}
}
}

int main() {
cin >> n >> m;
for (int i = 1; i <= n; i++) cin >> a[i];
build(1, 1, n);

int t, x, y;
while (m--) {
cin >> t >> x >> y;
if (t == 1) {
if (x > y) swap(x, y);
cout << query(1, x, y).tmax << endl;
} else modify(1, x, y);
}

return 0;
}

区间最大公约数

给定一个长度为 N 的数列 A,以及 M 条指令,每条指令可能是以下两种之一:

  1. C l r d,表示把 A[l],A[l+1],…,A[r],都加上 d。
  2. Q l r,表示询问 A[l],A[l+1],…,A[r] 的最大公约数(GCD)。

对于每个询问,输出一个整数表示答案。

题目链接:AcWing 246

结论是,一个数列与它的差分数列最大公约数相同。

gcd(a1,a2,,an)=gcd(a1,a2a1,,anan1)\gcd(a_1,a_2,\cdots,a_n)=\gcd(a_1,a_2-a_1,\cdots,a_n-a_{n-1})

考虑左边的任意一个约数 dd,它一定也是右边的约数;右边的任意一个约数 dd 也一定是左边的约数:dda1,a2a1a_1, a_2-a_1 的约数一定也是 (a2a1)+a1(a_2-a_1)+a_1 的约数。

但是,对于一个数列的区间 [L,R][L,R] 它的差分的 gcd 就不是整个 gcd 了,需要先求出来第一项。

用差分使得此问题转化为单点修改与单点查询。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
#include <iostream>
#include <algorithm>
using namespace std;

typedef long long ll;
const int N = 500010;

struct Node {
int l, r;
ll v, sum;
} tr[N<<2];

ll a[N];

ll _gcd(ll a, ll b) {
return a ? _gcd(b%a, a) : b;
}

ll gcd(ll a, ll b) {
return _gcd(abs(a), abs(b));
}

void pushup(Node& u, Node l, Node r) {
u.v = gcd(l.v, r.v);
u.sum = l.sum + r.sum;
}

void pushup(int u) {
pushup(tr[u], tr[u<<1], tr[u<<1|1]);
}

void build(int u, int l, int r) {
tr[u] = {l, r};
if (l == r) {
tr[u].v = tr[u].sum = a[l]-a[l-1];
} else {
int mid = l+r >> 1;
build(u<<1, l, mid);
build(u<<1|1, mid+1, r);
pushup(u);
}
}

void modify(int u, int p, ll v) {
if (tr[u].l == p && tr[u].r == p) tr[u].v += v, tr[u].sum += v;
else {
int mid = tr[u].l + tr[u].r >> 1;
if (p <= mid) modify(u<<1, p, v);
else modify(u<<1|1, p, v);
pushup(u);
}
}

Node query(int u, int l, int r) {
if (l <= tr[u].l && tr[u].r <= r) return tr[u];
else {
int mid = tr[u].l + tr[u].r >> 1;
if (l > mid) return query(u<<1|1, l, r);
else if (r <= mid) return query(u<<1, l, r);
else {
Node res;
pushup(res, query(u<<1, l, r), query(u<<1|1, l, r));
return res;
}
}
}

int main() {
int n, m;
cin >> n >> m;
for (int i = 1; i <= n; i++) cin >> a[i];
build(1, 1, n);
char op;
int l, r;
ll d;
while (m--) {
cin >> op >> l >> r;
if (op == 'Q') {
ll res = query(1, 1, l).sum;
if (l < r) res = gcd(res, query(1, l+1, r).v);
cout << res << endl;
} else {
cin >> d;
modify(1, l, d);
if (r < n) modify(1, r+1, -d);
}
}
return 0;
}

区间修改

模板题

给定一个长度为 N 的数列 A,以及 M 条指令,每条指令可能是以下两种之一:

  1. C l r d,表示把 A[l],A[l+1],…,A[r] 都加上 d。
  2. Q l r,表示询问数列中第 l∼r 个数的和。

对于每个询问,输出一个整数表示答案。

题目链接:AcWing 243P3372(不完全相同)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
#include <cstdio>
using namespace std;

typedef long long LL;
const int N = 1e5+10;

struct Node {
int l, r;
LL sum, add;
void spread(int d) {
sum += (r-l+1) * (LL)d;
add += d;
}
} tr[N * 4];
int a[N], n, m;

void pushdown(int u) {
if (tr[u].add) {
tr[u<<1].spread(tr[u].add);
tr[u<<1|1].spread(tr[u].add);
tr[u].add = 0;
}
}

void pushup(int u) {
tr[u].sum = tr[u<<1].sum + tr[u<<1|1].sum;
}

void build(int u, int l, int r) {
tr[u] = {l, r};
if (l == r) tr[u].sum = a[l];
else {
int mid = l+r >> 1;
build(u<<1, l, mid);
build(u<<1|1, mid+1, r);
pushup(u);
}
}

void modify(int u, int l, int r, int d) {
if (l <= tr[u].l && tr[u].r <= r) {
tr[u].spread(d);
} else {
// 修改前先下放
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if (l <= mid) modify(u<<1, l, r, d);
if (mid+1 <= r) modify(u<<1|1, l, r, d);
// 回溯时重计算
pushup(u);
}
}

LL query(int u, int l, int r) {
if (l <= tr[u].l && tr[u].r <= r) return tr[u].sum;

pushdown(u);
LL res = 0;
int mid = tr[u].l + tr[u].r >> 1;
if (l <= mid) res = query(u<<1, l, r);
if (mid+1 <= r) res += query(u<<1|1, l, r);
return res;
}

int main() {
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i++) scanf("%d", a+i);
build(1, 1, n);
char op[2];
int l, r, d;
while (m--) {
scanf("%s%d%d", op, &l, &r);
if (*op == 'C') {
scanf("%d", &d);
modify(1, l, r, d);
} else {
printf("%lld\n", query(1, l, r));
}
}
return 0;
}

维护序列

有长为 N 的数列,不妨设为 a1,a2,…,aN。

有如下三种操作形式:

  1. 把数列中的一段数全部乘一个值;
  2. 把数列中的一段数全部加一个值;
  3. 询问数列中的一段数的和,由于答案可能很大,你只需输出这个数模 P 的值。

输入格式

第一行两个整数 N 和 P;

第二行含有 N 个非负整数,从左到右依次为 a1,a2,…,aN;

第三行有一个整数 M,表示操作总数;

从第四行开始每行描述一个操作,输入的操作有以下三种形式:

  • 操作 1:1 t g c,表示把所有满足 tigt≤i≤gaia_i 改为 ai×ca_i\times c
  • 操作 2:2 t g c,表示把所有满足 tigt≤i≤gaia_i 改为 ai+ca_i+c
  • 操作 3:3 t g,询问所有满足 tigt≤i≤gaia_i 的和模 PP 的值。

同一行相邻两数之间用一个空格隔开,每行开头和末尾没有多余空格。

输出格式

对每个操作 3,按照它在输入中出现的顺序,依次输出一行一个整数表示询问结果。

题目链接:AcWing 1277

结点中的懒标记分别是 add,mul\text{add}, \text{mul} 这里有一个技巧,可以同时传递加和乘,所有的计算都看成先乘后加,那么:

a(mul x+add)+b=a mul x+a add+ba(\text{mul }x+\text{add}) + b=a\text{ mul }x+a\text{ add}+b

对应的懒标记更新就是:

mul=a muladd=a add+b\begin{aligned} \text{mul}&=a \text{ mul}\\ \text{add}&=a \text{ add}+b \end{aligned}

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#include <iostream>
#include <algorithm>
using namespace std;

typedef long long LL;
const int N = 1e5+10;
int n, m, p, a[N];
struct Node {
int l, r;
LL sum, add, mul = 1;
// 这里不开 LL 给我 WA 了好几发
void spread(LL a, LL b) {
mul = (a * mul) % p;
add = (a * add + b) % p;
sum = (a * sum + (r-l+1) * b) % p;
}
} tr[N<<2];

void pushup(int u) {
tr[u].sum = (tr[u<<1].sum + tr[u<<1|1].sum) % p;
}

void pushdown(int u) {
if (tr[u].add || tr[u].mul != 1) {
tr[u<<1].spread(tr[u].mul, tr[u].add);
tr[u<<1|1].spread(tr[u].mul, tr[u].add);
tr[u].add = 0;
tr[u].mul = 1;
}
}

void build(int u, int l, int r) {
tr[u].l = l, tr[u].r = r;
if (l == r) tr[u].sum = a[l];
else {
int mid = l+r >> 1;
build(u<<1, l, mid);
build(u<<1|1, mid+1, r);
pushup(u);
}
}

void modify(int u, int l, int r, int a, int b) {
if (l <= tr[u].l && tr[u].r <= r) tr[u].spread(a, b);
else {
int mid = tr[u].l + tr[u].r >> 1;
pushdown(u);
if (l <= mid) modify(u<<1, l, r, a, b);
if (r >= mid+1) modify(u<<1|1, l, r, a, b);
pushup(u);
}
}

LL query(int u, int l, int r) {
if (l <= tr[u].l && tr[u].r <= r) return tr[u].sum;
pushdown(u);
LL res = 0;
int mid = tr[u].l + tr[u].r >> 1;
if (l <= mid) res = query(u<<1, l, r);
if (r >= mid+1) res = (res + query(u<<1|1, l, r)) % p;
return res;
}

int main() {
scanf("%d%d", &n, &p);
for (int i = 1; i <= n; i++) scanf("%d", a+i);
build(1, 1, n);
scanf("%d", &m);
int t, x, y, d;
while (m--) {
scanf("%d%d%d", &t, &x, &y);
if (t == 3) printf("%lld\n", query(1, x, y));
else {
scanf("%d", &d);
if (t == 2) modify(1, x, y, 1, d);
else modify(1, x, y, d, 0);
}
}
return 0;
}

主席树

第 k 小数

给定长度为 N 的整数序列 A,下标为 1∼N。

执行 M 次操作,每次给出三个整数 l, r, k, 求 A[l, r] 中的第 k 小数是多少。

题目链接:AcWing 255

主席树是一个权值线段树,意思是它的下标是一个具体的值,每个结点存储这个值出现了几次,这样求第 k 小数的时候就和 Splay 求第 k 小数的方式类似了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;

const int N = 1e5+10;

struct Node {
int l, r, cnt;
} tr[N*4+N*17];

int root[N], a[N], idx;
vector<int> nums;

int build(int l, int r) {
int p = ++idx;
if (l == r) return p;
int mid = l+r >> 1;
tr[p].l = build(l, mid), tr[p].r = build(mid+1, r);
return p;
}

int insert(int p, int l, int r, int x) {
int q = ++idx, mid = l+r >> 1;
tr[q] = tr[p];
if (l == r) {
tr[q].cnt++;
return q;
}
if (x <= mid) tr[q].l = insert(tr[p].l, l, mid, x);
else tr[q].r = insert(tr[p].r, mid+1, r, x);
// pushup
tr[q].cnt = tr[tr[q].l].cnt + tr[tr[q].r].cnt;
return q;
}

int query(int p, int q, int l, int r, int k) {
if (l == r) return r;
int cnt = tr[tr[q].l].cnt - tr[tr[p].l].cnt;
int mid = l+r >> 1;
if (k <= cnt) return query(tr[p].l, tr[q].l, l, mid, k);
else return query(tr[p].r, tr[q].r, mid+1, r, k-cnt);
}

int find(int n) {
return lower_bound(nums.begin(), nums.end(), n) - nums.begin();
}

int main() {
int n, m;
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i++) {
scanf("%d", &a[i]);
nums.push_back(a[i]);
}
sort(nums.begin(), nums.end());
nums.erase(unique(nums.begin(), nums.end()), nums.end());
root[0] = build(0, nums.size() - 1);
for (int i = 1; i <= n; i++)
root[i] = insert(root[i-1], 0, nums.size()-1, find(a[i]));
while (m--) {
int l, r, k;
scanf("%d%d%d", &l, &r, &k);
printf("%d\n", nums[query(root[l-1], root[r], 0, nums.size()-1, k)]);
}
return 0;
}