倍增

给定一棵包含 n 个节点的有根无向树,节点编号互不相同,但不一定是 1∼n。

有 m 个询问,每个询问给出了一对节点的编号 x 和 y 询问 x 与 y 的祖孙关系。

输入格式

输入第一行包括一个整数 表示节点个数;

接下来 n 行每行一对整数 a 和 b,表示 a 和 b 之间有一条无向边。如果 b 是 −1,那么 a 就是树的根;

第 n+2 行是一个整数 m 表示询问个数;

接下来 m 行,每行两个不同的正整数 x 和 y,表示一个询问。

输出格式

对于每一个询问,若 x 是 y 的祖先则输出 1,若 y 是 x 的祖先则输出 2,否则输出 0。

题目链接:AcWing 1172

f(i,j)f(i,j) 表示从结点 ii 开始向上走 2j2^j 步能走到的节点,depth(i)\text{depth}(i) 表示深度。

设超过跟节点的 f(i,j)=0,depth(0)=0f(i,j)=0, \text{depth}(0)=0,这样可以减少特判;递推关系求 f(i,j)=f[f(i,j1),j1]f(i,j)=f[f(i,j-1),j-1]

  1. 先让两个结点 a,ba, b 跳到同一层,如果相等直接返回。
  2. 让两个点同时向上跳,一直跳到最近公共祖先下一层。

每次跳的时候都将 kk 从大到小枚举,从二进制的角度来看一定能凑出来任意数字。

复杂度:预处理 O(nlogn)O(n \log n),查询 O(logn)O(\log n)

关于 f 数组的大小,这里可以举例说明:如果有 6=11026=110_2 应该让下标能达到 log26=2\lfloor\log_26\rfloor=2,因此应该开到 log2n+1\lfloor \log_2 n\rfloor+1 的大小。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
#include <iostream>
#include <queue>
#include <cstring>
using namespace std;

const int N = 40010, M = 2*N;
int depth[N], f[N][16], h[N], e[M], ne[M], idx;

void add(int a, int b) {
e[idx] = b, ne[idx] = h[a], h[a] = idx++;
}

void dfs(int u, int father) {
depth[u] = depth[father] + 1;
f[u][0] = father;
for (int k = 1; k <= 15; k++)
f[u][k] = f[f[u][k-1]][k-1];

for (int i = h[u]; ~i; i = ne[i]) {
int j = e[i];
if (!depth[j]) dfs(j, u);
}
}

int lca(int a, int b) {
// 保证 a 比 b 深
if (depth[a] < depth[b]) swap(a, b);
for (int k = 15; k >= 0; k--) {
if (depth[f[a][k]] >= depth[b]) a = f[a][k];
}
if (a == b) return a;
for (int k = 15; k >= 0; k--) {
if (f[a][k] != f[b][k]) a = f[a][k], b = f[b][k];
}
return f[a][0];
}

int main() {
int n, m, root;
memset(h, -1, sizeof h);
cin >> n;
while (n--) {
int a, b;
cin >> a >> b;
if (b == -1) root = a;
else add(a, b), add(b, a);
}
dfs(root, 0);

cin >> m;
while (m--) {
int a, b;
cin >> a >> b;
int p = lca(a, b);
if (p == a) puts("1");
else if (p == b) puts("2");
else puts("0");
}
return 0;
}

Tarjan LCA

给出 n 个点的一棵树,多次询问两点之间的最短距离。

  • 边是无向的。
  • 所有节点的编号是 1,2,…,n。

输入格式

第一行为两个整数 n 和 m。n 表示点数,m 表示询问次数;

下来 n−1 行,每行三个整数 x,y,k,表示点 x 和点 y 之间存在一条边长度为 k;

再接下来 m 行,每行两个整数 x,y,表示询问点 x 到点 y 的最短距离。

树中结点编号从 1 到 n。

输出格式

共 m 行,对于每次询问,输出一行询问结果。

题目链接:AcWing 1171

Tarjan LCA 算法是一个离线算法,意思是它必须知道所有查询之后一次性求出所有查询结果。

它的复杂度是 O(n+m)O(n+m),比前面提到的倍增算法 O(nlogn+mlogn)O(n\log n+m\log n) 更优。

  1. DFS 搜索所有结点,已经回溯过的标记为 2,正在搜索的标记为 1,未搜索的标记为 0。

  2. 将所有回溯过的点和它的父结点用并查集合并。

  3. 每当处理与当前正在搜索的点 u 相关的查询时,如果另一个点 x 已经被回溯过,那么它们的公共祖先就是 find(x)

  4. 距离为 d(x)+d(u)2d(p)d(x)+d(u)-2d(p) 其中 d(x)d(x) 为某结点到根节点的距离。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
#include <iostream>
#include <queue>
#include <cstring>
#include <vector>
using namespace std;

typedef pair<int, int> pii;

const int N = 1e4+10, M = 4e4+10;
int h[N], e[2*N], ne[2*N], w[2*N], idx;
int dist[N], mark[N], p[N], res[M];
vector<pii> query[N];

void add(int a, int b, int c) {
e[idx] = b, ne[idx] = h[a], w[idx] = c, h[a] = idx++;
}

int find(int x) {
if (p[x] != x) p[x] = find(p[x]);
return p[x];
}

void bfs(int root) {
queue<int> q;
q.push(root);
memset(dist, 0x3f, sizeof dist);
dist[root] = 0;

while (q.size()) {
int t = q.front(); q.pop();
for (int i = h[t]; ~i; i = ne[i]) {
int j = e[i];
if (dist[j] > dist[t] + w[i]) {
dist[j] = dist[t] + w[i];
q.push(j);
}
}
}
}

void tarjan(int u) {
mark[u] = 1;

for (int i = h[u]; ~i; i = ne[i]) {
int j = e[i];
if (mark[j] == 0) {
tarjan(j);
// 这里写不写 find(j) 都无所谓, p[j] 一定等于 j
p[j] = u;
}
}

for (pii q : query[u]) {
int x = q.first, index = q.second;
if (mark[x] == 2) {
res[index] = dist[u] + dist[x] - 2*dist[find(x)];
}
}

// 已回溯
mark[u] = 2;
}

int main() {
int n, m;
memset(h, -1, sizeof h);
cin >> n >> m;
for (int i = 1; i <= n; i++) p[i] = i;
for (int i = 0; i < n-1; i++) {
int a, b, c;
cin >> a >> b >> c;
add(a, b, c), add(b, a, c);
}
bfs(1);
for (int i = 0; i < m; i++) {
int a, b;
cin >> a >> b;
if (a != b) {
query[a].push_back({b, i});
query[b].push_back({a, i});
}
}
tarjan(1);
for (int i = 0; i < m; i++) cout << res[i] << endl;
return 0;
}

次小生成树

给定一张 N 个点 M 条边的无向图,求无向图的严格次小生成树。

设最小生成树的边权之和为 sum,严格次小生成树就是指边权之和大于 sum 的生成树中最小的一个。

输入格式

第一行包含两个整数 N 和 M。

接下来 M 行,每行包含三个整数 x,y,z,表示点 x 和点 y 之前存在一条边,边的权值为 z。

输出格式

包含一行,仅一个数,表示严格次小生成树的边权和。(数据保证必定存在严格次小生成树)

题目链接:AcWing 356

用 LCA 可以优化的主要是求树上一个环中的最大值和次大值这一步骤,其它都和 LCA 没什么关系。

求最大值和次大值一般来说模板都是这样:

1
2
3
4
for (int v : vals) {
if (v > m1) m2 = m1, m1 = v;
else if (m2 < v && v < m1) m2 = v;
}

先用 Kruskal 求出来最小生成树,然后枚举每条非树边,找出两个顶点上的树形成的半个环上的最大边和次大边 ,用非树边替换最大边或次大边后形成的所有生成树中一定包含次小生成树。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
#include <iostream>
#include <algorithm>
#include <cstring>
#include <queue>
using namespace std;

typedef long long ll;
const int N = 1e5+10, M = 3e5+10, INF = 0x3f3f3f3f;
int n, m;
int h[N], e[2*N], ne[2*N], w[2*N], idx;
int p[N], d1[N][17], d2[N][17], depth[N], fa[N][17];

struct Edge {
int a, b, w;
bool used;
bool operator<(const Edge& e) const {
return w < e.w;
}
} edge[M];

void add(int a, int b, int c) {
e[idx] = b, ne[idx] = h[a], w[idx] = c, h[a] = idx++;
}

int find(int x) {
if (p[x] != x) p[x] = find(p[x]);
return p[x];
}

// 求最小生成树长度的同时建图
ll kruskal() {
ll sum = 0;
for (int i = 1; i <= n; i++) p[i] = i;
memset(h, -1, sizeof h);
for (int i = 0; i < m; i++) {
int a = edge[i].a, b = edge[i].b, w = edge[i].w;
int pa = find(a), pb = find(b);
if (pa != pb) {
p[pa] = pb;
sum += w;
edge[i].used = true;
add(a, b, w), add(b, a, w);
}
}
return sum;
}

// 预处理 LCA 需要的数据
void bfs() {
queue<int> q;
q.push(1);
memset(depth, 0x3f, sizeof depth);
depth[0] = 0, depth[1] = 1;
while (q.size()) {
int t = q.front(); q.pop();
for (int i = h[t]; ~i; i = ne[i]) {
int j = e[i];
if (depth[j] > depth[t] + 1) {
depth[j] = depth[t] + 1;
q.push(j);

fa[j][0] = t, d1[j][0] = w[i], d2[j][0] = -INF;
for (int k = 1; k <= 16; k++) {
int mid = fa[j][k-1];
fa[j][k] = fa[mid][k-1];
int d[4] = {d1[j][k-1], d2[j][k-1], d1[mid][k-1], d2[mid][k-1]};
for (int u = 0; u < 4; u++) {
if (d[u] > d1[j][k]) d2[j][k] = d1[j][k], d1[j][k] = d[u];
else if (d[u] < d1[j][k] && d[u] > d2[j][k]) d2[j][k] = d[u];
}
}
}
}
}
}

int lca(int a, int b, int w) {
static int d[4*N];
int cnt = 0;
if (depth[a] < depth[b]) swap(a, b);
for (int k = 16; k >= 0; k--) {
if (depth[fa[a][k]] >= depth[b]) {
d[cnt++] = d1[a][k];
d[cnt++] = d2[a][k];
a = fa[a][k];
}
}

if (a != b) {
for (int k = 16; k >= 0; k--) {
if (fa[a][k] != fa[b][k]) {
d[cnt++] = d1[a][k];
d[cnt++] = d2[a][k];
d[cnt++] = d1[b][k];
d[cnt++] = d2[b][k];
a = fa[a][k], b = fa[b][k];
}
}
// 还要向上跳一级, d2[x][0] 恒等于 -INF
d[cnt++] = d1[a][0], d[cnt++] = d1[b][0];
}

int dist1 = -INF, dist2 = -INF;
for (int i = 0; i < cnt; i++) {
if (d[i] > dist1) dist2 = dist1, dist1 = d[i];
else if (d[i] < dist1 && d[i] > dist2) dist2 = d[i];
}

if (w > dist1) return w-dist1;
if (w > dist2) return w-dist2;

return INF;
}

int main() {
cin >> n >> m;
for (int i = 0; i < m; i++) {
int a, b, c;
cin >> a >> b >> c;
edge[i] = {a, b, c};
}
sort(edge, edge+m);
ll sum = kruskal();
bfs();
ll ans = 1e18;
for (int i = 0; i < m; i++) {
if (!edge[i].used) {
int a = edge[i].a, b = edge[i].b, w = edge[i].w;
ans = min(ans, sum + lca(a, b, w));
}
}
cout << ans << endl;
return 0;
}

换根 LCA

给定一颗 nn 个点的无根树,给 qq 组询问,每组询问包含 (u,v,w)(u,v,w),求以 ww 为根意义下的 LCA(u,v)LCA(u,v)

为了方便输入,第二行的第 ii 个点表示 ii 的父结点,保证 fa(1)=0fa(1)=0,即给出的是以 11 为根的树。

本题是在集训时 PPT 上看到的,并没在主流 OJ 上找到,我这里就自己写了个对拍验证是否正确了。

考虑以 11 为根时的 p=LCA(u,v),a=LCA(u,w),b=LCA(v,w)p=LCA(u,v),a=LCA(u,w),b=LCA(v,w)

  1. 如果 ww 在以 pp 为根的子树外面,答案就是 pp,此时 a=ba=b
  2. 如果 aaupu\to p 的路径上,答案是 aa,此时 b=pb=p
  3. 如果 bbvpv\to p 的路径上,答案是 bb,此时 a=pa=p

因为树上无环,不可能同时满足条件 2,32,3 我们可以归纳出如果 a=ba=b,说明答案是 pp,否则答案是 a,ba,b 中较深的那个。

1
2
3
4
5
6
int LCA(int u, int v, int w) {
int p = lca(u, v), a = lca(w, u), b = lca(w, v);
if (a == b) return p;
if (depth[a] > depth[b]) return a;
return b;
}

下面给出完整程序:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
#include <cstdio>
#include <cstring>
using namespace std;

const int N = 10010, M = 20010;
int h[N], e[M], ne[M], idx;
int fa[N][15], depth[N];
int n;

void add(int a, int b) {
e[idx] = b, ne[idx] = h[a], h[a] = idx++;
}

void dfs(int u, int f) {
depth[u] = depth[f] + 1;
fa[u][0] = f;
for (int k = 1; k < 15; k++)
fa[u][k] = fa[fa[u][k-1]][k-1];

for (int i = h[u]; ~i; i = ne[i]) {
int j = e[i];
if (j == f) continue;
dfs(j, u);
}
}

int lca(int a, int b) {
if (depth[a] < depth[b]) swap(a, b);
for (int k = 14; k >= 0; k--) {
if (depth[fa[a][k]] >= depth[b]) a = fa[a][k];
}
if (a == b) return a;
for (int k = 14; k >= 0; k--) {
if (fa[a][k] != fa[b][k]) a = fa[a][k], b = fa[b][k];
}
return fa[a][0];
}

int LCA(int u, int v, int w) {
int p = lca(u, v), a = lca(w, u), b = lca(w, v);
if (a == b) return p;
if (depth[a] > depth[b]) return a;
return b;
}

int main() {
memset(h, -1, sizeof h);
scanf("%d", &n);
for (int i = 1, j; i <= n; i++) {
scanf("%d", &j);
if (j) add(i, j), add(j, i);
}
int q, u, v, w;
scanf("%d", &q);
dfs(1, 0);
while (q--) {
scanf("%d%d%d", &u, &v, &w);
printf("%d\n", LCA(u, v, w));
}
return 0;
}

对拍程序中只是每次读入都重新预处理一遍:

1
2
3
4
5
while (q--) {
scanf("%d%d%d", &u, &v, &w);
dfs(w, 0);
printf("%d\n", lca(u, v));
}

数据生成程序:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
#include <cstdio>
#include <cstdlib>
using namespace std;

const int N = 10010;
int fa[N];

int main() {
srand(20230930);
int n = rand() % 10000 + 1;
fa[1] = 0;
for (int i = 2; i <= n; i++) {
int f;
fa[i] = rand() % (i-1) + 1;
}
printf("%d\n", n);
for (int i = 1; i <= n; i++) printf("%d ", fa[i]);
printf("\n");
int q = rand() % 10000 + 1;
printf("%d\n", q);
for (int i = 1, u, v, w; i <= q; i++) {
u = rand() % n + 1, v = rand() % n + 1, w = rand() % n + 1;
printf("%d %d %d\n", u, v, w);
}
return 0;
}

检查程序(MacOS 环境,Windows 应该要用 fc 之类的):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
#include <cstdio>
#include <cstdlib>
#include <ctime>
using namespace std;

int main() {
srand(time(0));
for (int i = 1; i <= 100; i++) {
system("./data > lca.in");
system("./force < lca.in > lca.ans");
system("./std < lca.in > lca.out");
if (system("diff lca.out lca.ans")) {
printf("Wrong answer on #%d\n", i);
break;
}
else printf("Accepted on #%d\n", i);
}
return 0;
}

可自行验证。