题意:你有 $n$ 种球,同种球没有区别,并且每种球都有无限个。
你可以拿一些球,如果第 $i$ 种球拿了 $k_i$ 个,那么你就会得到 $\prod_i k_i^{c_i}$ 的价值。
求总共拿不超过 $w$ 个球的所有方案的价值的总和。
$n \le 10^5,\;\sum_{i} c_i \le 10^5,\; w \le 10 ^ {18}$。
如果我们能得到关于选球个数的生成函数 $F(x)$ ,那么 $\frac {F(x)}{1 – x}$ 的 $x^w$ 项的系数就是答案。
设第 $i$ 种球的生成函数为 $f_{c_i}(x)$,显然 $F(x)=\prod f_{c_i}(x)$。
设 $c_i=n$,我们可以用第二类斯特林数展开 $n$ 次幂:
$$
\begin{aligned} f_n(x) =& \sum_{i = 0} ^ \infty i^n x^i \\ =& \sum_{i = 0} ^ \infty x^i \sum_{k = 0} ^ n {n \brace k} i^{\underline k} \\ =& \sum_{i = 0} ^ \infty x^i \sum_{k = 0} ^ n {n \brace k} {i \choose k} k! \\ =& \sum_{k = 0} ^ n {n \brace k} k! \sum_{i = 0} ^ \infty {i \choose k} x^i \\ =& \sum_{k = 0} ^ n {n \brace k} k! \frac {x^k} {(1-x) ^ {k + 1}} \\ =& \sum_{k = 0} ^ n {n \brace k} k! \frac {x^k (1-x) ^ {n – k}} {(1-x) ^ {n + 1}} \end{aligned}
$$
$\frac 1 {(1-x) ^ {n+1}}$ 是一个与 $k$ 无关的值,我们可以放在最后再一起乘起来:
$$
\begin{aligned} g_n(x) =& (1-x) ^ {n+1} f_n(x) \\ =& \sum_{k = 0} ^ n {n \brace k} k! x^k (1-x) ^ {n-k} \\ =& \sum_{i = 0} ^ n x^i \sum_{k = 0} ^ i {n \brace k}k! \left[ x^{i-k} \right] (1-x) ^ {n – k} \\ =& \sum_{i = 0} ^ n x^i \sum_{k = 0} ^ i {n \brace k}k! {n-k \choose i-k}(-1)^{i-k} \\ =& \sum_{i = 0} ^ n \frac {x^i} {(n-i)!} \sum_{k = 0} ^ i {n \brace k}k!(n-k)! \frac {(-1)^{i-k}} {(i-k)!} \end{aligned}
$$
化成了一个卷积,第二类斯特林数可以 $O(n\log n)$ 预处理,然后代入再跑一次卷积就可以求出 $g_n(x)$ 了。
去掉 $\frac 1 {(1-x) ^ {n + 1}}$ 之后剩下的 $g_n(x)$ 显然是一个 $n$ 次多项式,因为 $\sum c_i \le 10^5$,所以直接用分治 FFT 合并所有的 $g_{c_i}(x)$ 就行了,假设最后得到的结果是 $G(x)$,那么就有
$$ F(x) = \frac {G(x)} {(1 – x) ^ {n + \sum c_i + 1}} $$
(多除一次是因为题目要求对 $\le w$ 的所有情况都求和)
$\frac 1 {(1-x)^k}$ 可以用组合数表示,所以要求第 $w$ 项只需要递推组合数加起来就行了。
总的来说比较值得注意的几个点:
- 整体的思路,用斯特林数展开很好做的原因主要在于,斯特林数展开之后会得到组合数,而组合数又可以表示成 $\frac 1 {(1-x)^k}$ 的形式,去掉这部分之后化成 $c_i$ 次多项式就可以直接分治 FFT 了。
- 推式子的细节,补题的时候推式子还是花了比较久,并且出了几次问题。
- 板子里没有的细节,比如补题的时候预处理 $g_n(x)$ 忘了清空,还查了一会。
#include <bits/stdc++.h>
using namespace std;
constexpr int maxn = 262155, p = 998244353;
int qpow(int a, int b) {
int ans = 1;
while (b) {
if (b & 1)
ans = (long long)ans * a % p;
b >>= 1;
a = (long long)a * a % p;
}
return ans;
}
int ntt_n, omega[maxn], omega_inv[maxn];
void NTT_init(int n) {
ntt_n = n;
int wn = qpow(3, (p - 1) / n);
omega[0] = omega_inv[0] = 1;
for (int i = 1; i < n; i++)
omega_inv[n - i] = omega[i] = (long long)omega[i - 1] * wn % p;
}
void NTT(int *a, int n, int tp) {
for (int i = 1, j = 0, k; i < n - 1; i++) {
k = n;
do
j ^= (k >>= 1);
while (j < k);
if (i < j)
swap(a[i], a[j]);
}
for (int k = 2, m = ntt_n / 2; k <= n; k *= 2, m /= 2)
for (int i = 0; i < n; i += k)
for (int j = 0; j < k / 2; j++) {
int w = (tp > 0 ? omega : omega_inv)[m * j];
int u = a[i + j], v = (long long)w * a[i + j + k / 2] % p;
a[i + j] = u + v;
if (a[i + j] >= p)
a[i + j] -= p;
a[i + j + k / 2] = u - v;
if (a[i + j + k / 2] < 0)
a[i + j + k / 2] += p;
}
if (tp < 0) {
int inv = qpow(n, p - 2);
for (int i = 0; i < n; i++)
a[i] = (long long)a[i] * inv % p;
}
}
int fac[maxn], fac_inv[maxn];
void work(int n, vector<int> &v) {
static int a[maxn], b[maxn], s[maxn];
int N = 1;
while (N < (n + 1) * 2)
N *= 2;
for (int i = 0; i <= n; i++) {
a[i] = (i % 2 ? p - fac_inv[i] : fac_inv[i]);
b[i] = (long long)qpow(i, n) * fac_inv[i] % p;
}
NTT(a, N, 1);
NTT(b, N, 1);
for (int i = 0; i < N; i++)
a[i] = (long long)a[i] * b[i] % p;
NTT(a, N, -1);
for (int i = 0; i <= n; i++)
s[i] = (long long)a[i] * fac[i] % p; // 斯特林数乘阶乘
memset(a, 0, sizeof(int) * N);
memset(b, 0, sizeof(int) * N);
for (int i = 0; i <= n; i++) {
a[i] = (long long)s[i] * fac[n - i] % p;
b[i] = (i % 2 ? p - fac_inv[i] : fac_inv[i]);
}
NTT(a, N, 1);
NTT(b, N, 1);
for (int i = 0; i < N; i++)
a[i] = (long long)a[i] * b[i] % p;
NTT(a, N, -1);
v.resize(n + 1);
for (int i = 0; i <= n; i++)
v[i] = (long long)fac_inv[n - i] * a[i] % p;
memset(s, 0, sizeof(int) * (n + 1));
memset(a, 0, sizeof(int) * N);
memset(b, 0, sizeof(int) * N);
}
void multi(vector<int> &u, vector<int> &v) {
static int a[maxn], b[maxn];
int n = (int)u.size() - 1, m = (int)v.size() - 1;
memcpy(a, &u[0], sizeof(int) * (n + 1));
memcpy(b, &v[0], sizeof(int) * (m + 1));
int N = 1;
while (N <= n + m)
N *= 2;
NTT(a, N, 1);
NTT(b, N, 1);
for (int i = 0; i < N; i++)
a[i] = (long long)a[i] * b[i] % p;
NTT(a, N, -1);
u.resize(n + m + 1);
memcpy(&u[0], a, sizeof(int) * (n + m + 1));
v.clear();
memset(a, 0, sizeof(int) * N);
memset(b, 0, sizeof(int) * N);
}
int c[maxn];
vector<int> vec[maxn];
void solve(int l, int r) {
if (l == r) {
work(c[l], vec[l]);
return;
}
int mid = (l + r) / 2;
solve(l, mid);
solve(mid + 1, r);
multi(vec[l], vec[mid + 1]);
}
int main() {
int n;
long long w;
scanf("%d%lld", &n, &w);
int m = 0;
for (int i = 1; i <= n; i++) {
scanf("%d", &c[i]);
m += c[i];
}
int N = 1;
while (N <= m * 2)
N *= 2;
NTT_init(N);
fac[0] = fac_inv[0] = 1;
for (int i = 1; i <= m; i++) {
fac[i] = (long long)fac[i - 1] * i % p;
fac_inv[i] = (long long)fac_inv[i - 1] * qpow(i, p - 2) % p;
}
solve(1, n);
vector<int> &v = vec[1];
w += n + m;
int choose = 1;
for (int i = 1; i <= m + n; i++)
choose = (long long)choose * i % p;
choose = qpow(choose, p - 2);
for (int i = 0; i < n + m; i++)
choose = (w - i) % p * choose % p;
int ans = 0;
for (int i = 0; i <= m && choose; i++) {
ans = (ans + (long long)choose * v[i]) % p;
choose = (long long)choose * qpow((w - i) % p, p - 2) % p;
choose = (w - i - (n + m)) % p * choose % p;
}
printf("%d\n", ans);
return 0;
}
绿某人的博客,我这种小朋友看了是根本顶不住
Great content! Keep up the good work!