Black Nodes in Subgraphs

原题:Black Nodes in Subgraphs

原题大意

给你一棵由 N 个节点构成的树 T。节点按照 1 到 N 编号,每个节点要么是白色,要么是黑色。有 Q 组询问,每组询问形如 (s, b)。你需要检查是否存在一个连通子图,其大小恰好是 s,并且包含恰好 b 个黑色节点。

输入

输入第一行,包含一个整数 T,表示测试数据组数。对于每组测试数据:
第一行包含两个整数 N 和 Q,分别表示树的节点个数和询问个数。
接下来 N - 1 行,每行包含两个整数 ui 和 vi,表示在树中 ui 和 vi 之间存在一条边。
接下来一行包含 N 个整数,c1, c2, … , cN。如果 ci 为 0 表示第 i 个节点是白色的,如果 ci 为
1 表示第 i 个节点是黑色的。
接下来 Q 行,每行包含两个整数 si 和 bi,表示一组形如 (si, bi) 的询问。
对于每组询问输出一行字符串表示答案,其中 Yes 表示存在一个符合要求的连
通子图,No 表示不存在。
1 <= T <= 5, 2 <= N <= 5e3, 1 <= Q <= 1e5, 1 <= ui, vi <= N。
0 <= ci <= 1, 0 <= bi <= N, 1 <= si <= N, bi <= si。

算法分析

观察到一个现象,对于一个子树,在子树中子图的点数确定的情况下,可行的黑点数是一个连续的区间。那么很自然地用f[i][j]记下在子树i中用了j个点的情况下最多和最少的黑点数,背包一下就可以了。
剩下一个问题,就是复杂度。你会发现每个点对在dp的过 程中都只出现一次,所以复杂度是n^2的。

程序代码

#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <ctime>
#include <cmath>
#include <iostream>
#include <algorithm>

using namespace std;

struct node {
    node *next;
    int where;
} *first[5001], a[10001];
int f[5001][5001], A[5001];
int v[5001][5001], g[5001], e[5001];
int l, size[5001], test, n, Q, dist[5001], c[5001];

//邻接表建图
inline void makelist(int x, int y) {
    a[++l].where = y;
    a[l].next = first[x];
    first[x] = &a[l];
}

int main() {
    //freopen("input.txt", "r", stdin);
    scanf("%d", &test);
    for (;test--;)
    {
        memset(first, 0, sizeof(first));
        l = 0;
        int Q;
        scanf("%d", &n);
        scanf("%d", &Q);
        for (int i = 1; i < n; i++) 
        {
            int x, y;
            scanf("%d%d", &x, &y);
            makelist(x, y);
            makelist(y, x);
        }
/*        for(int i = 0; i <= n; i++){
            for(node *t = first[i];t;t=t->next){
                printf("%d ", t->where);
            }
            printf("\n");
        }
*/

        for (int i = 1; i <= n; i++)
            scanf("%d", &A[i]);
        memset(dist, 0, sizeof(dist));
        c[1] = 1; dist[1] = 1;
        //c从树根往叶子存结点,类似于bfs层次遍历
        //dist存距离或者深度,便于下面确定相邻两点谁是父亲谁是子孙

        for (int k = 1, l = 1; l <= k; l++)
        {
            int m = c[l];
            for (node *x = first[m]; x; x = x->next)
                if (!dist[x->where])
                    dist[x->where] = dist[m] + 1,c[++k] = x->where;
        }
/*        printf("-------------\n");
        for(int i = 1; i <= n; i++)
            printf("%d ", dist[i]);
        printf("-------------\n");*/

        memset(f, 127, sizeof(f));//max
        memset(v, 255, sizeof(v));//-1
        memset(g, 127, sizeof(g));//max
        memset(e, 255, sizeof(e));//-1
//        printf("%d %d %d %d", f[1][1], v[1][1], g[1], e[1]);


        for(int i = n; i; --i)
        {
            int m = c[i];
            size[m] = 1;//已知的点数
            if (A[m])
                f[m][1] = v[m][1] = 1;
            else
                f[m][1] = v[m][1] = 0;
            for (node *x = first[m]; x; x = x->next)
                if (dist[x->where] == dist[m] + 1)
                {
                    for (int j = size[m]; j >= 0; j--)
                        //也可以这样写for(int k = 0; k <= size[x->where]; k++)                       
                        for (int k = size[x->where]; k >= 0; k--)
                            f[m][j + k] = min(f[m][j + k], f[m][j] + f[x->where][k]),
                            v[m][j + k] = max(v[m][j + k], v[m][j] + v[x->where][k]);
                    size[m] += size[x->where];//已知的点数增加
                }
            f[m][0] = v[m][0] = 0;
            for (int j = 0; j <= size[m]; j++)
                g[j] = min(g[j], f[m][j]),
                e[j] = max(e[j], v[m][j]);
            //printf("%d\n", m);
            //for (int j = 0; j <= size[m]; j++)
            //  printf("%d %d %d\n", j, v[m][j], f[m][j]);
        }
        for (;Q--;)
        {
            int x, y;
            scanf("%d%d", &x, &y);
            if (y >= g[x] && y <= e[x])
                printf("Yes\n");
            else
                printf("No\n");
        }
    }
}