简单树套树

请你写出一种数据结构,来维护一个长度为 n 的序列,其中需要提供以下操作:

  1. 1 pos x,将 pos 位置的数修改为 x。
  2. 2 l r x,查询整数 x 在区间 [l,r] 内的前驱(前驱定义为小于 x,且最大的数)。

数列中的位置从左到右依次标号为 1∼n。

题目链接:AcWing 2488

外层是一个线段树,线段树的每个结点都包含一个平衡树 multiset

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
#include <cstdio>
#include <set>
using namespace std;

#define lson (u<<1)
#define rson (u<<1|1)
const int N = 50010, INF = 1e9;

struct Node {
int l, r;
multiset<int> s;
} tr[N*4];
int a[N], n, m;

void build(int u, int l, int r) {
tr[u] = {l, r};
tr[u].s.insert(-INF), tr[u].s.insert(INF);
for (int i = l; i <= r; i++) tr[u].s.insert(a[i]);
if (l == r) return;
int mid = l+r >> 1;
build(lson, l, mid), build(rson, mid+1, r);
}

void modify(int u, int p, int x) {
// 如果用 erase(a[p]) 会删除掉所有元素
// 我们只希望删除一个
tr[u].s.erase(tr[u].s.find(a[p]));
tr[u].s.insert(x);
if (tr[u].l == p && tr[u].r == p) return;
int mid = tr[u].l + tr[u].r >> 1;
if (p <= mid) modify(lson, p, x);
else modify(rson, p, x);
}

int query(int u, int l, int r, int x) {
if (l <= tr[u].l && tr[u].r <= r) {
auto it = tr[u].s.lower_bound(x);
return *--it;
}
int res = -INF, mid = tr[u].l + tr[u].r >> 1;
if (l <= mid) res = max(res, query(lson, l, r, x));
if (mid+1 <= r) res = max(res, query(rson, l, r, x));
return res;
}

int main() {
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i++) scanf("%d", &a[i]);
build(1, 1, n);
while (m--) {
int opt, l, r, p, x;
scanf("%d", &opt);
if (opt == 1) {
scanf("%d%d", &p, &x);
modify(1, p, x);
a[p] = x;
}
else {
scanf("%d%d%d", &l, &r, &x);
printf("%d\n", query(1, l, r, x));
}
}
return 0;
}

二逼平衡树

您需要写一种数据结构(可参考题目标题),来维护一个有序数列,其中需要提供以下操作:

  1. 查询 x 在区间内的排名;
  2. 查询区间 [l, r] 内排名为 k 的值;
  3. 修改某一位置上的数值;
  4. 查询 x 在区间 [l, r] 内的前驱(前驱定义为小于 x,且最大的数);
  5. 查询 x 在区间 [l, r] 内的后继(后继定义为大于 x,且最小的数)。

如果不存在前驱输出 -2147483647,不存在后继输出 2147483647

题目链接:LOJ 106P3380AcWing 2476

按照理论上来看,平衡树需要的点数这么计算:

  1. 对于 4N4N 个线段树的结点,每个都需要两个哨兵,这样算下来是 8N8N
  2. 对于 logN\log N 层线段树的结点,每层需要 NN 个结点。
  3. 加起来是 8N+NlogN120w8N+N\log N\approx 120w,但是洛谷的数据比较强,我们这个代码没做内存回收(做内存回收就太慢了),因此删一个点就要多开一个,开到 200w200w 才够。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
#include <cstdio>
#include <cstdlib>
#include <algorithm>
using namespace std;

int n, m, a[50010];
const int INF = 2147483647;

namespace splay {
const int N = 2000010;
struct Node {
int s[2], p, v;
int size, cnt;

void init(int _p, int _v) {
p = _p, v = _v;
size = cnt = 1;
}
} tr[N];
int idx;

void pushup(int u) {
tr[u].size = tr[u].cnt + tr[tr[u].s[0]].size + tr[tr[u].s[1]].size;
}

void rotate(int x) {
int y = tr[x].p, z = tr[y].p;
int k = tr[y].s[1] == x;
tr[z].s[tr[z].s[1] == y] = x, tr[x].p = z;
tr[y].s[k] = tr[x].s[k^1], tr[tr[x].s[k^1]].p = y;
tr[x].s[k^1] = y, tr[y].p = x;
pushup(y), pushup(x);
}

void splay(int& root, int x, int k) {
while (tr[x].p != k) {
int y = tr[x].p, z = tr[y].p;
if (z != k) {
if ((tr[y].s[0] == x) ^ (tr[z].s[0] == y)) rotate(x);
else rotate(y);
}
rotate(x);
}
if (k == 0) root = x;
}

void insert(int& root, int v) {
int u = root, p = 0;
while (u && tr[u].v != v) p = u, u = tr[u].s[v > tr[u].v];
if (u) tr[u].cnt++;
else {
u = ++idx;
tr[u].init(p, v);
if (p) tr[p].s[v > tr[p].v] = u;
}
splay(root, u, 0);
}

int count_lower(int& root, int v) {
int u = root, res = 0;
while (u) {
if (tr[u].v > v) u = tr[u].s[0];
else {
res += tr[tr[u].s[0]].size;
if (tr[u].v == v) {
splay(root, u, 0);
break;
}
res += tr[u].cnt;
u = tr[u].s[1];
}
}
return res;
}

void del(int& root, int v) {
int u = root;
while (u && tr[u].v != v) u = tr[u].s[v > tr[u].v];
if (!u) exit(-1);

splay(root, u, 0);
if (tr[u].cnt > 1) tr[u].cnt--;
else {
int l = tr[u].s[0], r = tr[u].s[1];
while (tr[l].s[1]) l = tr[l].s[1];
while (tr[r].s[0]) r = tr[r].s[0];
splay(root, l, 0), splay(root, r, l);
tr[r].s[0] = 0;
pushup(r), pushup(l);
}
}

int pre(int& root, int v) {
insert(root, v);

int u = root;
while (tr[u].v != v) u = tr[u].s[v > tr[u].v];
splay(root, u, 0);

int l = tr[u].s[0];
while (tr[l].s[1]) l = tr[l].s[1];
del(root, v);

return tr[l].v;
}

int suc(int& root, int v) {
insert(root, v);

int u = root;
while (tr[u].v != v) u = tr[u].s[v > tr[u].v];
splay(root, u, 0);

int r = tr[u].s[1];
while (tr[r].s[0]) r = tr[r].s[0];
del(root, v);

return tr[r].v;
}

// 调试用的
void _output(int u) {
if (tr[u].s[0]) _output(tr[u].s[0]);
for (int i = 0; i < tr[u].cnt; i++) {
printf("%d ", tr[u].v);
}
if (tr[u].s[1]) _output(tr[u].s[1]);
}

void output(int u) {
_output(u);
puts("");
}
};

namespace seg {
const int N = 200010;
struct Node {
int l, r;
int rt;
} tr[N];

void build(int u, int l, int r) {
tr[u] = {l, r};
splay::insert(tr[u].rt, -INF);
splay::insert(tr[u].rt, INF);
for (int i = l; i <= r; i++) {
splay::insert(tr[u].rt, a[i]);
}
if (l == r) return;
int mid = (l+r) >> 1;
build(u<<1, l, mid), build(u<<1|1, mid+1, r);
}

int count_lower(int u, int a, int b, int v) {
// 注意这里一定要减去那个 -INF
// 设置哨兵能减少很多麻烦
if (a <= tr[u].l && tr[u].r <= b)
return splay::count_lower(tr[u].rt, v) - 1;
int mid = (tr[u].l + tr[u].r) >> 1;
int res = 0;
if (a <= mid) res += count_lower(u<<1, a, b, v);
if (mid+1 <= b) res += count_lower(u<<1|1, a, b, v);
return res;
}

int get_kth(int u, int a, int b, int k) {
// 二分答案
int l = 0, r = 1e8;
while (l < r) {
int mid = (l+r+1) >> 1;
if (count_lower(u, a, b, mid) + 1 > k) r = mid-1;
else l = mid;
}
return r;
}

// 调用完后需要更新 a[p] = v
void modify(int u, int p, int v) {
splay::del(tr[u].rt, a[p]);
splay::insert(tr[u].rt, v);
if (tr[u].l == p && tr[u].r == p) return;
int mid = (tr[u].l + tr[u].r) >> 1;
if (p <= mid) modify(u<<1, p, v);
else modify(u<<1|1, p, v);
}

int pre(int u, int a, int b, int v) {
if (a <= tr[u].l && tr[u].r <= b) return splay::pre(tr[u].rt, v);
int mid = (tr[u].l + tr[u].r) >> 1, res = -INF;
if (a <= mid) res = max(res, pre(u<<1, a, b, v));
if (mid+1 <= b) res = max(res, pre(u<<1|1, a, b, v));
return res;
}

int suc(int u, int a, int b, int v) {
if (a <= tr[u].l && tr[u].r <= b) return splay::suc(tr[u].rt, v);
int mid = (tr[u].l + tr[u].r) >> 1, res = INF;
if (a <= mid) res = min(res, suc(u<<1, a, b, v));
if (mid+1 <= b) res = min(res, suc(u<<1|1, a, b, v));
return res;
}
};

int main() {
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i++) scanf("%d", &a[i]);
seg::build(1, 1, n);
while (m--) {
int opt, l, r, p, x;
scanf("%d", &opt);
if (opt == 1) {
scanf("%d%d%d", &l, &r, &x);
printf("%d\n", seg::count_lower(1, l, r, x) + 1);
}
else if (opt == 2) {
scanf("%d%d%d", &l, &r, &x);
printf("%d\n", seg::get_kth(1, l, r, x));
}
else if (opt == 3) {
scanf("%d%d", &p, &x);
seg::modify(1, p, x);
a[p] = x;
}
else if (opt == 4) {
scanf("%d%d%d", &l, &r, &x);
printf("%d\n", seg::pre(1, l, r, x));
}
else {
scanf("%d%d%d", &l, &r, &x);
printf("%d\n", seg::suc(1, l, r, x));
}
}
return 0;
}

探究:对于 seg::get_kth 二分的写法,下面这种是错的:

1
2
3
4
5
6
7
8
9
10
11
int get_kth(int u, int a, int b, int k) {
// 二分答案
int l = 0, r = 1e8;
while (l < r) {
int mid = (l+r) >> 1;
// 区别是 <= 换成 <
if (count_lower(u, a, b, mid) + 1 < k) l = mid+1;
else r = mid;
}
return r;
}

改了之后样例的第一个输入就过不了,原因在于当我们二分到答案之时,即样例中的 mid=2 时,由于有重复的 2 它的 rank(2)=2 会被认为是不满足条件的,所以会二分到 mid+1=3

如果用的是小于等于,这样右边界会跨过 3 这个值到达 2,相当于二分出来了一个右边界。

所以这告诉我们什么问题呢?考场上想不出来就照着样例看哪种写法能过就写哪个 hh,纠结这些是平时学习该做的事。

K大数查询(待补)