For a better experience please click here.
Link to the question: Luogu, AtCoder
Preface
The very first generating function and polynomial problem solved in my life!
This blog is a detailed explanation and extension of the official editorial. I will try my best to explain the mathematical expressions and their deeper meanings so that you may understand if you are also new to generating functions and polynomials
Our Goal
Let $$$X$$$ be our random variable, which is the number of rolls after which all $$$N$$$-sides have shown up once for the first time. Its probability mass function
Unable to parse markup [type=CF_MATHJAX]
is just the probability that all $$$N$$$-sides have shown up at exactly $$$i$$$-th roll.Then, what we are looking for is
Unable to parse markup [type=CF_MATHJAX]
Unable to parse markup [type=CF_MATHJAX]
Unable to parse markup [type=CF_MATHJAX]
$$$\vdots$$$
Unable to parse markup [type=CF_MATHJAX]
(Actually, $$$p_i=0$$$ if
Unable to parse markup [type=CF_MATHJAX]
but it doesn't matter.)Derivation: Generating Functions
Ordinary and Exponential Generating Functions
For a sequence
Unable to parse markup [type=CF_MATHJAX]
there are two formal power series associated with it:- Its Ordinary Generating Function (OGF) is
Unable to parse markup [type=CF_MATHJAX]
and we denote it
Unable to parse markup [type=CF_MATHJAX]
* Its Exponential Generating Function (EGF) isUnable to parse markup [type=CF_MATHJAX]
and we denote it
Unable to parse markup [type=CF_MATHJAX]
We can see that the EGF of
Unable to parse markup [type=CF_MATHJAX]
is just the OGF ofUnable to parse markup [type=CF_MATHJAX]
Probability Generating Functions
In particular, if the sequence
Unable to parse markup [type=CF_MATHJAX]
is the probability mass function of a discrete random variable $$$X$$$ taking non-negative integer values, then its OGF is also called the Probability Generating Function (PGF) of the random variableUnable to parse markup [type=CF_MATHJAX]
writtenUnable to parse markup [type=CF_MATHJAX]
Our first goal is to find the PGF of our random variable
Unable to parse markup [type=CF_MATHJAX]
and then we will show how to use that function to derive the final answer.Finding the PGF of $$$X$$$
It is difficult to consider "**the first time** when all $$$N$$$-sides have shown", so we drop that condition. We continue rolling, not stopping when all $$$N$$$-sides have already shown up, and let $$$a_i$$$ be the probability that all $$$N$$$-sides have shown up after $$$i$$$ rolls.
Then, we have
Unable to parse markup [type=CF_MATHJAX]
This is because subtracting the former term is equivalent to subtracting the probability that all $$$N$$$-sides have shown up before the $$$i$$$-th roll, and the probability that all $$$N$$$-sides have shown up for the first time at exactly the $$$i$$$-th roll remains.We try to find the OGF of
Unable to parse markup [type=CF_MATHJAX]
(A subtlety: although $$$a_i$$$ stores the probability of something, its OGF is not a PGF because $$$a_i$$$ is not a probability mass function, but just a tool for us to find
Unable to parse markup [type=CF_MATHJAX]
)However, it is easier to find its EGF first than to find its OGF directly. This is due to the properties of products of OGFs and EGFs.
Products of OGFs and EGFs
Let
Unable to parse markup [type=CF_MATHJAX]
andUnable to parse markup [type=CF_MATHJAX]
be sequences.OGFs
Let
Unable to parse markup [type=CF_MATHJAX]
be their OGFs, then its productUnable to parse markup [type=CF_MATHJAX]
is the OGF of
Unable to parse markup [type=CF_MATHJAX]
whereUnable to parse markup [type=CF_MATHJAX]
The way to understand its meaning is: let $$$a_i$$$ be the number of ways to take $$$i$$$ balls from a box, and $$$b_i$$$ be the number of ways to take $$$i$$$ balls from another box, then $$$c_i$$$ is the number of ways to take a total of $$$i$$$ balls from the two boxes.
Indeed, you can take $$$k$$$ balls from the first box, with $$$a_k$$$ ways, and
Unable to parse markup [type=CF_MATHJAX]
balls from the second box, withUnable to parse markup [type=CF_MATHJAX]
ways. So, the number of ways to take $$$i$$$ balls from the first box andUnable to parse markup [type=CF_MATHJAX]
balls from the second box isUnable to parse markup [type=CF_MATHJAX]
and you sum over all possible $$$k,$$$ which is from $$$0$$$ to $$$i.$$$EGFs
Let
Unable to parse markup [type=CF_MATHJAX]
be their EGFs, then its productUnable to parse markup [type=CF_MATHJAX]
Unable to parse markup [type=CF_MATHJAX]
is the EGF of
Unable to parse markup [type=CF_MATHJAX]
whereUnable to parse markup [type=CF_MATHJAX]
The difference between the products of OGFs and EGFs is a binomial number. The way to understand its meaning is: let $$$a_i$$$ be the number of ways to take $$$i$$$ balls from a box and align them in some order, and $$$b_i$$$ be the number of ways to take $$$i$$$ balls from another box and align them in some order, then $$$c_i$$$ is the number of ways to take a total of $$$i$$$ balls from the two boxes and align them in some order.
Similarly, the number of ways to take $$$i$$$ balls from the first box in some order and
Unable to parse markup [type=CF_MATHJAX]
balls from the second box in some order isUnable to parse markup [type=CF_MATHJAX]
Next, you haveUnable to parse markup [type=CF_MATHJAX]
ways to choose $$$k$$$ slots from the $$$i$$$ slots for the balls from the first box. Thus, the total way to choose and align them isUnable to parse markup [type=CF_MATHJAX]
When we roll the dice, we get a sequence of the side that shows up at each time, so there is an order. That's why it is easier to find the EGF of
Unable to parse markup [type=CF_MATHJAX]
When we roll the dice once, each face shows up with probability
Unable to parse markup [type=CF_MATHJAX]
If we consider a specific side, for example, $$$1,$$$ then the probability of getting all $$$1$$$'s in $$$i$$$ rolls isUnable to parse markup [type=CF_MATHJAX]
The EGF of the probability of getting all $$$1$$$'s in $$$i$$$ rolls is thereforeUnable to parse markup [type=CF_MATHJAX]
Note that we drop the case $$$i=0$$$ because we want that side to show up at least once.
Symmetrically, all $$$N$$$-sides have the same EGF. And the EGF of the probability of getting all $$$N$$$-sides in $$$i$$$ rolls is
Unable to parse markup [type=CF_MATHJAX]
We are just taking the product of the EGF corresponding to each side. As they are EGFs, their product automatically deals with the order of the sides that show up.
An example
If the derivation above seems a bit non-intuitive, we may verify it with
Unable to parse markup [type=CF_MATHJAX]
a dice with two sides.Trivially,
Unable to parse markup [type=CF_MATHJAX]
If we roll the dice twice, then
Unable to parse markup [type=CF_MATHJAX]
are two ways that make both sides show up. There are in total $$$4$$$ equally possible results (Unable to parse markup [type=CF_MATHJAX]
), soUnable to parse markup [type=CF_MATHJAX]
If we roll the dice three times, then
Unable to parse markup [type=CF_MATHJAX]
are the ways to get both sides showing up, soUnable to parse markup [type=CF_MATHJAX]
Similarly,
Unable to parse markup [type=CF_MATHJAX]
Therefore,
Unable to parse markup [type=CF_MATHJAX]
Unable to parse markup [type=CF_MATHJAX]
Unable to parse markup [type=CF_MATHJAX]
Using our formula,
Unable to parse markup [type=CF_MATHJAX]
Unable to parse markup [type=CF_MATHJAX]
Unable to parse markup [type=CF_MATHJAX]
which matches with our "brute-forcely" calculated
Unable to parse markup [type=CF_MATHJAX]
Now that we have the EGF of
Unable to parse markup [type=CF_MATHJAX]
we convert it to its OGF.Conversion between OGFs and EGFs
There are two laws:
- If
Unable to parse markup [type=CF_MATHJAX]
($$$f(x)$$$ and $$$F(x)$$$ are the OGF and EGF of the same sequence) and
Unable to parse markup [type=CF_MATHJAX]
then
Unable to parse markup [type=CF_MATHJAX]
This law tells us there is a sense of 'linearity' between sequences and their GFs. When doing conversions, we can deal with separate terms and add them up.
- For all constant $$$k,$$$
Unable to parse markup [type=CF_MATHJAX]
The OGF is a geometric series and the EGF is the exponential function.
With the two rules above, we have
Unable to parse markup [type=CF_MATHJAX]
Unable to parse markup [type=CF_MATHJAX]
And finally, we compute the PGF of
Unable to parse markup [type=CF_MATHJAX]
which isUnable to parse markup [type=CF_MATHJAX]
Unable to parse markup [type=CF_MATHJAX]
(sinceUnable to parse markup [type=CF_MATHJAX]
)Unable to parse markup [type=CF_MATHJAX]
Unable to parse markup [type=CF_MATHJAX]
Unable to parse markup [type=CF_MATHJAX]
Unable to parse markup [type=CF_MATHJAX]
(Note: multiplying an OGF by $$$1-x$$$ is the same as subtracting each term in the sequence by its former term. On the other hand, its inverse action, multiplying by
Unable to parse markup [type=CF_MATHJAX]
is the same as taking the prefix sum of each term.)Though it is a 'nasty' formula, we will show later how to compute it in a code.
Spoil alert: there is a much easier derivation of $$$g(x)$$$ at the end of this blog.
Here is the final step: Calculating the expected value of
Unable to parse markup [type=CF_MATHJAX]
from the PGF.Moment Generating Functions
Similar to PGF, the OGF of a probability mass function, the Moment Generating Function (MGF) is the EGF of something else.
The MGF of a random variable $$$X$$$ is
Unable to parse markup [type=CF_MATHJAX]
Here are some algebraic manipulations:
Unable to parse markup [type=CF_MATHJAX]
Unable to parse markup [type=CF_MATHJAX]
which is exactly the EGF of our answer!
(Note: actually the summation with expected values is a more general definition of MGF, as it can be defined for random variables that are not only taking values of non-negative integers.)
Lastly, for the random variable $$$X$$$ taking the value of non-negative integers, like in our problem, we have
Unable to parse markup [type=CF_MATHJAX]
by definition.
Therefore, our final goal is to find the coefficients up to
Unable to parse markup [type=CF_MATHJAX]
of the MGF ofUnable to parse markup [type=CF_MATHJAX]
which isUnable to parse markup [type=CF_MATHJAX]
Implementation: Convolutions
Prerequisites: Convolution and inverse series.
In the implementation, I used the class modint998244353
and convolution()
from Atcoder Library for calculations in
Unable to parse markup [type=CF_MATHJAX]
and FFT.For how FFT works and more, see this blog.
We do this by explicitly calculating the PGF
Unable to parse markup [type=CF_MATHJAX]
and then the MGFUnable to parse markup [type=CF_MATHJAX]
Calculating $$$g(x)$$$
We have the explicit formula
Unable to parse markup [type=CF_MATHJAX]
The summation
Unable to parse markup [type=CF_MATHJAX]
can be written as a rational functionUnable to parse markup [type=CF_MATHJAX]
with $$$p(x)$$$ and $$$q(x)$$$ each a polynomial with degree at mostUnable to parse markup [type=CF_MATHJAX]
As it is the sum of a bunch of fractions in the form
Unable to parse markup [type=CF_MATHJAX]
we may combine them in some order usingconvolution()
.By FFT, the time complexity of multiplying two polynomials is
Unable to parse markup [type=CF_MATHJAX]
where $$$n$$$ is the higher degree of the polynomials. So, the best way to combine the fractions is by Divide-and-Conquer: Divide the summations in half, calculate each half to get a rational function, and then combine the two rational functions.Here is the class of rational functions and its addition method:
using mint=modint998244353; //calculation in mod 998244353
using ply=vector<mint>; //polynomials
struct R{ply p,q; //numerator and denominator
R operator+(R b){
ply rp(add(convolution(q,b.p),convolution(p,b.q))),
rq(convolution(q,b.q));
return{rp,rq};
}
};
ply add(ply a,ply b){ //adding two polynomials
if(a.size()<b.size())swap(a,b);
Frn0(i,0,b.size())a[i]+=b[i];
return a;
}
Here is the divide-and-conquer summation of rational functions, stored in vector<R>a
.
R sum(vector<R>&a,int l,int r){ //summing from a[l] to a[r]
if(l==r)return a[l];
int md((l+r)/2);
return sum(a,l,md)+sum(a,md+1,r);
}
The summation is done. For the remaining factor
Unable to parse markup [type=CF_MATHJAX]
there are two ways:- Multiply it by the numerator. This can be done by subtracting each term by its former term. Note that the degree will increase by $$$1.$$$
- (used here) As the denominator already has a $$$1-x$$$ factor (check the summands), we can remove this factor by taking the prefix sum of each term, which is the same as multiplying
Unable to parse markup [type=CF_MATHJAX]
And now, we obtain the PGF $$$g(x)$$$ as a rational function.
From $$$g(x)$$$ to
Unable to parse markup [type=CF_MATHJAX]
As
Unable to parse markup [type=CF_MATHJAX]
is a rational function. We calculateUnable to parse markup [type=CF_MATHJAX]
andUnable to parse markup [type=CF_MATHJAX]
separately and use inverse series to combine them. As we only need the coefficients from $$$x$$$ toUnable to parse markup [type=CF_MATHJAX]
we may take the resultsUnable to parse markup [type=CF_MATHJAX]
--- For a polynomial
Unable to parse markup [type=CF_MATHJAX]
Using our trick of conversion between EGFs and OGFs again:
Unable to parse markup [type=CF_MATHJAX]
So we may calculate the summation on the right hand side by the same Divide-and-Conquer technique. Use inverse series to get its coefficients in power series, and then divide the $$$i$$$-th term by $$$i!$$$ to obtain the left hand side.
The following is an implementation of inverse series
Unable to parse markup [type=CF_MATHJAX]
ply pinv(ply f,int m){
ply g({f[0].inv()});
f.resize(m);
for(int s(2);s<2*m;s<<=1){
auto tmp(convolution(convolution(g,g),
ply(f.begin(),f.begin()+min(s,m))));
g.resize(min(s,m));
Frn0(i,0,min(s,m))g[i]=2*g[i]-tmp[i];
}
return g;
}
The following is calculating
Unable to parse markup [type=CF_MATHJAX]
ply fex(ply f,int m){
vector<R>a(f.size());
Frn0(i,0,f.size())a[i].p={f[i]},a[i].q={1,-i};
R s(sum(a,0,a.size()-1)); //DC summation
auto re(convolution(s.p,pinv(s.q,m)));
re.resize(m);
Frn0(i,0,m)re[i]/=fc[i]; //dividing by i!
return re;
}
Code
Time Complexity:
Unable to parse markup [type=CF_MATHJAX]
(DC summation and inverse series)Memory Complexity: $$$O(n+m)$$$
Further details are annotated.
//This program is written by Brian Peng.
#include<bits/stdc++.h>
#include<atcoder/convolution>
using namespace std;
using namespace atcoder;
#define Rd(a) (a=rd())
#define Gc(a) (a=getchar())
#define Pc(a) putchar(a)
int rd(){
int x;char c(getchar());bool k;
while(!isdigit(c)&&c^'-')if(Gc(c)==EOF)exit(0);
c^'-'?(k=1,x=c&15):k=x=0;
while(isdigit(Gc(c)))x=x*10+(c&15);
return k?x:-x;
}
void wr(int a){
if(a<0)Pc('-'),a=-a;
if(a<=9)Pc(a|'0');
else wr(a/10),Pc((a%10)|'0');
}
signed const INF(0x3f3f3f3f),NINF(0xc3c3c3c3);
long long const LINF(0x3f3f3f3f3f3f3f3fLL),LNINF(0xc3c3c3c3c3c3c3c3LL);
#define Ps Pc(' ')
#define Pe Pc('\n')
#define Frn0(i,a,b) for(int i(a);i<(b);++i)
#define Frn1(i,a,b) for(int i(a);i<=(b);++i)
#define Frn_(i,a,b) for(int i(a);i>=(b);--i)
#define Mst(a,b) memset(a,b,sizeof(a))
#define All(a) (a).begin(),(a).end()
#define File(a) freopen(a".in","r",stdin),freopen(a".out","w",stdout)
using mint=modint998244353;
using ply=vector<mint>;
#define N (200010)
int n,m;
mint fc[N]{1};
ply ans;
ply pinv(ply f,int m);
ply add(ply a,ply b);
struct R{ply p,q;
R operator+(R b){
ply rp(add(convolution(q,b.p),convolution(p,b.q))),
rq(convolution(q,b.q));
return{rp,rq};
}
}g;
vector<R>a;
R sum(vector<R>&a,int l,int r);
mint cmb(int n,int r){return fc[n]/(fc[r]*fc[n-r]);} //binomial numbers
ply fex(ply f,int m);
signed main(){
Rd(n),Rd(m);
Frn1(i,1,max(n,m))fc[i]=fc[i-1]*i; //factorials
a.resize(n+1);
mint niv(mint(n).inv());
Frn1(i,0,n){
a[i].p={(((n-i)&1)?-1:1)*cmb(n,i)};
a[i].q={1,-niv*i};
} //the terms of the summation in g(x)
g=sum(a,0,n);
Frn0(i,1,g.q.size())g.q[i]+=g.q[i-1]; //denominator dividing 1-x
//by taking prefix sums, obtaining PGF
ans=convolution(fex(g.p,m+1),pinv(fex(g.q,m+1),m+1));
//obtaining MGF
Frn1(i,1,m)wr((ans[i]*fc[i]).val()),Pe;
//remember to multiply by i! as it is an EGF
exit(0);
}
ply pinv(ply f,int m){
ply g({f[0].inv()});
f.resize(m);
for(int s(2);s<2*m;s<<=1){
auto tmp(convolution(convolution(g,g),
ply(f.begin(),f.begin()+min(s,m))));
g.resize(min(s,m));
Frn0(i,0,min(s,m))g[i]=2*g[i]-tmp[i];
}
return g;
}
ply add(ply a,ply b){
if(a.size()<b.size())swap(a,b);
Frn0(i,0,b.size())a[i]+=b[i];
return a;
}
R sum(vector<R>&a,int l,int r){
if(l==r)return a[l];
int md((l+r)/2);
return sum(a,l,md)+sum(a,md+1,r);
}
ply fex(ply f,int m){
vector<R>a(f.size());
Frn0(i,0,f.size())a[i].p={f[i]},a[i].q={1,-i};
R s(sum(a,0,a.size()-1));
auto re(convolution(s.p,pinv(s.q,m)));
re.resize(m);
Frn0(i,0,m)re[i]/=fc[i];
return re;
}
Extensions
An alternative way to find the PGF of $$$X$$$
We may track the number of rolls to get a new side showing up when $$$i$$$ sides have already shown up.
When $$$i$$$ sides have already shown up, the probability of getting a new side in a roll is
Unable to parse markup [type=CF_MATHJAX]
Let $$$X_i$$$ be the random variable of the number of rolls, thenUnable to parse markup [type=CF_MATHJAX]
As the PGF of
Unable to parse markup [type=CF_MATHJAX]
isUnable to parse markup [type=CF_MATHJAX]
the PGF of $$$X_i$$$ isUnable to parse markup [type=CF_MATHJAX]
By Convolution Theorem of PGF, the PGF of the total number of rollsUnable to parse markup [type=CF_MATHJAX]
isUnable to parse markup [type=CF_MATHJAX]
It seems to be a lot easier to do... So the product of these small polynomials can still be done by a similar Divide-and-Conquer method.
//This program is written by Brian Peng.
#include<bits/stdc++.h>
#include<atcoder/convolution>
using namespace std;
using namespace atcoder;
#define Rd(a) (a=rd())
#define Gc(a) (a=getchar())
#define Pc(a) putchar(a)
int rd(){
int x;char c(getchar());bool k;
while(!isdigit(c)&&c^'-')if(Gc(c)==EOF)exit(0);
c^'-'?(k=1,x=c&15):k=x=0;
while(isdigit(Gc(c)))x=x*10+(c&15);
return k?x:-x;
}
void wr(int a){
if(a<0)Pc('-'),a=-a;
if(a<=9)Pc(a|'0');
else wr(a/10),Pc((a%10)|'0');
}
signed const INF(0x3f3f3f3f),NINF(0xc3c3c3c3);
long long const LINF(0x3f3f3f3f3f3f3f3fLL),LNINF(0xc3c3c3c3c3c3c3c3LL);
#define Ps Pc(' ')
#define Pe Pc('\n')
#define Frn0(i,a,b) for(int i(a);i<(b);++i)
#define Frn1(i,a,b) for(int i(a);i<=(b);++i)
#define Frn_(i,a,b) for(int i(a);i>=(b);--i)
#define Mst(a,b) memset(a,b,sizeof(a))
#define All(a) (a).begin(),(a).end()
#define File(a) freopen(a".in","r",stdin),freopen(a".out","w",stdout)
using mint=modint998244353;
using ply=vector<mint>;
#define N (200010)
int n,m;
mint fc[N]{1};
ply ans;
ply pinv(ply f,int m);
ply add(ply a,ply b);
struct R{ply p,q;
R operator+(R b){
ply rp(add(convolution(q,b.p),convolution(p,b.q))),
rq(convolution(q,b.q));
return{rp,rq};
}
}g;
vector<ply>a;
R sum(vector<R>&a,int l,int r);
ply prod(vector<ply>&a,int l,int r); //DC Multiplication
ply fex(ply f,int m);
signed main(){
Rd(n),Rd(m);
Frn1(i,1,max(n,m))fc[i]=fc[i-1]*i;
g.p.resize(n+1),g.p[n]=fc[n],a.resize(n);
Frn0(i,0,n)a[i]={n,-i};
g.q=prod(a,0,n-1);
ans=convolution(fex(g.p,m+1),pinv(fex(g.q,m+1),m+1));
Frn1(i,1,m)wr((ans[i]*fc[i]).val()),Pe;
exit(0);
}
ply pinv(ply f,int m){
ply g({f[0].inv()});
f.resize(m);
for(int s(2);s<2*m;s<<=1){
auto tmp(convolution(convolution(g,g),
ply(f.begin(),f.begin()+min(s,m))));
g.resize(min(s,m));
Frn0(i,0,min(s,m))g[i]=2*g[i]-tmp[i];
}
return g;
}
ply add(ply a,ply b){
if(a.size()<b.size())swap(a,b);
Frn0(i,0,b.size())a[i]+=b[i];
return a;
}
R sum(vector<R>&a,int l,int r){
if(l==r)return a[l];
int md((l+r)/2);
return sum(a,l,md)+sum(a,md+1,r);
}
ply prod(vector<ply>&a,int l,int r){
if(l==r)return a[l];
int md((l+r)/2);
return convolution(prod(a,l,md),prod(a,md+1,r));
}
ply fex(ply f,int m){
vector<R>a(f.size());
Frn0(i,0,f.size())a[i].p={f[i]},a[i].q={1,-i};
R s(sum(a,0,a.size()-1));
auto re(convolution(s.p,pinv(s.q,m)));
re.resize(m);
Frn0(i,0,m)re[i]/=fc[i];
return re;
}
It is really easier to implement, and took 300ms less time than the previous one...
THANKS FOR READING!