写了一天快吐了$QAQ$
某个$zz$的假做法:
先考虑颜色不在边上而是在点上的做法。
$dis(i)$表示根节点到$i$的权值,$ma(i)$表示深度为$i$的点中最大的$dis$。
在合并路径时,根节点一定被经过。这样在计算$dis$时,不算上根节点的权值,计算时直接取恰当的$ma$加上即可,单调队列维护。
用$dis$更新$ma$时,再把根节点的权值加回来。
这是颜色在点上的情况,转到颜色在边上考虑拆边为点,拆出来的点颜色为原边的颜色。
那原有的点的颜色呢?
它父节点指向它的边的颜色。。。
根节点呢?
选一条出边为其颜色。。。
根节点不止一条出边呢?
选一个度数为$1$的点为根节点。。。
写出来后成功喜提$20$分,发现这个做法完全是假的。。。
正解:
记$col(i)$为根节点到$i$的路径第一条边的颜色。
显然合并两条路径$x,y$时,若$col(x)=col(y)$需要减去$col(x)$的权值。
把$col$相同的放在一起考虑,对它们用单调队列维护,统计答案时减去该颜色的权值。
对$col$不同的还是单调队列,直接加起来统计。
为保证复杂度,不同$col$之间按包含路径中最大深度排序,优先处理深度小的。
同样地,相同$col$内部也优先处理深度小的。
统计时还要把深度相同的放一块处理,用$bfs$消去排序的$\log$。
这样每次分治的复杂度为当前树大小。
总复杂度$O(n\log n)$。
话说我就对个拍hack掉了所有有代码的题解。。。
代码:
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cmath>
#include <cstring>
#include <vector>
#define maxn 200005
#define inf 0x3f3f3f3f
using namespace std;
inline int read(){
int x=0,y=0;
char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')y=1;ch=getchar();}
while(ch>='0'&&ch<='9')x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
return y?-x:x;
}
struct Monoqueue{
int l1[maxn],l2[maxn],head,tail;
void clear(){head=1,tail=0;}
void check(int x){
while(head<=tail&&l1[head]>x)++head;
}
void push(int pos,int d){
if(d==-inf)return;
while(head<=tail&&l2[tail]<d)--tail;
l1[++tail]=pos,l2[tail]=d;
}
int front(){
if(head<=tail)return l2[head];
return -inf;
}
}q1,q2;
struct edge{
int pre,to,l;
}e[maxn<<1];
int md[maxn],mmd[maxn],siz[maxn],v[maxn],c[maxn],h[maxn],col[maxn],L,R,mx,root,all,num,head,tail,ans;
int f[maxn],line[maxn],deep[maxn],dis[maxn],sma[maxn],srec[maxn],dma[maxn];
vector<int>poi[maxn];
bool vis[maxn];
inline bool cmp1(int x,int y){return md[x]<md[y];}
inline bool cmp2(int x,int y){return mmd[x]<mmd[y];}
inline void add(int from,int to,int l){
e[++num].pre=h[from],h[from]=num,e[num].to=to,e[num].l=l;
}
void getroot(int node,int fa){
siz[node]=1;
int x,ma=0;
for(register int i=h[node];i;i=e[i].pre){
x=e[i].to;
if(x==fa||vis[x])continue;
getroot(x,node),siz[node]+=siz[x],ma=max(ma,siz[x]);
}
ma=max(ma,all-siz[node]);
if(ma<mx)mx=ma,root=node;
}
void bfs(int node){
int x,y;
head=0,line[tail=1]=node,deep[node]=1;
while(head<tail){
x=line[++head];
if(deep[x]>R){md[node]=R;return;}
for(register int i=h[x];i;i=e[i].pre){
y=e[i].to;
if(y==f[x]||vis[y])continue;
c[y]=e[i].l,f[y]=x,line[++tail]=y,deep[y]=deep[x]+1;
}
}
md[node]=deep[line[tail]];
}
void calc(int node,int m){
q2.clear();
for(register int i=m;i>=L;--i)q2.push(i,sma[i]);
int x,y,l,r;
head=0,line[tail=1]=node;
while(head<tail){
x=line[++head];
if(deep[x]>R)return;
dis[x]=dis[f[x]]+v[c[x]]*bool(c[x]^c[f[x]]);
srec[deep[x]]=max(srec[deep[x]],dis[x]);
for(register int i=h[x];i;i=e[i].pre){
y=e[i].to;
if(y==f[x]||vis[y])continue;
line[++tail]=y;
}
if(head==tail||deep[x]!=deep[line[head+1]]){
l=L-deep[x],r=R-deep[x];
q2.check(r);
if(l>=0)q2.push(l,sma[l]);
ans=max(ans,q2.front()+srec[deep[x]]-dis[node]);
}
}
}
void calc_siz(int node,int f=0){
siz[node]=1;
int x;
for(register int i=h[node];i;i=e[i].pre){
x=e[i].to;
if(x==f||vis[x])continue;
calc_siz(x,node),siz[node]+=siz[x];
}
}
void solve(int node){
int x,y,z,len=0;
vis[node]=1;
for(register int i=h[node];i;i=e[i].pre){
x=e[i].to;
if(vis[x])continue;
if(poi[e[i].l].empty())col[++len]=e[i].l,mmd[e[i].l]=0;
c[x]=e[i].l,f[x]=node,poi[c[x]].push_back(x),bfs(x);
mmd[c[x]]=max(mmd[c[x]],md[x]);
}
c[node]=dis[node]=0;
sort(col+1,col+1+len,cmp2);
y=mmd[col[len]];
for(register int i=1;i<=len;++i){
x=col[i],z=0;
sort(poi[x].begin(),poi[x].end(),cmp1);
for(vector<int>::iterator iter=poi[x].begin();iter!=poi[x].end();++iter){
dis[*iter]=v[x],calc(*iter,z);
for(register int j=z=md[*iter];j;--j)
sma[j]=max(sma[j],srec[j]),srec[j]=-inf;
}
q1.clear();
for(register int j=mmd[col[i-1]];j>=L;--j)q1.push(j,dma[j]);
for(register int j=1;j<=mmd[x];++j){
q1.check(R-j);
if(L>=j)q1.push(L-j,dma[L-j]);
ans=max(q1.front()+sma[j],ans);
}
for(register int j=mmd[x];j;--j)
dma[j]=max(dma[j],sma[j]),sma[j]=-inf;
poi[x].clear();
}
for(register int i=y;i;--i)dma[i]=-inf;
calc_siz(node);
for(register int i=h[node];i;i=e[i].pre){
x=e[i].to;
if(vis[x])continue;
root=0,all=siz[x],mx=inf,getroot(x,node),solve(root);
}
}
int main(){
memset(sma,~0x3f,sizeof sma);
memset(srec,~0x3f,sizeof srec);
memset(dma,~0x3f,sizeof dma);
ans=-inf,dma[0]=0;
int n=read(),m=read(),x,y,z;
L=read(),R=read();
for(register int i=1;i<=m;++i)v[i]=read();
for(register int i=1;i<n;++i)x=read(),y=read(),z=read(),add(x,y,z),add(y,x,z);
mx=inf,all=n,getroot(1,0),solve(root);
printf("%d\n",ans);
}