洛谷P3177《[HAOI2015]树上染色》

我推式子推了半个小时。。。

Description

有一棵点数为N的树,树边有边权。给你一个在0~N之内的正整数K,你要在这棵树中选择K个点,将其染成黑色,并

将其他的N-K个点染成白色。将所有点染色后,你会获得黑点两两之间的距离加上白点两两之间距离的和的收益。

问收益最大值是多少。

Input

第一行两个整数N,K。

接下来N-1行每行三个正整数fr,to,dis,表示该树中存在一条长度为dis的边(fr,to)。

输入保证所有点之间是联通的。

N<=2000,0<=K<=N

Output

输出一个正整数,表示收益的最大值。

Sample Input

1
2
3
4
5
5 2  
1 2 3
1 5 1
2 3 1
2 4 2

Sample Output

1
17  

【样例解释】
将点1,2染黑就能获得最大收益。

解析

第一反应设 \text{dp[i][j]} 表示以 i 为根的子树选 j 个黑点的最大收益
但是是错的

康了一眼这个我就瞬间明白了
关于式子的推导,组成部分的意义,还有循环顺序的选择,这篇文章都讲得很清楚

老规矩,题解都在代码里

代码实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
//
// BZOJ4033.cpp
// Title: [HAOI2015]树上染色
// Alternatives: Luogu-P3177
// Debugging
//
// Created by HandwerSTD on 2019/7/31.
// Copyright © 2019 HandwerSTD. All rights reserved.
//

#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdio>
#include <vector>

#define FILE_IN(__fname) freopen(__fname, "r", stdin)
#define FILE_OUT(__fname) freopen(__fname, "w", stdout)
#define rap(a,s,t,i) for (int a = s; a <= t; a += i)
#define basketball(a,t,s,i) for (int a = t; a > s; a -= i)
#define countdown(s) while (s --> 0)
#define IMPROVE_IO() std::ios::sync_with_stdio(false)

using std::cin;
using std::cout;
using std::endl;

typedef long long int lli;

int getint() { int x; scanf("%d", &x); return x; }
lli getll() { long long int x; scanf("%lld", &x); return x; }

/**
*
* 参考资料:https://www.luogu.org/blog/ahaha254/solution-p3177
* 关于 val(x,y) 和枚举顺序的解释可以康一康这篇文章
*
* 设 f[i][j] 表示以 i 为根的子树中取了 j 个黑点「对答案的贡献」
* 转移方程:
* f[u][j] = max(
* f[u][j],
* f[u][j - k] + f[v][k] + val(u,v)
* )
* 其中 v 是 u 的儿子,k 是枚举出来的
* j = min(m,size(x)) -> 0, k = 0 -> min(j,size(y))
* 其中 val(x,y) 表示边 (x,y) 对答案的贡献,它等于
* 「该边两边黑点数量的乘积 乘以 边长 加上 该边两边白点数量的乘积 乘以 边长」
* 也就是 val(x,y) = k * (m - k) * weight(x,y) + (size(y) - k) * ((n - m) - (size(y) - k)) * weight(x,y)
* 其中 m 是总的黑点数,k 是边 (x,y) 另一边的黑点数,那么 (m - k) 就是这一边的黑点数
* size(y) 是以 y 为根的子树的大小,也就意味着 (size(y) - k) 是另一边的白点数(另一边的肯定不是黑点就是白点)
* (n - m) 是总的白点数,(size(y) - k) 是另一边的白点数,也就意味着 ((n - m) - (size(y) - k)) 是边 (x,y) 这一边的白点数(白点肯定不在那边就在这边)
*
*/

const int MAXN = 2000 + 10;
const int MAXK = 2000 + 10;

struct Edge {
int v;
lli w;

Edge(int v = 0, lli w = 0) : v(v), w(w) {}
};

std::vector<Edge> head[MAXN];

int n, m, size[MAXN];
lli dp[MAXN][MAXK];
bool vis[MAXN][MAXK];
// dp 数组大概 31 MB
// size 数组大概 8 KB
// vis 数组大概 4 MB

void DFS(int root = 1, int father = 0) {
size[root] = 1;
vis[root][0] = vis[root][1] = true;
for (int i = 0, siz = (int) head[root].size(); i < siz; ++i) {
int next = head[root][i].v;
if (next == father) continue;
DFS(next, root);
size[root] += size[next];
}
for (int i = 0, siz = (int) head[root].size(); i < siz; ++i) {
int next = head[root][i].v;
if (next == father) continue;
lli weight = head[root][i].w;
for (int j = std::min(m, size[root]); j >= 0; --j) {
int up = std::min(j, size[next]);
for (int k = 0; k <= up; ++k) {
if (!vis[root][j - k]) continue;
lli val = k * (m - k) * weight + (size[next] - k) * ((n - m) - (size[next] - k)) * weight;
dp[root][j] = std::max(dp[root][j], dp[root][j - k] + dp[next][k] + val);
vis[root][j] = true;
}
}
}
}

int main() {
n = getint(); m = getint();
rap (i, 1, n - 1, 1) {
int prev = getint(), next = getint(), weight = getint();
head[prev].push_back(Edge(next, weight));
head[next].push_back(Edge(prev, weight));
}
DFS();
printf("%lld\n", dp[1][m]);
return 0;
}