题意:你有一个长为 $n$ 的表达式,其中只有一种运算符 $\bigotimes_k$,效果为 $a \bigotimes_k b = k(a + b)$。并且 $k$ 越小则运算符优先级越低,如果有多个相同的 $k$ 则按从左到右的顺序依次计算。
现在有 $m$ 次询问,每次询问区间 $[l, r]$ 对应的表达式的结果,对 $10^9 + 7$ 取模。
$n, m \le 3 \times 10^5$
考虑把表达式树建出来,顺便求出每个数左边和右边第一个比它小的数。可以发现区间可以从最小值处拆成左右两半,左边可以从 $l$ 开始一直往右跳到最小值为止,右边则是一直往左跳。(当然要先特判掉 $l = r$ 的情况。)
两边是对称的,只考虑左半往右跳。如果往右跳了一步,假设右边第一个更小的结点是 $x$,那么当前答案就需要加上表达式树上 $x$ 的整个右子树的结果,然后乘上 $k_x$。
不难发现这个过程实际上是可以用倍增加速的,处理一下从某个点往左/右跳 $2^i$ 步时答案需要乘多少、再加多少即可。
然而因为要建表达式树,点数就变成了 $6\times 10^5$,开这么多个倍增数组空间不是很够用。一种可行的解决方法是特判表达式树的叶子,让树的结点数重新变成 $3\times 10^5$,但是有那么一点麻烦,不是很想写。
如果直接求出需要往左/右跳多少步,那么倍增时需要在哪几层跳过去就已经确定了。可以把所有询问离线下来,一层一层地做,用滚动数组维护当前层的倍增数组,然后每层都扫一遍所有询问,把这一层应该跳的询问处理一下。
建树因为懒就写了个 sparse table,除此之外的空间复杂度都是 $O(n)$ 的,并且常数并不大。
#include <bits/stdc++.h>
using namespace std;
constexpr int maxn = 600005, p = (int)1e9 + 7;
int a[maxn], op[maxn]; // op : 1 ~ n - 1
int st[25][maxn];
int query_min(int l, int r) {
r--;
int k = 31 - __builtin_clz(r - l + 1);
return min(st[k][l], st[k][r - (1 << k) + 1], [] (int x, int y) {
return op[x] != op[y] ? op[x] < op[y] : x > y;
});
}
int lc[maxn], rc[maxn];
int N;
int sum[maxn];
void build(int l, int r, int &o) {
if (l == r) {
o = l;
sum[o] = a[l];
return;
}
int mid = query_min(l, r);
o = mid + N;
build(l, mid, lc[o]);
build(mid + 1, r, rc[o]);
sum[o] = (long long)(sum[lc[o]] + sum[rc[o]]) * op[mid] % p;
}
int dl[maxn], dr[maxn];
int fl[2][maxn], fr[2][maxn]; // 向左/右跳
int gl[2][maxn], gr[2][maxn], vl[2][maxn], vr[2][maxn]; // plus vl/vr
int u[maxn], v[maxn], w[maxn], tu[maxn], tv[maxn], ansu[maxn], ansv[maxn];
int id[maxn];
int main() {
int n;
scanf("%d", &n);
N = n;
for (int i = 1; i <= n; i++)
scanf("%d", &a[i]);
for (int i = 1; i < n; i++) {
scanf("%d", &op[i]);
st[0][i] = i;
}
for (int j = 1; (1 << j) < n; j++)
for (int i = 1; i + (1 << j) - 1 < n; i++)
st[j][i] = min(st[j - 1][i], st[j - 1][i + (1 << (j - 1))], [] (int x, int y) {
return op[x] != op[y] ? op[x] < op[y] : x > y;
});
int rt;
build(1, n, rt);
for (int i = 1; i < n; i++)
id[i] = i + n;
sort(id + 1, id + n, [&n] (int x, int y) {
return op[x - n] != op[y - n] ? op[x - n] < op[y - n] : x > y;
});
for (int k = 1; k < n; k++) {
int x = id[k];
fl[0][rc[x]] = fr[0][lc[x]] = x;
fl[0][lc[x]] = fl[0][x];
fr[0][rc[x]] = fr[0][x];
}
for (int k = 1; k < n * 2; k++) {
int x = (k < n ? id[k] : k - n + 1);
dl[x] = dl[fl[0][x]] + 1;
dr[x] = dr[fr[0][x]] + 1;
if (fl[0][x]) {
gl[0][x] = op[fl[0][x] - n];
vl[0][x] = (long long)op[fl[0][x] - n] * sum[lc[fl[0][x]]] % p;
}
if (fr[0][x]) {
gr[0][x] = op[fr[0][x] - n];
vr[0][x] = (long long)op[fr[0][x] - n] * sum[rc[fr[0][x]]] % p;
}
}
int m;
scanf("%d", &m);
for (int i = 1; i <= m; i++) {
scanf("%d%d", &u[i], &v[i]);
if (u[i] != v[i]) {
w[i] = query_min(u[i], v[i]) + n;
tu[i] = dr[u[i]] - dr[w[i]] - 1;
tv[i] = dl[v[i]] - dl[w[i]] - 1;
ansu[i] = a[u[i]];
ansv[i] = a[v[i]];
}
}
int cur = 0;
for (int j = 0; (1 << j) <= n; j++) {
for (int i = 1; i <= m; i++) {
if (tu[i] >> j & 1) {
ansu[i] = ((long long)ansu[i] * gr[cur][u[i]] + vr[cur][u[i]]) % p;
u[i] = fr[cur][u[i]];
}
if (tv[i] >> j & 1) {
ansv[i] = ((long long)ansv[i] * gl[cur][v[i]] + vl[cur][v[i]]) % p;
v[i] = fl[cur][v[i]];
}
}
cur ^= 1;
for (int i = 1; i < n * 2; i++) {
fl[cur][i] = fl[cur ^ 1][fl[cur ^ 1][i]];
gl[cur][i] = (long long)gl[cur ^ 1][i] * gl[cur ^ 1][fl[cur ^ 1][i]] % p;
vl[cur][i] = ((long long)vl[cur ^ 1][i] * gl[cur ^ 1][fl[cur ^ 1][i]] + vl[cur ^ 1][fl[cur ^ 1][i]]) % p;
fr[cur][i] = fr[cur ^ 1][fr[cur ^ 1][i]];
gr[cur][i] = (long long)gr[cur ^ 1][i] * gr[cur ^ 1][fr[cur ^ 1][i]] % p;
vr[cur][i] = ((long long)vr[cur ^ 1][i] * gr[cur ^ 1][fr[cur ^ 1][i]] + vr[cur ^ 1][fr[cur ^ 1][i]]) % p;
}
}
for (int i = 1; i <= m; i++) {
int ans;
if (u[i] == v[i])
ans = a[u[i]];
else
ans = (long long)(ansu[i] + ansv[i]) * op[w[i] - n] % p;
printf("%d\n", ans);
}
return 0;
}