大概是最后一次写题解了吧。(并不是)
题目规定的形如$AA$的串自然可以想到优秀的拆分,根据套路,枚举$A$串的长度$len$,每$len$的长度设置一个关键点,$AA$串一定会跨过关键点。
对于每个相邻关键点$i,j$,若$lcs(i,j)+lcp(i+1,j+1)\ge len$就可以产生$AA$串,且左右端点在$[i-lcs(i,j)+1,i+lcp(i+1,j+1)]$内移动,于是$\forall k\in[i-lcs(i,j)+1,i+lcp(i+1,j+1)]$,需连边$(k,k+len)$。
考虑这个操作:一个区间与另一个区间对应连边,容易想到萌萌哒。
用$f(i,k)$表示点集$[i,i+2^k-1]$,把每次连边拆成两次$ST$表上的连边。
最后从高层往底层推下去。若有边$(f(i,k),f(j,k))$,则必有边$(f(i,k-1),f(j,k-1)),(f(i+2^{k-1},k-1),f(i+2^{k-1},k-1))$。不过直接把边推到底层边数还是爆炸的,所以可以对每层做一遍最小生成树,把有用的边下放,可以保证每层的边数都是$O(n)$的。
枚举$AA$串长度是调和级数的$O(n\log n)$,$ST$表的下放是$O(n\log^2n)$的,总复杂度$O(n\log^2n)$,空间复杂度$O(n\log n)$。
为了卡时间卡空间,线段树替换$ST$表求$lcp$,归并排序合并边集,写垃圾桶回收废弃边,代码丑的一批:
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cmath>
#include <cstring>
#include <vector>
#define maxn 300005
#define inf 0x3f3f3f3f
using namespace std;
inline int read(){
int x=0,y=0;
char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')y=1;ch=getchar();}
while(ch>='0'&&ch<='9')x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
return y?-x:x;
}
int lg[maxn],fa[maxn],s[maxn],pool[maxn*30],edg[maxn<<4],rec[maxn<<3],top,n,m,cnt,all,num;
struct edge{
int x,y,l;
bool operator < (const edge &X)const{return l<X.l;}
}e[maxn*40];
vector<int>ed[21];
struct Suffix_Array{
#define ls(x) x<<1
#define rs(x) x<<1|1
int sa[maxn],rk[maxn],tp[maxn],tax[maxn],hei[maxn],mi[maxn<<2];
void Rsort(){
for(register int i=0;i<=m;++i)tax[i]=0;
for(register int i=1;i<=n;++i)++tax[rk[i]];
for(register int i=1;i<=m;++i)tax[i]+=tax[i-1];
for(register int i=n;i;--i)sa[tax[rk[tp[i]]]--]=tp[i];
}
void Ssort(){
for(register int i=1;i<=n;++i)rk[i]=s[i],tp[i]=i;
m=n,Rsort();
for(register int k=1,p=0;p<n;m=p,k<<=1){
p=0;
for(register int i=1;i<=k;++i)tp[++p]=n-k+i;
for(register int i=1;i<=n;++i)if(sa[i]>k)tp[++p]=sa[i]-k;
Rsort();
for(register int i=1;i<=n;++i)tp[i]=rk[i];
rk[sa[1]]=p=1;
for(register int i=2;i<=n;++i)
rk[sa[i]]=tp[sa[i]]==tp[sa[i-1]]&&tp[sa[i]+k]==tp[sa[i-1]+k]?p:++p;
}
}
void get_height(){
int k=0,x;
for(register int i=1;i<=n;++i){
if(rk[i]==1)continue;
if(k)--k;
x=sa[rk[i]-1];
while(i+k<=n&&x+k<=n&&s[i+k]==s[x+k])++k;
hei[rk[i]]=k;
}
}
void build(int l,int r,int node){
if(l==r){mi[node]=hei[l];return;}
int mid=l+r>>1;
build(l,mid,ls(node));
build(mid+1,r,rs(node));
mi[node]=min(mi[ls(node)],mi[rs(node)]);
}
int query(int L,int R,int l,int r,int node){
if(L<=l&&R>=r)return mi[node];
int mid=l+r>>1,ans=inf;
if(L<=mid)ans=query(L,R,l,mid,ls(node));
if(R>mid)ans=min(ans,query(L,R,mid+1,r,rs(node)));
return ans;
}
inline int lcp(int i,int j){
return query(min(rk[i],rk[j])+1,max(rk[i],rk[j]),1,n,1);
}
void clear(){
for(register int i=1;i<=n;++i)tp[i]=rk[i]=0;
}
void init(){
Ssort(),get_height(),build(1,n,1);
}
}pre,suf;
inline int LCS(int x,int y){
if(!x||!y)return 0;
return pre.lcp(n-x+1,n-y+1);
}
inline int LCP(int x,int y){return suf.lcp(x,y);}
inline void add(int l,int r,int len,int w){
int k=lg[r-l+1];
e[++cnt]=(edge){l,l+len,w};
ed[k].push_back(cnt);
e[++cnt]=(edge){r-(1<<k)+1,r+len-(1<<k)+1,w};
ed[k].push_back(cnt);
}
int find(int x){return fa[x]==x?x:fa[x]=find(fa[x]);}
inline int newn(){return top?pool[top--]:++cnt;}
void merge(vector<int>&b){
all=0;
int i=1,j=0;
while(i<=num&&j<b.size()){
if(e[rec[i]].l<e[b[j]].l)edg[++all]=rec[i++];
else edg[++all]=b[j++];
}
while(i<=num)edg[++all]=rec[i++];
while(j<b.size())edg[++all]=b[j++];
}
inline bool cmp(int x,int y){return e[x].l<e[y].l;}
int main(){
int t=read();
while(t--){
suf.clear(),pre.clear();
n=read(),cnt=top=0;
for(register int i=1;i<=n;++i)s[i]=read();
for(register int i=2;i<=n;++i)lg[i]=lg[i>>1]+1;
suf.init();
reverse(s+1,s+1+n),pre.init();
for(register int len=1;len<=(n>>1);++len){
int w=read();
for(register int i=len<<1;i<=n;i+=len){
int j=i-len,lcs=min(LCS(i,j),len),lcp=min(LCP(i+1,j+1),len-1);
if(lcs+lcp>=len)add(j-lcs+1,j+lcp,len,w);
}
}
vector<int>::iterator it;
long long ans=0;
for(register int j=lg[n];~j;--j){
sort(ed[j].begin(),ed[j].end(),cmp),merge(ed[j]),ed[j].clear(),num=0;
for(register int i=1;i<=n;++i)fa[i]=i;
for(register int i=1;i<=all;++i){
edge E=e[edg[i]];
int u=find(E.x),v=find(E.y);
if(u!=v){
fa[u]=v;
if(j){
rec[++num]=edg[i];
int p=newn();
e[p]=(edge){E.x+(1<<j-1),E.y+(1<<j-1),E.l},rec[++num]=p;
}
else ans+=E.l;
}
else pool[++top]=edg[i];
}
}
printf("%lld\n",ans);
}
}