严格次小生成树

题意简述

给定一个无向连通图,求出它的一棵生成树,满足边权之和仅次于最小生成树。

解题思路

一个很自然的想法就是,枚举每一条选中的边,用另外一条没选中的边替换,求出次小值。

考虑怎么优化这个算法。


首先换一下枚举顺序,改为枚举每一条没选中的边,用它来替换一条选中的边,因为维护树上的信息比维护一堆零散的边的信息要简单得多。

接着我们来考虑要更换哪条边。很显然,要更换的一定是最大的那条边,这样可以对最小生成树的边权和产生的影响最小,让次小生成树「只比它大恰好一点点」。那么我们维护一个最小生成树的链上边权最大值就没了。


有小朋友可能会问:如果换边的时候搞成了一个基环森林咋办?
答案是:对于一条未选的边 (u, v, w) ,我们只替换 u \rightarrow v 这条链上的所有边。这个比较显然,自己画图理解一下。

行了,本题的大概思路到这里就结束了,大家洗洗睡吧。


真的吗?
再来读一遍题:严格次小生成树,生成树权值之和要严格小于最小生成树
如果按照上面的操作,我们可能会碰到这样一种情况——当前枚举的未选择边边权 = 最小生成树中的最大边权,此时对边权和产生的影响恰好为零,此时我们求出的仅仅是另外一棵最小生成树而已。

怎么办?


很简单,再维护一个次大值即可,如果最大值等于枚举的边权,就拿次大值算影响。

行了,本题到这里真的结束了。

代码细节

主要的实现细节是求链上最大值和次大值。这个用倍增维护比较舒服。

关于如何求最大值就不说了,和求倍增的fa[u][i]都差不多。

以下的代码都是令 maxw[u][i][0] 表示 u 向上跳 2^i 层经过的边权最大值,maxw[u][i][1] 表示 u 向上跳 2^i 层经过的边权次大值。

1
2
3
4
for (int  i  =  1; depth[u] >= twon[i]; ++i) {
// 更新 fa
fa[u][i] = fa[fa[u][i - 1]][i - 1];
}

维护次大值比较麻烦。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
// 更新最大值、次大值
for (int i = 1; depth[u] >= twon[i]; ++i) {;
// maxw[u][i][0] 可以从 maxw[u][i - 1][0] 和 maxw[fa[u][i - 1]][i - 1][0] 继承过来
maxw[u][i][0] = std::max(maxw[u][i - 1][0], maxw[fa[u][i - 1]][i - 1][0])
// maxw[u][i][1] 可以从 maxw[u][i - 1][1] maxw[fa[u][i - 1]][i - 1][1]
// 和 maxw[u][i - 1][0] maxw[fa[u][i - 1]][i - 1][0] 中的较小值 继承过来
if (maxw[u][i - 1][0] == maxw[fa[u][i - 1]][i - 1][0]) {
// 为了保证严格次小,当 maxw[u][i - 1][0] == maxw[fa[u][i - 1]][i - 1][0] 时
// 无法转移 maxw[u][i][1]
maxw[u][i][1] = std::max(maxw[u][i - 1][1], maxw[fa[u][i - 1]][i - 1][1]);
} else maxw[u][i][1] = std::max(
std::min(maxw[u][i - 1][0], maxw[fa[u][i - 1]][i - 1][0]),
std::max(maxw[u][i - 1][1], maxw[fa[u][i - 1]][i - 1][1])
);
}

接下来是查询。查询可以分成两个部分 u \rightarrow \text{LCA} v \rightarrow \text{LCA} ,就成了一个链上倍增。

1
2
3
4
5
6
7
8
9
10
11
12
13
// 最小值,次小值,高度差(x 为 u 和 v 之一,y 为 LCA)
int mx = 0, smx = 0, k = depth[x] - depth[y];
for (int i = 0; i <= LOGM; ++i) {
if (k & (1 << i)) {
if (smx < maxw[x][i][1]) {
smx = maxw[x][i][1];
}
if (mx < maxw[x][i][0]) {
smx = std::max(smx, mx);
mx = maxw[x][i][0];
}
}
}

代码实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
// Accepted

#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdio>
#include <string>
#include <vector>

#define DEBUG(x) std::cerr << #x << " = " << x << std::endl;

using std::cin;
using std::cout;
using std::endl;

inline int read() {
int s = 0, x = 1; char ch = getchar();
while (!isdigit(ch)) { if (ch == '-') x = -x; ch = getchar(); }
while (isdigit(ch)) { s = s * 10 + ch - '0'; ch = getchar(); }
return s * x;
}

const int MAXN = 100000 + 10;
const int MAXM = 300000 + 10;
const int MAXLOG = 17 + 3;
const int LOGM = 17;

struct Edge {
int v, w; Edge(int _v = 0, int _w = 0) : v(_v), w(_w) {}
};

struct REdge {
int u, v, w; bool chose;
REdge() { u = v = w = 0; chose = 0; }
void _r() {
u = read(); v = read(); w = read();
}
bool operator < (const REdge &th) const {
return w < th.w;
}
} edge[MAXM];

struct DSU {
int u[MAXN];
DSU() { memset(u, 0, sizeof u); }
int find(int x) { return !u[x] ? x : u[x] = find(u[x]); }
bool merge(int x, int y) {
x = find(x); y = find(y);
if (x == y) return false;
u[x] = y; return true;
}
} kk;

int n, m;
// twon[i] = 2^i
int twon[MAXLOG];
// 生成树上做 LCA
std::vector<Edge> G[MAXN];
int fa[MAXN][MAXLOG], depth[MAXN];
// maxw[u][i][0] 为 u 往上跳 2^i 次经过的边权最大值,maxw[][][1] 为次大值
int maxw[MAXN][MAXLOG][2];
long long int minDelta = 0x3f3f3f3f3f3f3f3f;

long long int Kruskal() {
long long int ans = 0; int ch = 0;
std::sort(edge + 1, edge + 1 + m);
for (int i = 1; i <= m; ++i) {
if (kk.merge(edge[i].u, edge[i].v)) {
++ch;
edge[i].chose = true;
ans += 1ll * edge[i].w;
G[edge[i].u].push_back(Edge(edge[i].v, edge[i].w));
G[edge[i].v].push_back(Edge(edge[i].u, edge[i].w));
}
if (ch == n - 1) break;
}
return ans;
}

void dfs(int u) {
// depth[u][0] maxw[u][0][0/1] 的处理放在枚举边的时候更方便一些
for (int i = 1; depth[u] >= twon[i]; ++i) {
// 更新 fa
fa[u][i] = fa[fa[u][i - 1]][i - 1];
// 更新最大值、次大值
// maxw[u][i][0] 可以从 maxw[u][i - 1][0] 和 maxw[fa[u][i - 1]][i - 1][0] 继承过来
// maxw[u][i][1] 可以从 maxw[u][i - 1][1] maxw[fa[u][i - 1]][i - 1][1]
// 和 maxw[u][i - 1][0] maxw[fa[u][i - 1]][i - 1][0] 中的较小值 继承过来
// 为了保证严格次小,当 maxw[u][i - 1][0] == maxw[fahter[u][i - 1]][i - 1][0] 时
// 无法转移 maxw[u][i][1]
maxw[u][i][0] = std::max(maxw[u][i - 1][0], maxw[fa[u][i - 1]][i - 1][0]);
if (maxw[u][i - 1][0] == maxw[fa[u][i - 1]][i - 1][0]) {
maxw[u][i][1] = std::max(maxw[u][i - 1][1], maxw[fa[u][i - 1]][i - 1][1]);
} else maxw[u][i][1] = std::max(
std::min(maxw[u][i - 1][0], maxw[fa[u][i - 1]][i - 1][0]),
std::max(maxw[u][i - 1][1], maxw[fa[u][i - 1]][i - 1][1])
);
}
for (int i = 0, siz = (int) G[u].size(); i < siz; ++i) {
int v = G[u][i].v, w = G[u][i].w;
if (v == fa[u][0]) continue;
maxw[v][0][0] = w; maxw[v][0][1] = -1;
fa[v][0] = u; depth[v] = depth[u] + 1;
dfs(v);
}
}

void initLCA() {
twon[0] = 1;
for (int i = 1; i <= LOGM; ++i) {
twon[i] = twon[i - 1] * 2;
}
depth[1] = 1;
dfs(1);
}

int LCA(int x, int y) {
if (depth[x] < depth[y]) std::swap(x, y);
int k = depth[x] - depth[y];
for (int i = 0; i <= LOGM; ++i) {
if (k & (1 << i)) x = fa[x][i];
}
if (x == y) return x;
for (int i = LOGM; i >= 0; --i) {
if (fa[x][i] != fa[y][i]) x = fa[x][i], y = fa[y][i];
}
return fa[x][0];
}

void GetLeastDelta(int x, int y, int dw) {
int mx = 0, smx = 0, k = depth[x] - depth[y];
for (int i = 0; i <= LOGM; ++i) {
if (k & (1 << i)) {
if (smx < maxw[x][i][1]) {
smx = maxw[x][i][1];
}
if (mx < maxw[x][i][0]) {
smx = std::max(smx, mx);
mx = maxw[x][i][0];
}
}
}
if (mx == dw) minDelta = std::min(minDelta, 1ll * (dw - smx));
else minDelta = std::min(minDelta, 1ll * (dw - mx));
}

int main() {
n = read(); m = read();
for (int i = 1; i <= m; ++i) edge[i]._r();
long long int mst = Kruskal();
// printf("%d\n", mst);
initLCA();
for (int i = 1; i <= m; ++i) {
if (!edge[i].chose) {
int l = LCA(edge[i].u, edge[i].v);
GetLeastDelta(edge[i].u, l, edge[i].w);
GetLeastDelta(edge[i].v, l, edge[i].w);
}
}
printf("%lld\n", mst + minDelta);
return 0;
}