原题大意
给你一棵由 N 个节点构成的树 T。节点按照 1 到 N 编号,每个节点要么是白色,要么是黑色。有 Q 组询问,每组询问形如 (s, b)。你需要检查是否存在一个连通子图,其大小恰好是 s,并且包含恰好 b 个黑色节点。
输入
输入第一行,包含一个整数 T,表示测试数据组数。对于每组测试数据:
第一行包含两个整数 N 和 Q,分别表示树的节点个数和询问个数。
接下来 N - 1 行,每行包含两个整数 ui 和 vi,表示在树中 ui 和 vi 之间存在一条边。
接下来一行包含 N 个整数,c1, c2, … , cN。如果 ci 为 0 表示第 i 个节点是白色的,如果 ci 为
1 表示第 i 个节点是黑色的。
接下来 Q 行,每行包含两个整数 si 和 bi,表示一组形如 (si, bi) 的询问。
对于每组询问输出一行字符串表示答案,其中 Yes 表示存在一个符合要求的连
通子图,No 表示不存在。
1 <= T <= 5, 2 <= N <= 5e3, 1 <= Q <= 1e5, 1 <= ui, vi <= N。
0 <= ci <= 1, 0 <= bi <= N, 1 <= si <= N, bi <= si。
算法分析
观察到一个现象,对于一个子树,在子树中子图的点数确定的情况下,可行的黑点数是一个连续的区间。那么很自然地用f[i][j]记下在子树i中用了j个点的情况下最多和最少的黑点数,背包一下就可以了。
剩下一个问题,就是复杂度。你会发现每个点对在dp的过 程中都只出现一次,所以复杂度是n^2的。
程序代码
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <ctime>
#include <cmath>
#include <iostream>
#include <algorithm>
using namespace std;
struct node {
node *next;
int where;
} *first[5001], a[10001];
int f[5001][5001], A[5001];
int v[5001][5001], g[5001], e[5001];
int l, size[5001], test, n, Q, dist[5001], c[5001];
//邻接表建图
inline void makelist(int x, int y) {
a[++l].where = y;
a[l].next = first[x];
first[x] = &a[l];
}
int main() {
//freopen("input.txt", "r", stdin);
scanf("%d", &test);
for (;test--;)
{
memset(first, 0, sizeof(first));
l = 0;
int Q;
scanf("%d", &n);
scanf("%d", &Q);
for (int i = 1; i < n; i++)
{
int x, y;
scanf("%d%d", &x, &y);
makelist(x, y);
makelist(y, x);
}
/* for(int i = 0; i <= n; i++){
for(node *t = first[i];t;t=t->next){
printf("%d ", t->where);
}
printf("\n");
}
*/
for (int i = 1; i <= n; i++)
scanf("%d", &A[i]);
memset(dist, 0, sizeof(dist));
c[1] = 1; dist[1] = 1;
//c从树根往叶子存结点,类似于bfs层次遍历
//dist存距离或者深度,便于下面确定相邻两点谁是父亲谁是子孙
for (int k = 1, l = 1; l <= k; l++)
{
int m = c[l];
for (node *x = first[m]; x; x = x->next)
if (!dist[x->where])
dist[x->where] = dist[m] + 1,c[++k] = x->where;
}
/* printf("-------------\n");
for(int i = 1; i <= n; i++)
printf("%d ", dist[i]);
printf("-------------\n");*/
memset(f, 127, sizeof(f));//max
memset(v, 255, sizeof(v));//-1
memset(g, 127, sizeof(g));//max
memset(e, 255, sizeof(e));//-1
// printf("%d %d %d %d", f[1][1], v[1][1], g[1], e[1]);
for(int i = n; i; --i)
{
int m = c[i];
size[m] = 1;//已知的点数
if (A[m])
f[m][1] = v[m][1] = 1;
else
f[m][1] = v[m][1] = 0;
for (node *x = first[m]; x; x = x->next)
if (dist[x->where] == dist[m] + 1)
{
for (int j = size[m]; j >= 0; j--)
//也可以这样写for(int k = 0; k <= size[x->where]; k++)
for (int k = size[x->where]; k >= 0; k--)
f[m][j + k] = min(f[m][j + k], f[m][j] + f[x->where][k]),
v[m][j + k] = max(v[m][j + k], v[m][j] + v[x->where][k]);
size[m] += size[x->where];//已知的点数增加
}
f[m][0] = v[m][0] = 0;
for (int j = 0; j <= size[m]; j++)
g[j] = min(g[j], f[m][j]),
e[j] = max(e[j], v[m][j]);
//printf("%d\n", m);
//for (int j = 0; j <= size[m]; j++)
// printf("%d %d %d\n", j, v[m][j], f[m][j]);
}
for (;Q--;)
{
int x, y;
scanf("%d%d", &x, &y);
if (y >= g[x] && y <= e[x])
printf("Yes\n");
else
printf("No\n");
}
}
}