ICPC2021台北 Gym103443K Insertion Array 后缀数组

Problem – K – Codeforces

题意:有两个字符串 $a$ 和 $b$,现在要尝试把 $a$ 插入到 $b$ 中间。记 $s_i$ 表示把 $a$ 插入到 $b$ 的第 $i$ 个字符($1$-based)后面得到的字符串,现在要把所有 $s_i$ 排序,输出排序后的结果。

多组数据,$a$ 和 $b$ 的长度总和分别都不超过 $2 \times 10^6$。


不难想到直接 sort,然后写一个比较函数就好了。

考虑到每个 $s_i$ 都是 $b$ 的前缀接上 $a$ 再接上剩下的部分,比较的时候就可以按照这些分界点把串划成几段,每段用后缀数组求 LCP 就好了。后缀数组顺便还能比较大小,就不用找第一个不同的位置了。

注意划分然后讨论的细节,像代码里这样直接分成 4 段基本上是最简洁的(第一段肯定一样,不用比了)。

另外本题对后缀数组来说非常卡常,瓶颈显然在于 RMQ 的预处理和比较,RMQ 是很难省掉的(除非写四毛子或者换用 SAM),而减少比较次数可以通过使用 stable_sort 实现。因为快排虽然快,但比较次数是比归并排序要多的。

也有扩展 KMP 做法,不过比起后缀数组来说需要多动一些脑子。填出来之后会单独另发一篇博客。

教训:抄非常长的板子的时候最好有个人在旁边检查是不是抄错了,训练的时候抄错了 SA-IS 导致大解体。

#include <bits/stdc++.h>

using namespace std;

constexpr int maxn = 4000005, l_type = 0, s_type = 1;

bool is_lms(int *tp, int x) {
	return x > 0 && tp[x] == s_type && tp[x - 1] == l_type;
}

bool equal_substr(int *s, int x, int y, int *tp) {
	do {
		if (s[x] != s[y])
			return false;
		x++;
		y++;
	} while (!is_lms(tp, x) && !is_lms(tp, y));

	return s[x] == s[y];
}

void induced_sort(int *s, int *sa, int *tp, int *buc, int *lbuc, int *sbuc, int n, int m) {
	for (int i = 0; i <= n; i++)
		if (sa[i] > 0 && tp[sa[i] - 1] == l_type)
			sa[lbuc[s[sa[i] - 1]]++] = sa[i] - 1;
		
	for (int i = 1; i <= m; i++)
		sbuc[i] = buc[i] - 1;
	
	for (int i = n; ~i; i--)
		if (sa[i] > 0 && tp[sa[i] - 1] == s_type)
			sa[sbuc[s[sa[i] - 1]]--] = sa[i] - 1;
}

int *sais(int *s, int len, int m) {
	int n = len - 1;

	int *tp = new int[n + 1];
	int *pos = new int[n + 1];
	int *name = new int[n + 1];
	int *sa = new int[n + 1];
	int *buc = new int[m + 1];
	int *lbuc = new int[m + 1];
	int *sbuc = new int[m + 1];

	memset(buc, 0, sizeof(int) * (m + 1));
	memset(lbuc, 0, sizeof(int) * (m + 1));
	memset(sbuc, 0, sizeof(int) * (m + 1));

	for (int i = 0; i <= n; i++)
		buc[s[i]]++;
	
	for (int i = 1; i <= m; i++) {
		buc[i] += buc[i - 1];

		lbuc[i] = buc[i - 1];
		sbuc[i] = buc[i] - 1;
	}

	tp[n] = s_type;
	for (int i = n - 1; ~i; i--) {
		if (s[i] < s[i + 1])
			tp[i] = s_type;
		else if (s[i] > s[i + 1])
			tp[i] = l_type;
		else
			tp[i] = tp[i + 1];
	}

	int cnt = 0;
	for (int i = 1; i <= n; i++)
		if (tp[i] == s_type && tp[i - 1] == l_type)
			pos[cnt++] = i;
	
	memset(sa, -1, sizeof(int) * (n + 1));
	for (int i = 0; i < cnt; i++)
		sa[sbuc[s[pos[i]]]--] = pos[i];
	induced_sort(s, sa, tp, buc, lbuc, sbuc, n, m);

	memset(name, -1, sizeof(int) * (n + 1));
	int lastx = -1, namecnt = 1;
	bool flag = false;

	for (int i = 1; i <= n; i++) {
		int x = sa[i];

		if (is_lms(tp, x)) {
			if (lastx >= 0 && !equal_substr(s, x, lastx, tp))
				namecnt++;
			
			if (lastx >= 0 && namecnt == name[lastx])
				flag = true;
			
			name[x] = namecnt;
			lastx = x;
		}
	}
	name[n] = 0;

	int *t = new int[cnt];
	int p = 0;
	for (int i = 0; i <= n; i++)
		if (name[i] >= 0)
			t[p++] = name[i];

	int *tsa;
	if (!flag) {
		tsa = new int[cnt];
		
		for (int i = 0; i < cnt; i++)
			tsa[t[i]] = i;
	}
	else
		tsa = sais(t, cnt, namecnt);
	
	lbuc[0] = sbuc[0] = 0;
	for (int i = 1; i <= m; i++) {
		lbuc[i] = buc[i - 1];
		sbuc[i] = buc[i] - 1;
	}

	memset(sa, -1, sizeof(int) * (n + 1));
	for (int i = cnt - 1; ~i; i--)
		sa[sbuc[s[pos[tsa[i]]]]--] = pos[tsa[i]];
	induced_sort(s, sa, tp, buc, lbuc, sbuc, n, m);

	delete[] tp;
	delete[] pos;
	delete[] name;
	delete[] buc;
	delete[] lbuc;
	delete[] sbuc;
	delete[] t;
	delete[] tsa;

	return sa;
}

void get_sa(char *s, int n, int *sa, int *rnk, int *height) {
	static int a[maxn];

	for (int i = 1; i <= n; i++)
		a[i - 1] = s[i];
	
	a[n] = '$';

	int *t = sais(a, n + 1, 256);
	memcpy(sa, t, sizeof(int) * (n + 1));
	delete[] t;

	sa[0] = 0;
	for (int i = 1; i <= n; i++)
		rnk[++sa[i]] = i;
	
	for (int i = 1, k = 0; i <= n; i++) {
		if (k)
			k--;
		
		while (s[i + k] == s[sa[rnk[i] - 1] + k])
			k++;
		
		height[rnk[i]] = k;
	}
}

int sa[maxn], rnk[maxn], height[maxn];

int f[25][maxn]; // , log_tbl[maxn];

int get_lcp(int l, int r) {
	l = rnk[l];
	r = rnk[r];

	if (l == r)
		return 1e9;

	if (l > r)
		swap(l, r);

	l++;
	int k = 31 - __builtin_clz(r - l + 1); // log_tbl[r - l + 1];

	return min(f[k][l], f[k][r - (1 << k) + 1]);	
}

char a[maxn / 2], b[maxn / 2], str[maxn];

int arr[maxn / 2];

int main() {
	// log_tbl[0] = -1;
	// for (int i = 1; i < maxn; i++)
	// 	log_tbl[i] = log_tbl[i / 2] + 1;

	int T;
	scanf("%d", &T);

	while (T--) {
		// scanf("%s%s", a + 1, b + 1);

		// int n = strlen(a + 1), m = strlen(b + 1);

		int n = 0, m = 0, N = 0;

		char c = getchar();
		while (c < 'a')
			c = getchar();

		while (c >= 'a') {
			str[++N] = a[++n] = c;
			c = getchar();
		}
		
		while (c < 'a')
			c = getchar();
		
		str[++N] = 'z' + 1;
		
		while (c >= 'a') {
			str[++N] = b[++m] = c;
			c = getchar();
		}

		assert(N == n + m + 1);

		str[N + 1] = a[n + 1] = b[m + 1] = 0;

		// int N = n + m + 1;
		get_sa(str, N, sa, rnk, height);

		// printf("sa:");
		// for (int i = 1; i <= N; i++)
		// 	printf(" %d", sa[i]);
		// printf("\n");

		for (int i = 1; i <= N; i++)
			f[0][i] = height[i];
		
		for (int j = 1; (1 << j) <= N; j++)
			for (int i = 1; i + (1 << j) - 1 <= N; i++)
				f[j][i] = min(f[j - 1][i], f[j - 1][i + (1 << (j - 1))]);
		
		for (int i = 0; i <= m; i++)
			arr[i] = i;
		
		stable_sort(arr, arr + m + 1, [&] (int i, int j) -> bool {
			int p[] = {0, i, j, i + n, j + n, n + m}; // the end pos

			// printf("----- i = %d  j = %d -----\n", i, j);

			sort(p + 1, p + 5);

			for (int k = 1; k < 4; k++) {
				int ti, tj, pk = p[k] + 1;

				if (pk <= i)
					ti = pk + n + 1;
				else if (pk <= i + n)
					ti = pk - i;
				else
					ti = pk + 1;
				
				if (pk <= j)
					tj = pk + n + 1;
				else if (pk <= j + n)
					tj = pk - j;
				else
					tj = pk + 1;

				int lcp = get_lcp(ti, tj);

				// printf("k = %d  p[k] = %d  ti = %d  tj = %d  lcp = %d\n", k, p[k], ti, tj, lcp);

				if (lcp >= p[k + 1] - p[k])
					continue;
				
				// printf("%d < %d : %s\n", i, j, rnk[ti] < rnk[tj] ? "true" : "false");

				return rnk[ti] < rnk[tj];
			}

			return i > j;
		});

		static constexpr int base = 1234567, p = (int)1e9 + 7;

		// for (int i = 0; i <= m; i++)
		// 	printf("arr[%d] = %d\n", i, arr[i]);

		int ans = 0;
		for (int i = m; ~i; i--)
			ans = ((long long)ans * base + arr[i]) % p;
		
		printf("%d\n", ans);

		// 不需要清空
	}

	return 0;
}
暂无评论

发送评论 编辑评论


				
|´・ω・)ノ
ヾ(≧∇≦*)ゝ
(☆ω☆)
(╯‵□′)╯︵┴─┴
 ̄﹃ ̄
(/ω\)
∠( ᐛ 」∠)_
(๑•̀ㅁ•́ฅ)
→_→
୧(๑•̀⌄•́๑)૭
٩(ˊᗜˋ*)و
(ノ°ο°)ノ
(´இ皿இ`)
⌇●﹏●⌇
(ฅ´ω`ฅ)
(╯°A°)╯︵○○○
φ( ̄∇ ̄o)
ヾ(´・ ・`。)ノ"
( ง ᵒ̌皿ᵒ̌)ง⁼³₌₃
(ó﹏ò。)
Σ(っ °Д °;)っ
( ,,´・ω・)ノ"(´っω・`。)
╮(╯▽╰)╭
o(*////▽////*)q
>﹏<
( ๑´•ω•) "(ㆆᴗㆆ)
😂
😀
😅
😊
🙂
🙃
😌
😍
😘
😜
😝
😏
😒
🙄
😳
😡
😔
😫
😱
😭
💩
👻
🙌
🖕
👍
👫
👬
👭
🌚
🌝
🙈
💊
😶
🙏
🍦
🍉
😣
Source: github.com/k4yt3x/flowerhd
颜文字
Emoji
小恐龙
花!
上一篇
下一篇