传送门

一边听主席的故事(本子)会一边调出来了。

题解里都是哈希,我这个大常数SAM+kmp就显得很另类(

考虑点分治时如何把两条路径拼起来。用$[P_i]$表示前缀,$[S_i]$表示后缀,$a[S]$表示$a$个$S$相连($a\in\mathbb{Z}$),一定是形如$a[S]+[P_i]$和$[S_{i+1}]+b[S]$的两条半路径拼成的路径符合要求:

$T_x$表示分治中心到点$x$的路径形成的字符串。如果$T_x$形如$[S_{i+1}]+b[S]$,则用$sl_x$表示$S_{i+1}$的长度。

对$S$串建$SAM$,并把$S$代表的节点(最后插入的字符新建的节点)及其$parent\ tree$上的祖先都打上标记。如果把$T_x$在$SAM$上跑一遍走到了有标记的节点,则$T_x$为$S$的一个后缀,$sl_x=deep_x$。

解决了后缀的问题,只要找出$[S]$即可。$kmp$匹配一下就行。如果$T_x$正好匹配$S$,则$sl_x=sl_y$($y$为$x$的$m$级祖先)。

而$a[S]+[P_i]$就把$S$反过来做一遍。开个$sl$的桶统计答案就做完了。

复杂度$O(n\log n)$,成功挤进了最优解最后一页。

代码:

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cmath>
#include <cstring>

#define maxn 1000005
#define inf 0x3f3f3f3f

using namespace std;

inline int read(){
    int x=0,y=0;
    char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')y=1;ch=getchar();}
    while(ch>='0'&&ch<='9')x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
    return y?-x:x;
}
char s[2][maxn],t[maxn],a[maxn];
int m;
struct SAM{
#define son(x,y) son[x][y]
    int son[maxn<<1][26],fa[maxn<<1],len[maxn<<1],cnt,last;
    bool en[maxn<<1];
    void clear(){
        for(register int i=1;i<=cnt;++i)memset(son[i],0,sizeof son[i]),fa[i]=len[i]=en[i]=0;
        cnt=last=1;
    }
    int insert(int c){
        int p=last,ne=last=++cnt;
        len[ne]=len[p]+1;
        while(p&&!son(p,c))son(p,c)=ne,p=fa[p];
        if(!p)fa[ne]=1;
        else {
            int q=son(p,c);
            if(len[q]==len[p]+1)fa[ne]=q;
            else {
                int sp=++cnt;
                memcpy(son[sp],son[q],sizeof son[q]);
                len[sp]=len[p]+1,fa[sp]=fa[q],fa[q]=fa[ne]=sp;
                while(p&&son(p,c)==q)son(p,c)=sp,p=fa[p];
            }
        }
        return ne;
    }
    void init(char *s){
        clear();
        for(register int i=1;i<m;++i)insert(s[i]);
        int node=insert(s[m]);
        while(node)en[node]=1,node=fa[node];
    }
}S[2];
int siz[maxn],h[maxn],st[maxn][2],tax[2][maxn],q[maxn],sl[maxn],nex[2][maxn],tail,num,mx,root,all,cnt;
long long ans;
bool vis[maxn];
struct edge{int pre,to;}e[maxn];
inline void add(int from,int to){e[++num]=(edge){h[from],to},h[from]=num;}
void getroot(int node,int fa=0){
    int ma=all-siz[node];
    for(register int i=h[node],x;i;i=e[i].pre){
        x=e[i].to;
        if(vis[x]||x==fa)continue;
        getroot(x,node),ma=max(ma,siz[x]);
    }
    if(ma<mx)mx=ma,root=node;
}
void modify(int id,int len){
    if(len==-1)return;
    ans+=tax[id^1][(m-len)%m];
    st[++cnt][0]=len,st[cnt][1]=id;
}
void calc(int node,int fa,int sam,int len,int id){
    q[++tail]=node,sam=S[id].son(sam,a[node]);
    while(len&&s[id][len+1]!=a[node])len=nex[id][len];
    if(s[id][len+1]==a[node])++len;
    if(len==m)sl[node]=sl[q[tail-m]],len=nex[id][m];
    else if(S[id].en[sam])sl[node]=tail%m;
    modify(id,sl[node]);
    for(register int i=h[node],x;i;i=e[i].pre){
        x=e[i].to;
        if(x==fa||vis[x])continue;
        calc(x,node,sam,len,id);
    }
    --tail,sl[node]=-1;
}
void dfs(int node,int fa){
    siz[node]=1,sl[node]=-1;
    for(register int i=h[node],x;i;i=e[i].pre){
        x=e[i].to;
        if(x==fa||vis[x])continue;
        dfs(x,node),siz[node]+=siz[x];
    }
}
void solve(int node){
    vis[node]=1,cnt=0,modify(0,0);
    ++tax[0][0];
    int sam=S[1].son(1,a[node]),len=(s[1][1]==a[node]),last=1;
    long long rec=ans;
    bool flag=0;
    if(len==m)len=nex[1][m];
    else if(S[1].en[sam])sl[node]=1%m;
    modify(1,sl[node]);
    for(register int i=h[node],x;i;i=e[i].pre){
        x=e[i].to;
        if(vis[x])continue;
        for(register int j=last+1;j<=cnt;++j)
            ++tax[st[j][1]][st[j][0]];
        last=cnt;
        calc(x,node,1,0,0);
        q[++tail]=node,calc(x,node,sam,len,1),--tail;
    }
    dfs(node,0);
    for(register int i=1;i<=last;++i)--tax[st[i][1]][st[i][0]];
    for(register int i=h[node],x;i;i=e[i].pre){
        x=e[i].to;
        if(vis[x])continue;
        all=siz[x],mx=inf,getroot(x,node),solve(root);
    }
}
int main(){
    memset(sl,-1,sizeof sl),sl[0]=0;
    int t=read();
    while(t--){
        memset(vis,0,sizeof vis);
        memset(h,0,sizeof h),num=ans=0;
        int n=read();
        m=read();
        scanf("%s",a+1);
        for(register int i=1;i<=n;++i)a[i]-='A';
        for(register int i=1,x,y;i<n;++i)x=read(),y=read(),add(x,y),add(y,x);
        scanf("%s",s[0]+1);
        for(register int i=1;i<=m;++i)s[1][i]=(s[0][m-i+1]-='A');
        S[0].init(s[0]),S[1].init(s[1]);
        int j=0;
        for(register int i=2;i<=m;++i){
            while(j&&s[0][i]!=s[0][j+1])j=nex[0][j];
            if(s[0][i]==s[0][j+1])++j;
            nex[0][i]=j;
        }
        j=0;
        for(register int i=2;i<=m;++i){
            while(j&&s[1][i]!=s[1][j+1])j=nex[1][j];
            if(s[1][i]==s[1][j+1])++j;
            nex[1][i]=j;
        }
        dfs(1,0),mx=inf,all=n,getroot(1,0),solve(root);
        printf("%lld\n",ans);
    }
}