传送门

我又双叒叕尝试学习数论辣!

算是我的组合数学+容斥的入门题。

把可行方案数转成总方案数减去不可行方案数。

然后把不可行方案分成三类:不满足条件$1$、$2$、$3$的。容斥加加减减就完了。

推完长这样:

感觉我推的十分鬼畜。。。查阅一下$asuldb$的题解发现他推的长这样:

应该是可以化简的然而懒得化了反正复杂度一样

这个复杂度是$O(nmc\log nm)$的,会$T$,瓶颈在于

有两种方案:

1.瞎JB改式子

调换一下顺序:

倒序循环$i$,会发现$i$每减少$1$,后面的幂就会乘上$(c+1-k)^{m-j}$,那么在第二层循环里可以先算出$(c+1-k)^{m-j}$,乘起来即可,就能把$\log$去掉。甚至可以把前面的幂这样处理优化常数。

2.黑科技分块光速幂

这个我只会口胡

分块光速幂可以达到$O(n\sqrt{V})$($n$为底数值域,$V$为指数值域)预处理,$O(1)$回答$n^V$。

然后就能很暴力地$O(c\sqrt{mn})$预处理,$O(1)$求$(c+1-k)^{(m-j)(n-i)}$了。同样可以去掉前面的所有快速幂的$\log$。

最终复杂度为$O(nmc)$

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cmath>
#include <cstring>

#define maxn 405
#define inf 0x3f3f3f3f

const int mod = 1e9 + 7;

using namespace std;

inline int read(){
    int x=0,y=0;
    char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')y=1;ch=getchar();}
    while(ch>='0'&&ch<='9')x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
    return y?-x:x;
}
long long quickpow(long long x,int y){
    long long ans=1;
    while(y){
        if(y&1)ans=ans*x%mod;
        x=x*x%mod;
        y>>=1;
    }
    return ans;
}
inline int fac(int x){
    return x&1?-1:1;
}
int C[maxn][maxn];
int main(){
    C[0][0]=1;
    int n=read(),m=read(),c=read(),M=max(max(n,m),c);
    long long ans=0;
    for(register int i=1;i<=M;++i){
        C[0][i]=1;
        for(register int j=1;j<=M;++j)
            C[j][i]=(C[j][i-1]+C[j-1][i-1])%mod;
    }
    for(register int i=1;i<=n;++i)(ans+=quickpow(c+1,m*(n-i))*C[i][n]*fac(i+1))%=mod;
    for(register int i=1;i<=m;++i)(ans+=quickpow(c+1,n*(m-i))*C[i][m]*fac(i+1))%=mod;
    for(register int i=1;i<=c;++i)(ans+=quickpow(c+1-i,m*n)*C[i][c]*fac(i+1))%=mod;
    for(register int i=1;i<=n;++i)
        for(register int j=1;j<=m;++j)
            (ans+=quickpow(c+1,(n-i)*(m-j))*C[i][n]%mod*C[j][m]*fac(i+j+1))%=mod;
    for(register int i=1;i<=n;++i)
        for(register int j=1;j<=c;++j)
            (ans+=quickpow(c+1-j,m*(n-i))*C[i][n]%mod*C[j][c]*fac(i+j+1))%=mod;
    for(register int i=1;i<=m;++i)
        for(register int j=1;j<=c;++j){
            (ans+=quickpow(c+1-j,n*(m-i))*C[i][m]%mod*C[j][c]*fac(i+j+1))%=mod;
        }
    for(register int k=1;k<=c;++k){
        long long p=1;
        for(register int j=m;j;--j){
            long long q=1;
            for(register int i=n;i;--i,q=q*p%mod)
                (ans+=q*C[i][n]%mod*C[j][m]%mod*C[k][c]*fac(i+j+k+1))%=mod;
            p=p*(c+1-k)%mod;
        }
    }
    printf("%lld\n",(quickpow(c+1,n*m)-ans+mod)%mod);//注意模成正数
}