传送门

写了一天快吐了$QAQ$


某个$zz$的假做法:

先考虑颜色不在边上而是在点上的做法。

$dis(i)$表示根节点到$i$的权值,$ma(i)$表示深度为$i$的点中最大的$dis$。

在合并路径时,根节点一定被经过。这样在计算$dis$时,不算上根节点的权值,计算时直接取恰当的$ma$加上即可,单调队列维护。

用$dis$更新$ma$时,再把根节点的权值加回来。

这是颜色在点上的情况,转到颜色在边上考虑拆边为点,拆出来的点颜色为原边的颜色。

那原有的点的颜色呢?

它父节点指向它的边的颜色。。。

根节点呢?

选一条出边为其颜色。。。

根节点不止一条出边呢?

选一个度数为$1$的点为根节点。。。

写出来后成功喜提$20$分,发现这个做法完全是假的。。。


正解:

记$col(i)$为根节点到$i$的路径第一条边的颜色。

显然合并两条路径$x,y$时,若$col(x)=col(y)$需要减去$col(x)$的权值。

把$col$相同的放在一起考虑,对它们用单调队列维护,统计答案时减去该颜色的权值。

对$col$不同的还是单调队列,直接加起来统计。

为保证复杂度,不同$col$之间按包含路径中最大深度排序,优先处理深度小的。

同样地,相同$col$内部也优先处理深度小的。

统计时还要把深度相同的放一块处理,用$bfs$消去排序的$\log$。

这样每次分治的复杂度为当前树大小。

总复杂度$O(n\log n)$。

话说我就对个拍hack掉了所有有代码的题解。。。

代码:

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cmath>
#include <cstring>
#include <vector>

#define maxn 200005
#define inf 0x3f3f3f3f

using namespace std;

inline int read(){
    int x=0,y=0;
    char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')y=1;ch=getchar();}
    while(ch>='0'&&ch<='9')x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
    return y?-x:x;
}
struct Monoqueue{
    int l1[maxn],l2[maxn],head,tail;
    void clear(){head=1,tail=0;}
    void check(int x){
        while(head<=tail&&l1[head]>x)++head;
    }
    void push(int pos,int d){
        if(d==-inf)return;
        while(head<=tail&&l2[tail]<d)--tail;
        l1[++tail]=pos,l2[tail]=d;
    }
    int front(){
        if(head<=tail)return l2[head];
        return -inf;
    }
}q1,q2;
struct edge{
    int pre,to,l;
}e[maxn<<1];
int md[maxn],mmd[maxn],siz[maxn],v[maxn],c[maxn],h[maxn],col[maxn],L,R,mx,root,all,num,head,tail,ans;
int f[maxn],line[maxn],deep[maxn],dis[maxn],sma[maxn],srec[maxn],dma[maxn];
vector<int>poi[maxn];
bool vis[maxn];
inline bool cmp1(int x,int y){return md[x]<md[y];}
inline bool cmp2(int x,int y){return mmd[x]<mmd[y];}
inline void add(int from,int to,int l){
    e[++num].pre=h[from],h[from]=num,e[num].to=to,e[num].l=l;
}
void getroot(int node,int fa){
    siz[node]=1;
    int x,ma=0;
    for(register int i=h[node];i;i=e[i].pre){
        x=e[i].to;
        if(x==fa||vis[x])continue;
        getroot(x,node),siz[node]+=siz[x],ma=max(ma,siz[x]);
    }
    ma=max(ma,all-siz[node]);
    if(ma<mx)mx=ma,root=node;
}
void bfs(int node){
    int x,y;
    head=0,line[tail=1]=node,deep[node]=1;
    while(head<tail){
        x=line[++head];
        if(deep[x]>R){md[node]=R;return;}
        for(register int i=h[x];i;i=e[i].pre){
            y=e[i].to;
            if(y==f[x]||vis[y])continue;
            c[y]=e[i].l,f[y]=x,line[++tail]=y,deep[y]=deep[x]+1;
        }    
    }
    md[node]=deep[line[tail]];
}
void calc(int node,int m){
    q2.clear();
    for(register int i=m;i>=L;--i)q2.push(i,sma[i]);
    int x,y,l,r;
    head=0,line[tail=1]=node;
    while(head<tail){
        x=line[++head];
        if(deep[x]>R)return;
        dis[x]=dis[f[x]]+v[c[x]]*bool(c[x]^c[f[x]]);
        srec[deep[x]]=max(srec[deep[x]],dis[x]);
        for(register int i=h[x];i;i=e[i].pre){
            y=e[i].to;
            if(y==f[x]||vis[y])continue;
            line[++tail]=y;
        }
        if(head==tail||deep[x]!=deep[line[head+1]]){
            l=L-deep[x],r=R-deep[x];
            q2.check(r);
            if(l>=0)q2.push(l,sma[l]);
            ans=max(ans,q2.front()+srec[deep[x]]-dis[node]);
        }
    }
}
void calc_siz(int node,int f=0){
    siz[node]=1;
    int x;
    for(register int i=h[node];i;i=e[i].pre){
        x=e[i].to;
        if(x==f||vis[x])continue;
        calc_siz(x,node),siz[node]+=siz[x];
    }
}
void solve(int node){
    int x,y,z,len=0;
    vis[node]=1;
    for(register int i=h[node];i;i=e[i].pre){
        x=e[i].to;
        if(vis[x])continue;
        if(poi[e[i].l].empty())col[++len]=e[i].l,mmd[e[i].l]=0;
        c[x]=e[i].l,f[x]=node,poi[c[x]].push_back(x),bfs(x);
        mmd[c[x]]=max(mmd[c[x]],md[x]);
    }
    c[node]=dis[node]=0;
    sort(col+1,col+1+len,cmp2);
    y=mmd[col[len]];
    for(register int i=1;i<=len;++i){
        x=col[i],z=0;
        sort(poi[x].begin(),poi[x].end(),cmp1);
        for(vector<int>::iterator iter=poi[x].begin();iter!=poi[x].end();++iter){
            dis[*iter]=v[x],calc(*iter,z);
            for(register int j=z=md[*iter];j;--j)
                sma[j]=max(sma[j],srec[j]),srec[j]=-inf;
        }
        q1.clear();
        for(register int j=mmd[col[i-1]];j>=L;--j)q1.push(j,dma[j]);
        for(register int j=1;j<=mmd[x];++j){
            q1.check(R-j);
            if(L>=j)q1.push(L-j,dma[L-j]);
            ans=max(q1.front()+sma[j],ans);
        }
        for(register int j=mmd[x];j;--j)
            dma[j]=max(dma[j],sma[j]),sma[j]=-inf;
        poi[x].clear();
    }
    for(register int i=y;i;--i)dma[i]=-inf;
    calc_siz(node);
    for(register int i=h[node];i;i=e[i].pre){
        x=e[i].to;
        if(vis[x])continue;
        root=0,all=siz[x],mx=inf,getroot(x,node),solve(root);
    }
}
int main(){
    memset(sma,~0x3f,sizeof sma);
    memset(srec,~0x3f,sizeof srec);
    memset(dma,~0x3f,sizeof dma);
    ans=-inf,dma[0]=0;
    int n=read(),m=read(),x,y,z;
    L=read(),R=read();
    for(register int i=1;i<=m;++i)v[i]=read();
    for(register int i=1;i<n;++i)x=read(),y=read(),z=read(),add(x,y,z),add(y,x,z);
    mx=inf,all=n,getroot(1,0),solve(root);
    printf("%d\n",ans);
}