一边听主席的故事(本子)会一边调出来了。
题解里都是哈希,我这个大常数SAM+kmp就显得很另类(
考虑点分治时如何把两条路径拼起来。用$[P_i]$表示前缀,$[S_i]$表示后缀,$a[S]$表示$a$个$S$相连($a\in\mathbb{Z}$),一定是形如$a[S]+[P_i]$和$[S_{i+1}]+b[S]$的两条半路径拼成的路径符合要求:
$T_x$表示分治中心到点$x$的路径形成的字符串。如果$T_x$形如$[S_{i+1}]+b[S]$,则用$sl_x$表示$S_{i+1}$的长度。
对$S$串建$SAM$,并把$S$代表的节点(最后插入的字符新建的节点)及其$parent\ tree$上的祖先都打上标记。如果把$T_x$在$SAM$上跑一遍走到了有标记的节点,则$T_x$为$S$的一个后缀,$sl_x=deep_x$。
解决了后缀的问题,只要找出$[S]$即可。$kmp$匹配一下就行。如果$T_x$正好匹配$S$,则$sl_x=sl_y$($y$为$x$的$m$级祖先)。
而$a[S]+[P_i]$就把$S$反过来做一遍。开个$sl$的桶统计答案就做完了。
复杂度$O(n\log n)$,成功挤进了最优解最后一页。
代码:
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cmath>
#include <cstring>
#define maxn 1000005
#define inf 0x3f3f3f3f
using namespace std;
inline int read(){
int x=0,y=0;
char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')y=1;ch=getchar();}
while(ch>='0'&&ch<='9')x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
return y?-x:x;
}
char s[2][maxn],t[maxn],a[maxn];
int m;
struct SAM{
#define son(x,y) son[x][y]
int son[maxn<<1][26],fa[maxn<<1],len[maxn<<1],cnt,last;
bool en[maxn<<1];
void clear(){
for(register int i=1;i<=cnt;++i)memset(son[i],0,sizeof son[i]),fa[i]=len[i]=en[i]=0;
cnt=last=1;
}
int insert(int c){
int p=last,ne=last=++cnt;
len[ne]=len[p]+1;
while(p&&!son(p,c))son(p,c)=ne,p=fa[p];
if(!p)fa[ne]=1;
else {
int q=son(p,c);
if(len[q]==len[p]+1)fa[ne]=q;
else {
int sp=++cnt;
memcpy(son[sp],son[q],sizeof son[q]);
len[sp]=len[p]+1,fa[sp]=fa[q],fa[q]=fa[ne]=sp;
while(p&&son(p,c)==q)son(p,c)=sp,p=fa[p];
}
}
return ne;
}
void init(char *s){
clear();
for(register int i=1;i<m;++i)insert(s[i]);
int node=insert(s[m]);
while(node)en[node]=1,node=fa[node];
}
}S[2];
int siz[maxn],h[maxn],st[maxn][2],tax[2][maxn],q[maxn],sl[maxn],nex[2][maxn],tail,num,mx,root,all,cnt;
long long ans;
bool vis[maxn];
struct edge{int pre,to;}e[maxn];
inline void add(int from,int to){e[++num]=(edge){h[from],to},h[from]=num;}
void getroot(int node,int fa=0){
int ma=all-siz[node];
for(register int i=h[node],x;i;i=e[i].pre){
x=e[i].to;
if(vis[x]||x==fa)continue;
getroot(x,node),ma=max(ma,siz[x]);
}
if(ma<mx)mx=ma,root=node;
}
void modify(int id,int len){
if(len==-1)return;
ans+=tax[id^1][(m-len)%m];
st[++cnt][0]=len,st[cnt][1]=id;
}
void calc(int node,int fa,int sam,int len,int id){
q[++tail]=node,sam=S[id].son(sam,a[node]);
while(len&&s[id][len+1]!=a[node])len=nex[id][len];
if(s[id][len+1]==a[node])++len;
if(len==m)sl[node]=sl[q[tail-m]],len=nex[id][m];
else if(S[id].en[sam])sl[node]=tail%m;
modify(id,sl[node]);
for(register int i=h[node],x;i;i=e[i].pre){
x=e[i].to;
if(x==fa||vis[x])continue;
calc(x,node,sam,len,id);
}
--tail,sl[node]=-1;
}
void dfs(int node,int fa){
siz[node]=1,sl[node]=-1;
for(register int i=h[node],x;i;i=e[i].pre){
x=e[i].to;
if(x==fa||vis[x])continue;
dfs(x,node),siz[node]+=siz[x];
}
}
void solve(int node){
vis[node]=1,cnt=0,modify(0,0);
++tax[0][0];
int sam=S[1].son(1,a[node]),len=(s[1][1]==a[node]),last=1;
long long rec=ans;
bool flag=0;
if(len==m)len=nex[1][m];
else if(S[1].en[sam])sl[node]=1%m;
modify(1,sl[node]);
for(register int i=h[node],x;i;i=e[i].pre){
x=e[i].to;
if(vis[x])continue;
for(register int j=last+1;j<=cnt;++j)
++tax[st[j][1]][st[j][0]];
last=cnt;
calc(x,node,1,0,0);
q[++tail]=node,calc(x,node,sam,len,1),--tail;
}
dfs(node,0);
for(register int i=1;i<=last;++i)--tax[st[i][1]][st[i][0]];
for(register int i=h[node],x;i;i=e[i].pre){
x=e[i].to;
if(vis[x])continue;
all=siz[x],mx=inf,getroot(x,node),solve(root);
}
}
int main(){
memset(sl,-1,sizeof sl),sl[0]=0;
int t=read();
while(t--){
memset(vis,0,sizeof vis);
memset(h,0,sizeof h),num=ans=0;
int n=read();
m=read();
scanf("%s",a+1);
for(register int i=1;i<=n;++i)a[i]-='A';
for(register int i=1,x,y;i<n;++i)x=read(),y=read(),add(x,y),add(y,x);
scanf("%s",s[0]+1);
for(register int i=1;i<=m;++i)s[1][i]=(s[0][m-i+1]-='A');
S[0].init(s[0]),S[1].init(s[1]);
int j=0;
for(register int i=2;i<=m;++i){
while(j&&s[0][i]!=s[0][j+1])j=nex[0][j];
if(s[0][i]==s[0][j+1])++j;
nex[0][i]=j;
}
j=0;
for(register int i=2;i<=m;++i){
while(j&&s[1][i]!=s[1][j+1])j=nex[1][j];
if(s[1][i]==s[1][j+1])++j;
nex[1][i]=j;
}
dfs(1,0),mx=inf,all=n,getroot(1,0),solve(root);
printf("%lld\n",ans);
}
}