题意:你有一个字符串集合 $D$,还有一个串 $s$。现在有 $q$ 次修改操作,每次会把 $s$ 的某个后缀全部改成同一个字符,要求在最开始和每次修改后输出有多少个 $D$ 中的字符串的字典序严格小于 $s$。
$q \le 10^6$,$s$ 的长度和 $D$ 中字符串的串长总和不超过 $10^6$。
对于这类每次修改后缀的问题,可以考虑用一个栈维护连续的相同段,修改时一定是弹出几个尾部元素,然后加入一个新的段。
因为 $s$ 的长度是始终保持不变的,所以可以对 $D$ 建出字典树之后预处理出从每个结点开始一直填某个字符直到长度和 $s$ 相同时会走到哪里,和字典序比这个串更小的串的个数。另外为了方便往回跳,还需要维护一个向父亲方向走的倍增。
注意到 $s$ 对应的结点在字典树上很可能是不存在的,因此如果按 $s$ 走到某一步之后无路可走了,就要在这个点停下来(而不是继续走到 $0$,那样就没法回溯了)。关于这个部分在遍历字典树时还需要加一些判断,因为如果之前已经在某个结点被卡住了,换一个新的字符之后即使存在对应的儿子也不能再走了(因为实际上之前就卡住了,例如字典树只有 ba,$s$ 是 aa…ab 的情况)。
细节还是比较多的,没有用到什么板子,关于细节的处理全靠手写。以后应当多练类似的题目,抄板子对码力的提升显然不如手写来得明显。
#include <bits/stdc++.h>
using namespace std;
constexpr int maxn = 1000005;
int ch[maxn][26], val[maxn], sum[maxn], trie_cnt = 0;
int go[maxn][26], dp[maxn][26];
int f[25][maxn];
int d[maxn];
void insert(const char *c) {
int x = 1;
while (*c) {
if (!ch[x][*c - 'a']) {
int y = ch[x][*c - 'a'] = ++trie_cnt;
d[y] = d[x] + 1;
f[0][y] = x;
}
x = ch[x][*c++ - 'a'];
}
val[x]++;
}
char s[maxn];
char t[maxn];
int main() {
int n, q;
scanf("%d%d", &n, &q);
scanf("%s", s + 1);
trie_cnt = 1;
while (n--) {
scanf("%s", t);
insert(t);
}
n = strlen(s + 1);
for (int i = 1; i <= trie_cnt; i++)
for (int c = 0; c < 26; c++)
if (ch[i][c])
sum[ch[i][c]] = sum[i] + val[ch[i][c]];
for (int i = trie_cnt; i; i--)
for (int c = 0; c < 26; c++)
if (ch[i][c])
val[i] += val[ch[i][c]];
for (int i = trie_cnt; i; i--) {
if (d[i] < n) {
int tmp = 0;
for (int c = 0; c < 26; c++) {
dp[i][c] = dp[ch[i][c]][c] + tmp;
tmp += val[ch[i][c]];
}
}
if (d[i] <= n) {
for (int c = 0; c < 26; c++)
go[i][c] = go[ch[i][c]][c] ? go[ch[i][c]][c] : i;
}
}
for (int j = 1; (1 << j) <= trie_cnt; j++)
for (int i = 1; i <= trie_cnt; i++)
f[j][i] = f[j - 1][f[j - 1][i]];
vector<tuple<int, int, int> > v;
int ans = 0;
v.emplace_back(0, 0, -1);
int x = 1;
for (int i = 1; i <= n; i++) {
if (d[x] == i - 1) {
if (i > 1)
ans -= dp[x][s[i - 1] - 'a'];
ans += dp[x][s[i] - 'a'];
}
v.emplace_back(i, x, s[i] - 'a');
if (d[x] == i - 1 && ch[x][s[i] - 'a'])
x = ch[x][s[i] - 'a'];
}
printf("%d\n", ans + sum[d[x] == n ? f[0][x] : x]);
while (q--) {
int k;
char chr;
scanf("%d %c", &k, &chr);
int c = chr - 'a';
while (get<0>(v.back()) >= k) {
auto &o = v.back(), &u = v[(int)v.size() - 2];
if (d[get<1>(o)] == get<0>(o) - 1) {
if (get<0>(o) > 1)
ans += dp[get<1>(o)][get<2>(u)];
ans -= dp[get<1>(o)][get<2>(o)];
}
v.pop_back();
}
int x;
if (k > 1) {
if (d[get<1>(v.back())] == get<0>(v.back()) - 1)
x = go[get<1>(v.back())][get<2>(v.back())];
else
x = get<1>(v.back());
}
else
x = go[1][c];
if (d[x] >= k) {
for (int j = 0; j < 21; j++)
if ((d[x] - k + 1) >> j & 1)
x = f[j][x];
}
auto &u = v.back();
if (d[x] == k - 1) {
if (k > 1)
ans -= dp[x][get<2>(u)];
ans += dp[x][c];
}
v.emplace_back(k, x, c);
int y = (d[x] == k - 1 ? go[x][c] : x);
if (d[y] == n)
y = f[0][y];
printf("%d\n", ans + sum[y]);
}
return 0;
}