洛谷 P3352 [ZJOI2016] 线段树 题解

对于一个数 ii,考虑它左边第一个比 aia_i 大的数的位置 LL 和右边第一个比 aia_i 大的数的位置 RR,显然在操作过程中,这两个位置都会逐渐向 ii 紧缩。但是我们无法得知紧缩的过程中这两个位置的值是多少,因此并不能计算出答案。

由于这个做法失败的原因是值的存在,不妨尝试通过设定一个值域线 VV 来消去值。具体的,将 V\leq V 的值设为 00>V>V 的值设为 11,则随着操作的进行,每一个 00 连续段两侧的 11 都会逐渐向内收缩。这样的好处是左右端点的值一定为 11,然后就可以对每一段分别设 fq,l,rf_{q,l,r} 表示 qq 次操作后,这一段缩到 [l,r][l,r] 的方案数。这个状态的转移非常明显,直接分类讨论下一次操作的位置,写出转移式后用前缀和优化一发就可以 O(qn2)O(qn^2) 了。

最后合并答案时,也可以按照期望的拆分,对每一个值拆一拆贡献。这部分相较 DP 来说非常容易,所以不再赘述。

可惜再结合枚举值是 O(qn3)O(qn^3) 的,虽然在随机数据下经过一定程度的剪枝可以通过,但是我们毕竟希望存在一个更加靠谱的做法。事实上写完代码之后可以发现每一个位置的转移是一些位置的 DP 值的常线性函数,所以可以“合并起来”转移。具体来说,在 O(qn3)O(qn^3) 的做法里面做多次 DP,每次有一些状态有初值;现在把所有的初值都加到同一个 DP 数组里面,仿照刚才的方程做一次 DP,最后对答案的贡献在之前是每一次 DP 之后分别加,由于这里每一次 DP 之后加的时候系数都是一样的,所以在合并所有 DP 数组之后还是可以直接取。

这样就从 nn 次 DP 变成了一次 DP,复杂度下降为 O(qn2)O(qn^2),对于任意的 n,q400n,q\leq 400 的数据就都能通过了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
const long long mod = 1000000007;
int n, q, a[405], b[405], vtop;
long long dp[2][405][405], pre[405][405], suf[405][405], ans[405];

inline long long Power(long long x, long long y) {
long long ans = 1;
while (y) {
if (y & 1) ans = ans * x % mod;
x = x * x % mod;
y >>= 1;
}
return ans;
}

inline void Read() {
n = qread(); q = qread();
for (int i = 1;i <= n;i++) b[i] = a[i] = qread();
}

inline void Prefix() {
sort(b + 1, b + n + 1);
vtop = unique(b + 1, b + n + 1) - b - 1;
for (int i = 1;i <= n;i++) a[i] = lower_bound(b + 1, b + vtop + 1, a[i]) - b;
}

inline void Solve() {
for (int i = 1;i < vtop;i++) {
for (int j = 1;j <= n;) {
if (a[j] > i) {
j++;
continue;
}
int l = j;
while (j <= n && a[j] <= i) j++;
dp[1][l][j - 1] = (dp[1][l][j - 1] + (b[i + 1] - b[i]) % mod) % mod;
for (int k = l;k < j;k++) ans[k] = (ans[k] + (b[i + 1] - b[i]) % mod * Power(n * (n + 1) / 2, q) % mod) % mod;
}
}
for (int j = 1;j <= n;j++) {
for (int k = j;k <= n;k++) {
pre[j][k] = (pre[j - 1][k] + dp[1][j][k] * (j - 1) % mod) % mod;
}
}
for (int j = 1;j <= n;j++) {
for (int k = n;k >= j;k--) {
suf[j][k] = (suf[j][k + 1] + dp[1][j][k] * (n - k) % mod) % mod;
}
}
long long inv = Power(n * (n + 1) / 2, mod - 2);
for (int i = 2;i <= q + 1;i++) {
for (int j = 1;j <= n;j++) {
for (int k = j;k <= n;k++) {
dp[i&1][j][k] = (pre[j - 1][k] + suf[j][k + 1] + dp[i - 1&1][j][k] * ((j - 1) * j / 2 + (n - k) * (n - k + 1) / 2 + (k - j + 1) * (k - j + 2) / 2) % mod) % mod;
}
}
for (int j = 1;j <= n;j++) {
for (int k = j;k <= n;k++) {
pre[j][k] = (pre[j - 1][k] + dp[i&1][j][k] * (j - 1) % mod) % mod;
}
}
for (int j = 1;j <= n;j++) {
for (int k = n;k >= j;k--) {
suf[j][k] = (suf[j][k + 1] + dp[i&1][j][k] * (n - k) % mod) % mod;
}
}
}
/*
for (int l = 1;l <= n;l++) {
for (int r = l;r <= n;r++) cout << dp[2][l][r] << " ";
cout << endl;
}
*/
for (int i = 1;i <= n;i++) {
for (int l = 1;l <= i;l++) {
for (int r = i;r <= n;r++) ans[i] = (ans[i] - dp[q + 1&1][l][r] + mod) % mod;
}
}
for (int i = 1;i <= n;i++) cout << (ans[i] + Power(n * (n + 1) / 2, q) * b[a[i]]) % mod << ' '; cout << endl;
}