NOIP2014 联合权值

题意简述

给定一棵无根树,每个点有一个点权 w ,定义一个距离为 2 的有序点对的联合权值为两个点的点权之积,求图上的所有联合权值之和。

换句话说,求:

\sum_{\mathrm{dist}(u, v) = 2} w_u \times w_v

解题思路

一个经典小套路:枚举中间点。

考虑枚举一个点 m ,计算与它相连的所有点的联合权值之和,因为有这个中间点的存在,所以与它相连的所有点之间两两距离一定为 2。

那么问题就转化为了:每次枚举一个 m_i ,求

\sum_{i = 1}^{n} \sum_{\mathrm{dist}(u, m_i) = \mathrm{dist}(m_i, v) = 1}w_u\times w_v

这个柿子很容易让我们联想到 \sum_{i = 1}^na_i \times \sum_{i = 1}^nb_i 的展开形式,所以很容易发现,上式等于

\sum\left(\sum_{\mathrm{dist}(u, m_i) = 1}w_u \times \sum_{\mathrm{dist}(m_i, v) = 1}w_v - \sum_{\mathrm{dist}(u, m_i) = 1}w_u^2\right)

说人话就是:
我们每次枚举一个中间点 m ,那么与 m 连接的点集能产生的联合权值为 点集里所有点权的和的平方 减去 点集里所有点权的平方和。把这个加起来输出就完事了。


还有一个最大值,怎么求?

显然我们拿权值的最大值乘上权值的次大值是最优的。顺便维护一下就好了。

代码实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdio>
#include <string>
#include <vector>

#define DEBUG(x) std::cerr << #x << " = " << x << std::endl;
#define forall(G,u) for (int i_ = 0, __for_siz__ = (int) G[u].size(); i_ < __for_siz__; ++i_)

using std::cin;
using std::cout;
using std::endl;

inline int read() {
int s = 0, x = 1; char ch = getchar();
while (!isdigit(ch)) { if (ch == '-') x = -x; ch = getchar(); }
while (isdigit(ch)) { s = s * 10 + ch - '0'; ch = getchar(); }
return s * x;
}

const int MAXN = 200000 + 10;

std::vector<int> G[MAXN];

int n;
int val[MAXN];
long long int ans, mxans;

int main() {
n = read();
for (int i = 1; i <= n - 1; ++i) {
int u = read(); int v = read();
G[u].push_back(v);
G[v].push_back(u);
}
for (int i = 1; i <= n; ++i) {
val[i] = read();
}
for (int mid = 1; mid <= n; ++mid) {
if (G[mid].size() < 2) continue;
long long int mans = 0, fmans = 0;
int maxx = 0, tmax = -1;
for (auto v : G[mid]) {
mans += val[v]; fmans += 1ll * val[v] * val[v];
if (val[v] > maxx) {
tmax = maxx; maxx = val[v];
} else if (val[v] > tmax) {
tmax = val[v];
}
}
(ans += ((mans * mans - fmans))) %= 10007;
mxans = std::max(mxans, 1ll * maxx * tmax);
}
printf("%lld %lld\n", mxans, ans);
return 0;
}