题意:给定一个字符串,每次询问一个它的一个子串中有多少个本质不同的子串。注意询问的子串长度是固定的。
鉴于本题的特殊性可以用后缀数组+set维护。但是对于一般情况(询问区间有包含),用set维护前驱后继的做法就不太好搞了。
区间本质不同子串统计有一个经典做法是后缀自动机+LCT+线段树维护。
考虑建出后缀自动机,然后枚举右端点,用线段树维护左端点的答案。显然只有right集合在$[l, r]$中的串才有可能有贡献,那么我们可以只考虑每个串最大的right。
每次右端点+1时其实就是把它对应的结点到SAM的根节点全部更新一遍,因为我们要保证这个本质不同的子串左端点不能越过$l$,所以对于一个结点$p$,我们知道它对应的子串长度$(val[par[p]], val[p]]$,那么在对应的区间上全部+1,询问的时候就是询问$[l, r]$区间和。
更新的操作实际上可以看成是整段整段的,和LCT的access操作相同,所以可以写一个LCT维护每次更新,再用一个线段树维护区间+1、区间求和即可。
这个做法是$O(n\log^2 n)$的,如果要支持任意区间且在线询问可以用一个主席树替代线段树维护,复杂度$O(n\log^2 n + q\log n)$。
#include <bits/stdc++.h>
using namespace std;
constexpr int maxn = 200005;
int val[maxn], par[maxn], go[maxn][26], last, sam_cnt;
void extend(int c) {
int p = last, np = ++sam_cnt;
val[np] = val[p] + 1;
while (p && !go[p][c]) {
go[p][c] = np;
p = par[p];
}
if (!p)
par[np] = 1;
else {
int q = go[p][c];
if (val[q] == val[p] + 1)
par[np] = q;
else {
int nq = ++sam_cnt;
val[nq] = val[p] + 1;
memcpy(go[nq], go[q], sizeof(go[q]));
par[nq] = par[q];
par[np] = par[q] = nq;
while (p && go[p][c] == q) {
go[p][c] = nq;
p = par[p];
}
}
}
last = np;
}
int N;
long long tree[524305], mark[524305];
void update(int l, int r, int d) {
if (l > r)
return;
int len = 1, cntl = 0, cntr = 0;
for (l += N - 1, r += N + 1; l ^ r ^ 1; l >>= 1, r >>= 1, len <<= 1) {
tree[l] += (long long)cntl * d;
tree[r] += (long long)cntr * d;
if (~l & 1) {
tree[l ^ 1] += (long long)d * len;
mark[l ^ 1] += d;
cntl += len;
}
if (r & 1) {
tree[r ^ 1] += (long long)d * len;
mark[r ^ 1] += d;
cntr += len;
}
}
while (l) {
tree[l] += (long long)cntl * d;
tree[r] += (long long)cntr * d;
l >>= 1;
r >>= 1;
}
}
long long query(int l, int r) {
long long ans = 0;
int len = 1, cntl = 0, cntr = 0;
for (l += N - 1, r += N + 1; l ^ r ^ 1; l >>= 1, r >>= 1, len <<= 1) {
ans += (long long)cntl * mark[l] + (long long)cntr * mark[r];
if (~l & 1) {
ans += tree[l ^ 1];
cntl += len;
}
if (r & 1) {
ans += tree[r ^ 1];
cntr += len;
}
}
while (l) {
ans += (long long)cntl * mark[l] + (long long)cntr * mark[r];
l >>= 1;
r >>= 1;
}
return ans;
}
struct node {
int id, l, r;
int val;
bool lazy;
int size;
node *ch[2], *p;
node() : size(1) {}
node(int id) : id(id), l(id), r(id) {}
void pushdown() {
if (lazy) {
ch[0] -> val = ch[1] -> val = val;
ch[0] -> lazy = ch[1] -> lazy = true;
lazy = false;
}
}
void refresh() {
l = r = id;
if (ch[0] -> id)
l = ch[0] -> l;
if (ch[1] -> id)
r = ch[1] -> r;
}
} null[maxn];
void init(node *x) {
*x = node(x - null);
x -> ch[0] = x -> ch[1] = x -> p = null;
x -> val = 0;
}
inline bool isroot(node *x) {
return x != x -> p -> ch[0] && x != x -> p -> ch[1];
}
inline bool dir(node *x) {
return x == x -> p -> ch[1];
}
void rot(node *x, int d) {
node *y = x -> ch[d ^ 1];
if ((x -> ch[d ^ 1] = y -> ch[d]) != null)
y -> ch[d] -> p = x;
y -> p = x -> p;
if (!isroot(x))
x -> p -> ch[dir(x)] = y;
(y -> ch[d] = x) -> p = y;
x -> refresh();
y -> refresh();
}
void splay(node *x) {
x -> pushdown();
while (!isroot(x)) {
if (!isroot(x -> p))
x -> p -> p -> pushdown();
x -> p -> pushdown();
x -> pushdown();
if (isroot(x -> p)) {
rot(x -> p, dir(x) ^ 1);
break;
}
if (dir(x) == dir(x -> p))
rot(x -> p -> p, dir(x -> p) ^ 1);
else
rot(x -> p, dir(x) ^ 1);
rot(x -> p, dir(x) ^ 1);
}
}
int tim;
node *access(node *x) {
node *y = null;
while (x != null) {
splay(x);
x -> ch[1] = null;
x -> refresh();
if (x -> val)
update(x -> val - val[x -> r] + 1, x -> val - val[par[x -> l]], -1);
x -> val = tim;
x -> lazy = true;
update(x -> val - val[x -> r] + 1, x -> val - val[par[x -> l]], 1);
x -> ch[1] = y;
(y = x) -> refresh();
x = x -> p;
}
return y;
}
char s[maxn];
long long ans[maxn];
int n, id[maxn];
int main() {
last = sam_cnt = 1;
scanf("%s", s + 1);
n = strlen(s + 1);
for (int i = 1; i <= n; i++) {
extend(s[i] - 'a');
id[i] = last;
}
N = 1;
while (N <= sam_cnt + 1)
N <<= 1;
for (int i = 0; i <= sam_cnt; i++) {
init(null + i);
null[i].id = null[i].l = null[i].r = i;
}
null -> size = 0;
for (int i = 2; i <= sam_cnt; i++)
null[i].p = null + par[i];
int m, q;
scanf("%d%d", &q, &m);
for (int i = 1; i <= n; i++) {
tim++;
access(null + id[i]);
if (i >= m)
ans[i - m + 1] = query(i - m + 1, i);
}
while (q--) {
int x;
scanf("%d", &x);
printf("%lld\n", ans[x]);
}
return 0;
}
代码没有高亮吗?
草,插件可能挂了,我搞一搞