CCPC2021广州 Gym103415A Math Ball

Problem – A – Codeforces

题意:你有$n$种球,同种球没有区别,并且每种球都有无限个。

你可以拿一些球,如果第$i$种球拿了$k_i$个,那么你就会得到$\prod_i k_i^{c_i}$的价值。

求总共拿不超过$w$个球的所有方案的价值的总和。

$n \le 10^5,\;\sum_{i} c_i \le 10^5,\; w \le 10 ^ {18}$。


如果我们能得到关于选球个数的生成函数$F(x)$,那么$\frac {F(x)}{1 – x}$的$x^w$项的系数就是答案。

设第$i$种球的生成函数为$f_{c_i}(x)$,显然$F(x)=\prod f_{c_i}(x)$。

设$c_i=n$,我们可以用第二类斯特林数展开n次幂:

$$
\begin{aligned} f_n(x) =& \sum_{i = 0} ^ \infty i^n x^i \\ =& \sum_{i = 0} ^ \infty x^i \sum_{k = 0} ^ n {n \brace k} i^{\underline k} \\ =& \sum_{i = 0} ^ \infty x^i \sum_{k = 0} ^ n {n \brace k} {i \choose k} k! \\ =& \sum_{k = 0} ^ n {n \brace k} k! \sum_{i = 0} ^ \infty {i \choose k} x^i \\ =& \sum_{k = 0} ^ n {n \brace k} k! \frac {x^k} {(1-x) ^ {k + 1}} \\ =& \sum_{k = 0} ^ n {n \brace k} k! \frac {x^k (1-x) ^ {n – k}} {(1-x) ^ {n + 1}} \end{aligned}
$$

$\frac 1 {(1-x) ^ {n+1}}$是一个与$k$无关的值,我们可以放在最后再一起乘起来:

$$
\begin{aligned} g_n(x) =& (1-x) ^ {n+1} f_n(x) \\ =& \sum_{k = 0} ^ n {n \brace k} k! x^k (1-x) ^ {n-k} \\ =& \sum_{i = 0} ^ n x^i \sum_{k = 0} ^ i {n \brace k}k! \left[ x^{i-k} \right] (1-x) ^ {n – k} \\ =& \sum_{i = 0} ^ n x^i \sum_{k = 0} ^ i {n \brace k}k! {n-k \choose i-k}(-1)^{i-k} \\ =& \sum_{i = 0} ^ n \frac {x^i} {(n-i)!} \sum_{k = 0} ^ i {n \brace k}k!(n-k)! \frac {(-1)^{i-k}} {(i-k)!} \end{aligned}
$$

化成了一个卷积,第二类斯特林数可以$O(n\log n)$预处理,然后代入再跑一次卷积就可以求出$g_n(x)$了。

去掉$\frac 1 {(1-x) ^ {n + 1}}$之后剩下的$g_n(x)$显然是一个$n$次多项式,因为$\sum c_i \le 10^5$,所以直接用分治FFT合并所有的$g_{c_i}(x)$就行了,假设最后得到的结果是$G(x)$,那么就有

$$ F(x) = \frac {G(x)} {(1 – x) ^ {n + \sum c_i + 1}} $$

(多除一次是因为题目要求对$\le w$的所有情况都求和)

$\frac 1 {(1-x)^k}$可以用组合数表示,所以要求第$w$项只需要递推组合数加起来就行了。


总的来说比较值得注意的几个点:

  1. 整体的思路,用斯特林数展开很好做的原因主要在于,斯特林数展开之后会得到组合数,而组合数又可以表示成$\frac 1 {(1-x)^k}$的形式,去掉这部分之后化成$c_i$次多项式就可以直接分治FFT了。
  2. 推式子的细节,补题的时候推式子还是花了比较久,并且出了几次问题。
  3. 板子里没有的细节,比如补题的时候预处理$g_n(x)$忘了清空,还查了一会。
#include <bits/stdc++.h>

using namespace std;

constexpr int maxn = 262155, p = 998244353;

int qpow(int a, int b) {
	int ans = 1;

	while (b) {
		if (b & 1)
			ans = (long long)ans * a % p;
		
		b >>= 1;
		a = (long long)a * a % p;
	}

	return ans;
}

int ntt_n, omega[maxn], omega_inv[maxn];

void NTT_init(int n) {
	ntt_n = n;

	int wn = qpow(3, (p - 1) / n);

	omega[0] = omega_inv[0] = 1;

	for (int i = 1; i < n; i++)
		omega_inv[n - i] = omega[i] = (long long)omega[i - 1] * wn % p;
}

void NTT(int *a, int n, int tp) {
	for (int i = 1, j = 0, k; i < n - 1; i++) {
		k = n;
		do
			j ^= (k >>= 1);
		while (j < k);

		if (i < j)
			swap(a[i], a[j]);
	}

	for (int k = 2, m = ntt_n / 2; k <= n; k *= 2, m /= 2)
		for (int i = 0; i < n; i += k)
			for (int j = 0; j < k / 2; j++) {
				int w = (tp > 0 ? omega : omega_inv)[m * j];

				int u = a[i + j], v = (long long)w * a[i + j + k / 2] % p;
				
				a[i + j] = u + v;
				if (a[i + j] >= p)
					a[i + j] -= p;
				
				a[i + j + k / 2] = u - v;
				if (a[i + j + k / 2] < 0)
					a[i + j + k / 2] += p;
			}
	
	if (tp < 0) {
		int inv = qpow(n, p - 2);
		for (int i = 0; i < n; i++)
			a[i] = (long long)a[i] * inv % p;
	}
}

int fac[maxn], fac_inv[maxn];

void work(int n, vector<int> &v) {
	static int a[maxn], b[maxn], s[maxn];

	int N = 1;
	while (N < (n + 1) * 2)
		N *= 2;

	for (int i = 0; i <= n; i++) {
		a[i] = (i % 2 ? p - fac_inv[i] : fac_inv[i]);
		b[i] = (long long)qpow(i, n) * fac_inv[i] % p;
	}

	NTT(a, N, 1);
	NTT(b, N, 1);

	for (int i = 0; i < N; i++)
		a[i] = (long long)a[i] * b[i] % p;
	
	NTT(a, N, -1);
	
	for (int i = 0; i <= n; i++)
		s[i] = (long long)a[i] * fac[i] % p; // 斯特林数乘阶乘
	
	memset(a, 0, sizeof(int) * N);
	memset(b, 0, sizeof(int) * N);

	for (int i = 0; i <= n; i++) {
		a[i] = (long long)s[i] * fac[n - i] % p;
		b[i] = (i % 2 ? p - fac_inv[i] : fac_inv[i]);
	}

	NTT(a, N, 1);
	NTT(b, N, 1);

	for (int i = 0; i < N; i++)
		a[i] = (long long)a[i] * b[i] % p;
	
	NTT(a, N, -1);

	v.resize(n + 1);

	for (int i = 0; i <= n; i++)
		v[i] = (long long)fac_inv[n - i] * a[i] % p;

	memset(s, 0, sizeof(int) * (n + 1));
	memset(a, 0, sizeof(int) * N);
	memset(b, 0, sizeof(int) * N);
}

void multi(vector<int> &u, vector<int> &v) {
	static int a[maxn], b[maxn];

	int n = (int)u.size() - 1, m = (int)v.size() - 1;

	memcpy(a, &u[0], sizeof(int) * (n + 1));
	memcpy(b, &v[0], sizeof(int) * (m + 1));

	int N = 1;
	while (N <= n + m)
		N *= 2;

	NTT(a, N, 1);
	NTT(b, N, 1);

	for (int i = 0; i < N; i++)
		a[i] = (long long)a[i] * b[i] % p;
	
	NTT(a, N, -1);

	u.resize(n + m + 1);
	memcpy(&u[0], a, sizeof(int) * (n + m + 1));

	v.clear();

	memset(a, 0, sizeof(int) * N);
	memset(b, 0, sizeof(int) * N);
}

int c[maxn];

vector<int> vec[maxn];

void solve(int l, int r) {
	if (l == r) {
		work(c[l], vec[l]);
		return;
	}

	int mid = (l + r) / 2;

	solve(l, mid);
	solve(mid + 1, r);

	multi(vec[l], vec[mid + 1]);
}

int main() {

	int n;
	long long w;

	scanf("%d%lld", &n, &w);

	int m = 0;

	for (int i = 1; i <= n; i++) {
		scanf("%d", &c[i]);
		m += c[i];
	}

	int N = 1;
	while (N <= m * 2)
		N *= 2;
	
	NTT_init(N);

	fac[0] = fac_inv[0] = 1;

	for (int i = 1; i <= m; i++) {
		fac[i] = (long long)fac[i - 1] * i % p;
		fac_inv[i] = (long long)fac_inv[i - 1] * qpow(i, p - 2) % p;
	}

	solve(1, n);

	vector<int> &v = vec[1];

	w += n + m;

	int choose = 1;
	for (int i = 1; i <= m + n; i++)
		choose = (long long)choose * i % p;
	choose = qpow(choose, p - 2);

	for (int i = 0; i < n + m; i++)
		choose = (w - i) % p * choose % p;
	
	int ans = 0;
	for (int i = 0; i <= m && choose; i++) {
		ans = (ans + (long long)choose * v[i]) % p;

		choose = (long long)choose * qpow((w - i) % p, p - 2) % p;
		choose = (w - i - (n + m)) % p * choose % p;
	}

	printf("%d\n", ans);

	return 0;
}

发表评论

您的电子邮箱地址不会被公开。 必填项已用*标注