传送门

Enchant Arrow!

虚树显然。用虚树上的边长表示原树上的距离。

考虑每条虚树上的边的贡献。记边$(x,y)$在虚树上距离两端点最近的点分别为$u,v$。换根$DP$可求。

分类讨论:

  • $u=v$,原树上路径$(x,y)$的所有点包括挂着的子树都归$u$管理,直接查即可
  • $u\neq v$,原树上路径$(x,y)$一部分点包括其子树归$u$管理,剩下的归$v$管理。显然从$x$开始的$\dfrac{dis(u,v)+1}{2}-dis(u,x)$(根据$u,v$大小向上/下取整)个点包括其子树归$u$管理,剩下的归$v$管理。找个$k$级祖先子树和相减即可。比较懒倍增求的$k$级祖先。

有一个问题是虚树端点处会算重。可以只计算下面的端点,不管上面的端点。但是这样根节点的贡献算不上,建一个虚点和根节点连边作为新的根节点即可。

代码写了一天题解这么短,口嗨就是这么简单其实细节巨多

代码:

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cmath>
#include <cstring>
#include <vector>

#define maxn 300005
#define inf 0x3f3f3f3f
#define px putchar
#define pn px('\n')

using namespace std;

inline int read(){
    int x=0,y=0;
    char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')y=1;ch=getchar();}
    while(ch>='0'&&ch<='9')x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
    return y?-x:x;
}
int seg[maxn],siz[maxn],deep[maxn],root;
namespace origin{
    int top[maxn],son[maxn],fa[maxn][20],h[maxn],num,cnt;
    struct edge{int pre,to;}e[maxn<<1];
    inline void add(int from,int to){e[++num]=(edge){h[from],to},h[from]=num;}
    void dfs1(int node=root){
        siz[node]=1;
        for(register int i=1;i<20;++i)fa[node][i]=fa[fa[node][i-1]][i-1];
        for(register int i=h[node],x;i;i=e[i].pre){
            x=e[i].to;
            if(siz[x])continue;
            fa[x][0]=node,deep[x]=deep[node]+1,dfs1(x),siz[node]+=siz[x];
            if(siz[x]>siz[son[node]])son[node]=x;
        }
    }
    void dfs2(int node=root){
        seg[node]=++cnt;
        if(!son[node])return;
        top[son[node]]=top[node],dfs2(son[node]);
        for(register int i=h[node],x;i;i=e[i].pre){
            x=e[i].to;
            if(!seg[x])top[x]=x,dfs2(x);
        }
    }
    inline int lca(int x,int y){
        while(top[x]!=top[y])deep[top[x]]<deep[top[y]]?y=fa[top[y]][0]:x=fa[top[x]][0];
        return deep[x]<deep[y]?x:y;
    }
    inline int dis(int x,int y){return deep[x]+deep[y]-(deep[lca(x,y)]<<1);}
    inline int anc(int x,int k){
        if(!k)return x;
        for(register int i=19;~i;--i)if(k>>i&1)x=fa[x][i];
        return x;
    }
    inline int sum(int x,int k){return siz[anc(x,k)]-siz[x];}
}
namespace virt{
    int h[maxn],p[maxn],sta[maxn],ans[maxn],rec[maxn],s[maxn],num,n,top;
    bool key[maxn];
    struct edge{int pre,to,l;}e[maxn<<1];
    struct near{
        int d,x;
        near(int D=inf,int X=0){d=D,x=X;}
        bool operator < (const near &X)const{
            if(d!=X.d)return d<X.d;
            return x<X.x;
        }
        bool operator == (const near &X){return d==X.d&&x==X.x;}
        bool operator != (const near &X){return !(*this==X);}
        near operator + (const int &X){return near(d+X,x);}
    }f[maxn];
    inline bool cmp(int x,int y){return seg[x]<seg[y];}
    inline void add(int from,int to){
        int l=origin::dis(from,to);
        e[++num]=(edge){h[from],to,l},h[from]=num;
        e[++num]=(edge){h[to],from,l},h[to]=num;
    }
    void push(int node){
        h[node]=0;
        int l=origin::lca(node,sta[top]);
        if(l==sta[top]){sta[++top]=node;return;}
        while(seg[l]<seg[sta[top-1]])add(sta[top],sta[top-1]),--top;
        if(sta[top-1]!=l)h[l]=0,add(l,sta[top]),sta[top]=l;
        else add(sta[top],l),--top;
        sta[++top]=node;
    }
    void build(){
        sta[top=1]=1,h[1]=h[root]=num=0;
        sort(p+1,p+1+n,cmp);
        for(register int i=1;i<=n;++i){
            key[p[i]]=1,ans[p[i]]=0;
            if(p[i]==1)continue;
            push(p[i]);
        }
        for(register int i=1;i<top;++i)add(sta[i],sta[i+1]);
        add(root,1);
    }
    void dp(int node=root,int fa=0){
        s[node]=siz[node];
        if(key[node])f[node]=near(0,node);
        else f[node]=near();
        for(register int i=h[node],x;i;i=e[i].pre){
            x=e[i].to;
            if(x==fa)continue;
            dp(x,node),s[node]-=siz[origin::anc(x,e[i].l-1)];
            f[node]=min(f[x]+e[i].l,f[node]);
        }
    }
    void check(near &m1,near &m2,near d){
        if(d<m1)m2=m1,m1=d;
        else if(d<m2)m2=d;
    }
    void _dp(int node=root,int fa=0,near F=near()){
        if(key[node])F=near(0,node);
        near m1=F,m2;
        for(register int i=h[node],x;i;i=e[i].pre){
            x=e[i].to;
            if(x==fa)continue;
            check(m1,m2,f[x]+e[i].l);
        }
        for(register int i=h[node],x;i;i=e[i].pre){
            x=e[i].to;
            if(x==fa)continue;
            _dp(x,node,(m1==f[x]+e[i].l?m2:m1)+e[i].l);
            near t=min(f[x],m1+e[i].l);
            if(t.x==m1.x)ans[t.x]+=s[x]+origin::sum(x,e[i].l-1);
            else {
                int d=t.x<m1.x?(origin::dis(t.x,m1.x)+2>>1):(origin::dis(t.x,m1.x)+1>>1),k=origin::sum(x,d-1-deep[t.x]+deep[x]);
                ans[t.x]+=s[x]+k;
                ans[m1.x]+=origin::sum(x,e[i].l-1)-k;
            }
        }
    }
    void solve(){
        build(),dp(),_dp();
        for(register int i=1;i<=n;++i)printf("%d ",ans[rec[i]]),key[rec[i]]=0;
        pn;
    }
}
using namespace origin;
int main(){
    int n=read();
    root=n+1;
    for(register int i=1,x,y;i<n;++i)x=read(),y=read(),add(x,y),add(y,x);
    add(root,1),dfs1(),dfs2(),n=read();
    while(n--){
        virt::n=read();
        for(register int i=1;i<=virt::n;++i)virt::p[i]=virt::rec[i]=read();
        virt::solve();
    }
}