Luogu P10037 赛时爆标做法

复杂度 O(n+mt2+mtk)O(n+mt^2+mt\sqrt k),如有更优做法欢迎讨论。

分 2 部分 dp,第一部分限制 vv 前缀和不低于 00,记 fi,jf_{i,j} 表示前 iivv 前缀和为 jj 的方案数。

低于 00 时转到第二部分,设状态 gi,jg_{i,j} 表示前 ii 天,vv 前缀和比允许的 vv 前缀和最小值高出 jj 的方案数,考虑状态 fi,jf_{i,j}vi+1<jv_{i+1}<-j。则所有深度 >i>i 的点没法要,设剩下的点有 cc 个,则可以选择保留 1c1\sim c 个。保留 xx 个点需要 vv 前缀和最小值 k/x\geq -\lfloor k/x\rfloor,于是对所有 xx,将 fi,jf_{i,j} 转移到 gi+1,j+vi+1+k/xg_{i+1,j+v_{i+1}+\lfloor k/x\rfloor},使用整除分块在 O(mtk)O(mt\sqrt k) 时间内完成。

ff 之间的转移和 gg 之间的转移都是直接枚举 vi+1v_{i+1} 的值进行转移,可以使用前缀和优化。

第二部分直接做是 O(tk)O(tk) 的,但是注意到如果比允许的最小值高出太多则无论如何都合法,于是可以直接乘一个 mm 的幂加到答案里,这样剪一下枝就变成 O(mt2)O(mt^2)

赛时图方便写的带 O(tk)O(tk),跑的和 std 几乎一样快。代码里面 ggff 跟上面是反过来的。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#include <bits/stdc++.h>
using namespace std;

#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
char buf[1 << 21], *p1 = buf, *p2 = buf;

inline int qread() {
char c = getchar();
int x = 0, f = 1;
while (c < '0' || c > '9') {
if (c == '-') f = -1;
c = getchar();
}
while (c >= '0' && c <= '9') {
x = (x << 3) + (x << 1) + c - 48;
c = getchar();
}
return x * f;
}

const int N = 200005, K = 300005, M = 55, T = 505;
const long long mod = 998244353;

int n, m, t, k, dct[N];
vector <int> gr[N];
long long f[2][K * 2], g[T][M * T], s[K * 2];

inline void Read() {
n = qread(); m = qread(); t = qread(); k = qread();
for (int i = 1;i < n;i++) {
int u = qread(), v = qread();
gr[u].push_back(v); gr[v].push_back(u);
}
}

inline void Dfs(int u, int fa, int dep) {
dct[dep]++;
for (int v : gr[u]) {
if (v != fa) Dfs(v, u, dep + 1);
}
}

inline void Prefix() {
Dfs(1, -1, 0);
for (int i = 1;i <= t;i++) dct[i] += dct[i - 1];
}

inline void Solve() {
g[0][0] = 1;
for (int i = 1;i <= t;i++) {
s[0] = 0;
for (int j = 1;j <= m * (i + 1) + 1;j++) s[j] = (s[j - 1] + g[i - 1][j - 1]) % mod;
for (int j = 0;j <= m * i;j++) g[i][j] = (s[j + m + 1] - s[max(0, j - m)]) % mod;
memset(f[i & 1], 0, sizeof(f[i & 1]));
for (int j = 1;j <= m;j++) {
// sum v = -j
long long w = s[m - j + 1];
for (int l = 1, r;l <= min(k, dct[i]);l = r + 1) {
r = min(dct[i], k / (k / l));
if (k / l - j >= 0) f[i & 1][k / l - j] = (f[i & 1][k / l - j] + (r - l + 1) * w) % mod;
}
}
s[0] = 0;
for (int j = 1;j <= m * (i + 1) + 1 + k;j++) s[j] = (s[j - 1] + f[i - 1 & 1][j - 1]) % mod;
for (int j = 0;j <= m * i + k;j++) f[i & 1][j] = (f[i & 1][j] + s[j + m + 1] - s[max(0, j - m)]) % mod;
}
long long ans = 0;
for (int j = 0;j <= m * t;j++) ans = (ans + g[t][j] * n) % mod;
for (int j = 0;j <= k + m * t;j++) ans = (ans + f[t & 1][j]) % mod;
cout << (ans % mod + mod) % mod << endl;
}

int main() {
Read();
Prefix();
Solve();
return 0;
}