题意:有两个字符串 $a$ 和 $b$,现在要尝试把 $a$ 插入到 $b$ 中间。记 $s_i$ 表示把 $a$ 插入到 $b$ 的第 $i$ 个字符($1$-based)后面得到的字符串,现在要把所有 $s_i$ 排序,输出排序后的结果。
多组数据,$a$ 和 $b$ 的长度总和分别都不超过 $2 \times 10^6$。
不难想到直接 sort,然后写一个比较函数就好了。
考虑到每个 $s_i$ 都是 $b$ 的前缀接上 $a$ 再接上剩下的部分,比较的时候就可以按照这些分界点把串划成几段,每段用后缀数组求 LCP 就好了。后缀数组顺便还能比较大小,就不用找第一个不同的位置了。
注意划分然后讨论的细节,像代码里这样直接分成 4 段基本上是最简洁的(第一段肯定一样,不用比了)。
另外本题对后缀数组来说非常卡常,瓶颈显然在于 RMQ 的预处理和比较,RMQ 是很难省掉的(除非写四毛子或者换用 SAM),而减少比较次数可以通过使用 stable_sort
实现。因为快排虽然快,但比较次数是比归并排序要多的。
也有扩展 KMP 做法,不过比起后缀数组来说需要多动一些脑子。填出来之后会单独另发一篇博客。
教训:抄非常长的板子的时候最好有个人在旁边检查是不是抄错了,训练的时候抄错了 SA-IS 导致大解体。
#include <bits/stdc++.h>
using namespace std;
constexpr int maxn = 4000005, l_type = 0, s_type = 1;
bool is_lms(int *tp, int x) {
return x > 0 && tp[x] == s_type && tp[x - 1] == l_type;
}
bool equal_substr(int *s, int x, int y, int *tp) {
do {
if (s[x] != s[y])
return false;
x++;
y++;
} while (!is_lms(tp, x) && !is_lms(tp, y));
return s[x] == s[y];
}
void induced_sort(int *s, int *sa, int *tp, int *buc, int *lbuc, int *sbuc, int n, int m) {
for (int i = 0; i <= n; i++)
if (sa[i] > 0 && tp[sa[i] - 1] == l_type)
sa[lbuc[s[sa[i] - 1]]++] = sa[i] - 1;
for (int i = 1; i <= m; i++)
sbuc[i] = buc[i] - 1;
for (int i = n; ~i; i--)
if (sa[i] > 0 && tp[sa[i] - 1] == s_type)
sa[sbuc[s[sa[i] - 1]]--] = sa[i] - 1;
}
int *sais(int *s, int len, int m) {
int n = len - 1;
int *tp = new int[n + 1];
int *pos = new int[n + 1];
int *name = new int[n + 1];
int *sa = new int[n + 1];
int *buc = new int[m + 1];
int *lbuc = new int[m + 1];
int *sbuc = new int[m + 1];
memset(buc, 0, sizeof(int) * (m + 1));
memset(lbuc, 0, sizeof(int) * (m + 1));
memset(sbuc, 0, sizeof(int) * (m + 1));
for (int i = 0; i <= n; i++)
buc[s[i]]++;
for (int i = 1; i <= m; i++) {
buc[i] += buc[i - 1];
lbuc[i] = buc[i - 1];
sbuc[i] = buc[i] - 1;
}
tp[n] = s_type;
for (int i = n - 1; ~i; i--) {
if (s[i] < s[i + 1])
tp[i] = s_type;
else if (s[i] > s[i + 1])
tp[i] = l_type;
else
tp[i] = tp[i + 1];
}
int cnt = 0;
for (int i = 1; i <= n; i++)
if (tp[i] == s_type && tp[i - 1] == l_type)
pos[cnt++] = i;
memset(sa, -1, sizeof(int) * (n + 1));
for (int i = 0; i < cnt; i++)
sa[sbuc[s[pos[i]]]--] = pos[i];
induced_sort(s, sa, tp, buc, lbuc, sbuc, n, m);
memset(name, -1, sizeof(int) * (n + 1));
int lastx = -1, namecnt = 1;
bool flag = false;
for (int i = 1; i <= n; i++) {
int x = sa[i];
if (is_lms(tp, x)) {
if (lastx >= 0 && !equal_substr(s, x, lastx, tp))
namecnt++;
if (lastx >= 0 && namecnt == name[lastx])
flag = true;
name[x] = namecnt;
lastx = x;
}
}
name[n] = 0;
int *t = new int[cnt];
int p = 0;
for (int i = 0; i <= n; i++)
if (name[i] >= 0)
t[p++] = name[i];
int *tsa;
if (!flag) {
tsa = new int[cnt];
for (int i = 0; i < cnt; i++)
tsa[t[i]] = i;
}
else
tsa = sais(t, cnt, namecnt);
lbuc[0] = sbuc[0] = 0;
for (int i = 1; i <= m; i++) {
lbuc[i] = buc[i - 1];
sbuc[i] = buc[i] - 1;
}
memset(sa, -1, sizeof(int) * (n + 1));
for (int i = cnt - 1; ~i; i--)
sa[sbuc[s[pos[tsa[i]]]]--] = pos[tsa[i]];
induced_sort(s, sa, tp, buc, lbuc, sbuc, n, m);
delete[] tp;
delete[] pos;
delete[] name;
delete[] buc;
delete[] lbuc;
delete[] sbuc;
delete[] t;
delete[] tsa;
return sa;
}
void get_sa(char *s, int n, int *sa, int *rnk, int *height) {
static int a[maxn];
for (int i = 1; i <= n; i++)
a[i - 1] = s[i];
a[n] = '$';
int *t = sais(a, n + 1, 256);
memcpy(sa, t, sizeof(int) * (n + 1));
delete[] t;
sa[0] = 0;
for (int i = 1; i <= n; i++)
rnk[++sa[i]] = i;
for (int i = 1, k = 0; i <= n; i++) {
if (k)
k--;
while (s[i + k] == s[sa[rnk[i] - 1] + k])
k++;
height[rnk[i]] = k;
}
}
int sa[maxn], rnk[maxn], height[maxn];
int f[25][maxn]; // , log_tbl[maxn];
int get_lcp(int l, int r) {
l = rnk[l];
r = rnk[r];
if (l == r)
return 1e9;
if (l > r)
swap(l, r);
l++;
int k = 31 - __builtin_clz(r - l + 1); // log_tbl[r - l + 1];
return min(f[k][l], f[k][r - (1 << k) + 1]);
}
char a[maxn / 2], b[maxn / 2], str[maxn];
int arr[maxn / 2];
int main() {
// log_tbl[0] = -1;
// for (int i = 1; i < maxn; i++)
// log_tbl[i] = log_tbl[i / 2] + 1;
int T;
scanf("%d", &T);
while (T--) {
// scanf("%s%s", a + 1, b + 1);
// int n = strlen(a + 1), m = strlen(b + 1);
int n = 0, m = 0, N = 0;
char c = getchar();
while (c < 'a')
c = getchar();
while (c >= 'a') {
str[++N] = a[++n] = c;
c = getchar();
}
while (c < 'a')
c = getchar();
str[++N] = 'z' + 1;
while (c >= 'a') {
str[++N] = b[++m] = c;
c = getchar();
}
assert(N == n + m + 1);
str[N + 1] = a[n + 1] = b[m + 1] = 0;
// int N = n + m + 1;
get_sa(str, N, sa, rnk, height);
// printf("sa:");
// for (int i = 1; i <= N; i++)
// printf(" %d", sa[i]);
// printf("\n");
for (int i = 1; i <= N; i++)
f[0][i] = height[i];
for (int j = 1; (1 << j) <= N; j++)
for (int i = 1; i + (1 << j) - 1 <= N; i++)
f[j][i] = min(f[j - 1][i], f[j - 1][i + (1 << (j - 1))]);
for (int i = 0; i <= m; i++)
arr[i] = i;
stable_sort(arr, arr + m + 1, [&] (int i, int j) -> bool {
int p[] = {0, i, j, i + n, j + n, n + m}; // the end pos
// printf("----- i = %d j = %d -----\n", i, j);
sort(p + 1, p + 5);
for (int k = 1; k < 4; k++) {
int ti, tj, pk = p[k] + 1;
if (pk <= i)
ti = pk + n + 1;
else if (pk <= i + n)
ti = pk - i;
else
ti = pk + 1;
if (pk <= j)
tj = pk + n + 1;
else if (pk <= j + n)
tj = pk - j;
else
tj = pk + 1;
int lcp = get_lcp(ti, tj);
// printf("k = %d p[k] = %d ti = %d tj = %d lcp = %d\n", k, p[k], ti, tj, lcp);
if (lcp >= p[k + 1] - p[k])
continue;
// printf("%d < %d : %s\n", i, j, rnk[ti] < rnk[tj] ? "true" : "false");
return rnk[ti] < rnk[tj];
}
return i > j;
});
static constexpr int base = 1234567, p = (int)1e9 + 7;
// for (int i = 0; i <= m; i++)
// printf("arr[%d] = %d\n", i, arr[i]);
int ans = 0;
for (int i = m; ~i; i--)
ans = ((long long)ans * base + arr[i]) % p;
printf("%d\n", ans);
// 不需要清空
}
return 0;
}