严格次短路

题意简述

给定一张无向联通图,求一条 1 到 n 的路径使得它的长度仅次于最短路。

解题思路

其实这个和最短路差不多,无非就是 BFS 的时候再维护一个次短路数组 dist[u][1] 表示 1 点到 u 点的次短路,主要的代码细节在维护这个数组上。

代码细节

次小值相对难维护一些,因为需要考虑它从哪里转移过来。
以下代码均令 dist[u][0] 表示最短路,dist[u][1] 表示次短路,而且都是从 1 开始的。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
for (Edge E : G[u]) {
int v = E.v, w = E.w;
// 最短路有更短的
if (dist[v][0] > dist[u][0] + w) {
// 原先的最短路成了次短路
dist[v][1] = dist[v][0];
dist[v][0] = dist[u][0] + w;
if (!inQueue[v]) inQueue[v] = true, q.push(v);
}
if (dist[v][0] < dist[u][0] + w && dist[u][0] + w < dist[v][1]) {
// 这条路比次短路更短,但是不会影响最短路
dist[v][1] = dist[u][0] + w;
if (!inQueue[v]) inQueue[v] = true, q.push(v);
}
if (dist[v][1] > dist[u][1] + w) {
// 现在的次短路比原先的次短路更短
dist[v][1] = dist[u][1] + w;
if (!inQueue[v]) inQueue[v] = true, q.push(v);
}
}

其他的地方都差不多吧。

代码实现

例题:[USACO06NOV]Roadblocks G,洛谷上有。注意这题是从 n 跑次短路。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
// Accepted

#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdio>
#include <string>
#include <vector>
#include <queue>

#define DEBUG(x) std::cerr << #x << " = " << x << std::endl;

using std::cin;
using std::cout;
using std::endl;

inline int read() {
int s = 0, x = 1; char ch = getchar();
while (!isdigit(ch)) { if (ch == '-') x = -x; ch = getchar(); }
while (isdigit(ch)) { s = s * 10 + ch - '0'; ch = getchar(); }
return s * x;
}

const int MAXN = 5000 + 10;

struct Edge {
int v, w;
Edge(int _v = 0, int _w = 0) : v(_v), w(_w) {}
};

std::vector<Edge> G[MAXN];

int n, m;
// dist[][0] 是最短路,dist[][1] 是次短路
int dist[MAXN][2];

void spfa(int s) {
static bool inq[MAXN], vis[MAXN];
memset(dist, 0x3f, sizeof dist);
memset(inq, 0, sizeof inq);
dist[s][0] = 0;
std::queue<int> q; q.push(s); inq[s] = true;
while (!q.empty()) {
int u = q.front(); q.pop(); inq[u] = false;
for (auto e : G[u]) {
int v = e.v, w = e.w;
if (dist[v][0] > dist[u][0] + w) {
dist[v][1] = dist[v][0];
dist[v][0] = dist[u][0] + w;
if (!inq[v]) inq[v] = true, q.push(v);
}
if (dist[v][0] < dist[u][0] + w && dist[v][1] > dist[u][0] + w) {
// 这个地方不能写等号,因为要求严格次小
dist[v][1] = dist[u][0] + w;
if (!inq[v]) inq[v] = true, q.push(v);
}
if (dist[v][1] > dist[u][1] + w) {
dist[v][1] = dist[u][1] + w;
if (!inq[v]) inq[v] = true, q.push(v);
}
}
}
}

int main() {
n = read(); m = read();
for (int i = 1; i <= m; ++i) {
int u = read(); int v = read(); int w = read();
G[u].push_back(Edge(v, w)); G[v].push_back(Edge(u, w));
}
spfa(n);
printf("%d\n", dist[1][1]);
return 0;
}