题意

给定 $n,x,y\left(1\le n\le 10^9,0\le x,y\lt 998244353\right)$,多组询问 $a,b\left(a,b\le 5000\right)$,求

$$ \sum_{i=a}^{n}\binom{i}{a}x^{i-a}\binom{n-i}{b}y^{n-i-b}\pmod{998244353} $$

题解

所求即为 $\displaystyle \left[z^{n-a-b}\right]\dfrac{1}{\left(1-xz\right)^{a+1}\left(1-yz\right)^{b+1}}$

设这个东西为 $F\left(n,a,b\right)$

那么分别把 $1$ 拆成 $\left(1-xz\right)+xz$ 和 $\left(1-yz\right)+yz$ 可得

$$\begin{aligned}F\left(n,a,b\right)&=F\left(n-1,a-1,b\right)+xF\left(n-1,a,b\right)\\F\left(n,a,b\right)&=F\left(n-1,a,b-1\right)+yF\left(n-1,a,b\right)\end{aligned}$$

联立可得

$$ F\left(n,a,b\right)=\dfrac{1}{x-y}\left(-F\left(n,a-1,b\right)+F\left(n,a,b-1\right)\right) $$

于是只要知道 $F\left(n,0,i\right)$ 和 $F\left(n,i,0\right)$ 即可

即为,给定 $m,p$,求 $\displaystyle\sum_{i=0}^m\binom{i+p}{i}a^i$

设答案为 $f\left(m,p\right)$,那么 $f\left(m,p\right)=\begin{cases}f\left(m-1,p\right)\cdot a+f\left(m,p-1\right) & p\ge 1\\\dfrac{a^{m+1}-1}{a-1} & p=0\end{cases}$

设 $F_m\left(z\right)=\sum_{p\ge 0}f\left(m,p\right) z^p$

那么有

$$ \begin{aligned} F_m\left(z\right)&=F_{m-1}\left(z\right)\cdot a+F_{m}\left(z\right)\cdot z+1\\ &=\dfrac{aF_{m-1}\left(z\right)+1}{1-z} \\ &=\dfrac{\left(\frac{a}{1-z}\right)^{m+1}-1}{z+a-1}\\ &=\dfrac{a^{m+1}}{\left(1-z\right)^{m+1}\left(z+a-1\right)}-\dfrac{1}{z+a-1} \end{aligned} $$

$\mathcal{O}\left(p\right)$ 线性计算即可

总复杂度 $\mathcal{O}\left(\sum\max\left(a_i,b_i\right)^2\right)$

Code

#include <cstdio>
#include <algorithm>
using namespace std;
constexpr int N=5010,p=998244353;
int inv[N<<1],n,x,y,q,fac[N<<1],ifac[N<<1];
int fp(int a,int b){int ans=1,off=a;while(b){if(b&1) ans=1ll*ans*off%p;off=1ll*off*off%p;b>>=1;}return ans;}
int calc(int m,int k,int a){
    static int df[N<<1];
    df[0]=1;
    for(int i=1;i<=k;++i) df[i]=1ll*df[i-1]*(m+i)%p*inv[i]%p;
    int val=fp(p+1-a,p-2),ans=0;
    for(int i=0,j=1;i<=k;++i,j=1ll*j*val%p){
        ans=(ans+1ll*df[k-i]*j)%p;
    }
    ans=(fp(val,k)-1ll*ans*fp(a,m+1))%p;
    ans=1ll*val*(ans+p)%p;
    return ans;
}
int main(){
    inv[1]=1;
    for(int i=2;i<(N<<1);++i) inv[i]=1ll*inv[p%i]*(p-p/i)%p;
    fac[0]=ifac[0]=1;
    for(int i=1;i<(N<<1);++i) fac[i]=1ll*fac[i-1]*i%p,ifac[i]=1ll*ifac[i-1]*inv[i]%p;
    while(scanf("%d%d%d%d",&n,&x,&y,&q)!=EOF){
        int k=1ll*x*fp(y,p-2)%p;
        int mxa=0,mxb=0;
        constexpr int Q=200010;
        static int qs[Q][2];
        for(int i=1;i<=q;++i){
            scanf("%d%d",&qs[i][0],&qs[i][1]);
            mxa=max(mxa,qs[i][0]);
            mxb=max(mxb,qs[i][1]);
        }
        if(!x || !y){
            static int nfac[N<<1],nifac[N<<1];
            nfac[0]=nifac[0]=1;
            for(int i=1;i<=mxa+mxb+1;++i) nfac[i]=1ll*nfac[i-1]*(n-i+1)%p,nifac[i]=1ll*nifac[i-1]*fp(n-i+1,p-2)%p;
            for(int i=1;i<=q;++i){
                int a=qs[i][0],b=qs[i][1];
                if(a+b==n){printf("1\n");}
                else{
                    int ans=(1ll*nfac[a+b]*nifac[a]%p*ifac[b]%p*fp(y,n-a-b)
                            +1ll*nfac[a+b]*nifac[b]%p*ifac[a]%p*fp(x,n-a-b))%p;
                    printf("%d\n",ans);
                }
            }
            continue;
        }
        if(k==1){
            static int df[N<<1];
            df[0]=1;
            for(int i=1;i<=mxa+mxb+1;++i) df[i]=1ll*df[i-1]*inv[i]%p*(n-i+2)%p;
            for(int i=1;i<=q;++i){
                printf("%lld\n",1ll*df[qs[i][0]+qs[i][1]+1]*fp(x,n-qs[i][0]-qs[i][1])%p);
            }
            continue;
        }
        static int f[N][N];
        for(int i=0;i<=mxa;++i){
            f[i][0]=calc(n-i,i,k);
        }
        int ik=fp(k,p-2);
        for(int i=1;i<=mxb;++i){
            f[0][i]=1ll*calc(n-i,i,ik)*fp(k,n-i)%p;
        }
        int val=fp(k-1,p-2);
        for(int i=1;i<=mxa;++i) for(int j=1;j<=mxb;++j) if(i+j<=n){
            f[i][j]=1ll*val*(f[i][j-1]-f[i-1][j]+p)%p;
        }
        for(int i=1;i<=q;++i) printf("%lld\n",1ll*f[qs[i][0]][qs[i][1]]*fp(y,n-qs[i][0]-qs[i][1])%p);
    }
    return 0;
}
文章目录