最近公共祖先 LCA 算法学习笔记

两个结点找共同的爸爸

LCA 的概念

图论计算机科学中,最近公共祖先(英语:lowest common ancestor)是指在一个或者有向无环图中同时拥有v和w作为后代的最深的节点。

——Wikipedia

看不懂没关系

简单的来说,就是两个节点v和w的最近的祖先节点

如下图

image

6和7的LCA是2,3和7的LCA是1

LCA 的求法

暴力求解

让他们一步一步往上爬,直到相遇

节点背着那重重的编号呀 一步一步地往上爬 ——《蜗牛与黄鹂鸟》

显然,这样的算法会T到飞起

所以我们要使用倍增优化

倍增求解

所谓倍增,就是按2的倍数来增大,也就是跳 1、2、4 、8 、16、32 …

但是在这里,我们要考虑开倒车从大到小跳

因为如果我们从小到大跳,就会出现要「回溯」的情况,因为我们不一定能精准地跳,而从大到小跳可以避开这种情况

图片来自cnblogs

图源cnblogs

对于上面这一棵更复杂的树,我们考虑17和18的LCA

1
2
17 ->(跳4) 3
18 ->(跳4) 5 ->(跳1) -> 3

是不是快多了,跳的次数大大减小

时间复杂度 O(nlogn)

LCA 的代码 & 实现流程

实现流程

首先我们要记录各个点的深度 depth[\ ] 和它们 2^i 级的祖先 father[\ ][\ ]

depth[i] 表示 i 点的深度, father[i][j] 表示 i 点的 2^i 级的祖先

1
2
3
4
5
6
7
8
9
10
11
// 预处理
void dfsInit(int root, int fa) {
depth[root] = depth[fa] + 1;
father[root][0] = fa;
for (int i = 1; (1 << i) <= depth[root]; ++i) {
father[root][i] = father[father[root][i-1]][i-1];
}
for (int e = head[root]; e; e = edge[e].next) {
if (edge[e].prev != fa) dfsInit(edge[e].prev, root);
}
}

接着我们就可以找LCA辣

对了,我们可以让它跑得更快

1
2
3
4
// 提前预处理出log2i + 1的值
for (int i = 1; i <= n; ++i) {
lg[i] = lg[i-1] + (1 << lg[i-1] == i);
}

在求 LCA 之前,我们先让两个节点蹦到同一层


但是跳的时候不能直接跳到 LCA 上,要跳到 LCA - 1 上,再输出 当前的父节点 就行了

因为直接蹦到 LCA 上可能会出现「误判」,比如上图中 4 8 ,若不判断,则在跳的时候会输出1,但是答案是3

所以我们就可以让它们跳到 2 5 ,然后输出父节点

1
2
3
4
5
6
7
8
9
10
11
12
int LCA(int x, int y) {
// 我们设x的深度大于y的深度
if (depth[x] < depth[y]) swap(x, y);
while (depth[x] > depth[y])
x = father[x][lg[depth[x] - depth[y]] - 1];
if (x == y) return x; // x 是 y 的祖先
for (int i = lg[depth[x]]; i >= 0; --i) {
if (father[x][i] != father[y][i]) x = father[x][i], y = father[y][i];
// 不相等就往上跳
}
return father[x][0];
}

完整代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cctype>
using namespace std;

const int MAXN = 500000 + 10;
const int MAXM = 500000 + 10;

struct Edge {
int prev, next;
} edge[MAXM * 2];

int head[MAXN], father[MAXN][22], lg[MAXN], depth[MAXN];
int cnt, n, m, s;

inline int getint() {
int s = 0, x = 1;
char ch = getchar();
while (!isdigit(ch)) {
if (ch == '-') x = -1;
ch = getchar();
}
while (isdigit(ch)) {
s = s * 10 + ch - '0';
ch = getchar();
}
return s * x;
}

inline void putint(int x, bool returnValue) {
if (x < 0) {
x = -x;
putchar('-');
}
if (x >= 10) putint(x / 10, false);
putchar(x % 10 + '0');
if (returnValue) putchar('\n');
}

inline void addEdge(int prev, int next) {
edge[++cnt].prev = prev;
edge[cnt].next = head[next];
head[next] = cnt;
}

// 预处理
void dfsInit(int root, int fa) {
depth[root] = depth[fa] + 1;
father[root][0] = fa;
for (int i = 1; (1 << i) <= depth[root]; ++i) {
father[root][i] = father[father[root][i-1]][i-1];
}
for (int e = head[root]; e; e = edge[e].next) {
if (edge[e].prev != fa) dfsInit(edge[e].prev, root);
}
}

int LCA(int x, int y) {
// 我们设x的深度大于y的深度
if (depth[x] < depth[y]) swap(x, y);
while (depth[x] > depth[y])
x = father[x][lg[depth[x] - depth[y]] - 1];
if (x == y) return x; // x 是 y 的祖先
for (int i = lg[depth[x]]; i >= 0; --i) {
if (father[x][i] != father[y][i]) x = father[x][i], y = father[y][i];
// 不相等就往上跳
}
return father[x][0];
}

int main(int argc, char *const argv[]) {
n = getint(), m = getint(), s = getint();
for (int i = 1; i < n; ++i) {
int prev = getint(), next = getint();
addEdge(prev, next);
addEdge(next, prev);
}
dfsInit(s, 0);
for (int i = 1; i <= n; ++i) {
lg[i] = lg[i-1] + (1 << lg[i-1] == i);
}
for (int i = 1; i <= m; ++i) {
int x = getint(), y = getint();
putint(LCA(x, y), true);
}
return 0;
}