我又双叒叕尝试学习数论辣!
算是我的组合数学+容斥的入门题。
把可行方案数转成总方案数减去不可行方案数。
然后把不可行方案分成三类:不满足条件$1$、$2$、$3$的。容斥加加减减就完了。
推完长这样:
感觉我推的十分鬼畜。。。查阅一下$asuldb$的题解发现他推的长这样:
应该是可以化简的然而懒得化了反正复杂度一样
这个复杂度是$O(nmc\log nm)$的,会$T$,瓶颈在于
有两种方案:
1.瞎JB改式子
调换一下顺序:
倒序循环$i$,会发现$i$每减少$1$,后面的幂就会乘上$(c+1-k)^{m-j}$,那么在第二层循环里可以先算出$(c+1-k)^{m-j}$,乘起来即可,就能把$\log$去掉。甚至可以把前面的幂这样处理优化常数。
2.黑科技分块光速幂
这个我只会口胡
分块光速幂可以达到$O(n\sqrt{V})$($n$为底数值域,$V$为指数值域)预处理,$O(1)$回答$n^V$。
然后就能很暴力地$O(c\sqrt{mn})$预处理,$O(1)$求$(c+1-k)^{(m-j)(n-i)}$了。同样可以去掉前面的所有快速幂的$\log$。
最终复杂度为$O(nmc)$
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cmath>
#include <cstring>
#define maxn 405
#define inf 0x3f3f3f3f
const int mod = 1e9 + 7;
using namespace std;
inline int read(){
int x=0,y=0;
char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')y=1;ch=getchar();}
while(ch>='0'&&ch<='9')x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
return y?-x:x;
}
long long quickpow(long long x,int y){
long long ans=1;
while(y){
if(y&1)ans=ans*x%mod;
x=x*x%mod;
y>>=1;
}
return ans;
}
inline int fac(int x){
return x&1?-1:1;
}
int C[maxn][maxn];
int main(){
C[0][0]=1;
int n=read(),m=read(),c=read(),M=max(max(n,m),c);
long long ans=0;
for(register int i=1;i<=M;++i){
C[0][i]=1;
for(register int j=1;j<=M;++j)
C[j][i]=(C[j][i-1]+C[j-1][i-1])%mod;
}
for(register int i=1;i<=n;++i)(ans+=quickpow(c+1,m*(n-i))*C[i][n]*fac(i+1))%=mod;
for(register int i=1;i<=m;++i)(ans+=quickpow(c+1,n*(m-i))*C[i][m]*fac(i+1))%=mod;
for(register int i=1;i<=c;++i)(ans+=quickpow(c+1-i,m*n)*C[i][c]*fac(i+1))%=mod;
for(register int i=1;i<=n;++i)
for(register int j=1;j<=m;++j)
(ans+=quickpow(c+1,(n-i)*(m-j))*C[i][n]%mod*C[j][m]*fac(i+j+1))%=mod;
for(register int i=1;i<=n;++i)
for(register int j=1;j<=c;++j)
(ans+=quickpow(c+1-j,m*(n-i))*C[i][n]%mod*C[j][c]*fac(i+j+1))%=mod;
for(register int i=1;i<=m;++i)
for(register int j=1;j<=c;++j){
(ans+=quickpow(c+1-j,n*(m-i))*C[i][m]%mod*C[j][c]*fac(i+j+1))%=mod;
}
for(register int k=1;k<=c;++k){
long long p=1;
for(register int j=m;j;--j){
long long q=1;
for(register int i=n;i;--i,q=q*p%mod)
(ans+=q*C[i][n]%mod*C[j][m]%mod*C[k][c]*fac(i+j+k+1))%=mod;
p=p*(c+1-k)%mod;
}
}
printf("%lld\n",(quickpow(c+1,n*m)-ans+mod)%mod);//注意模成正数
}