洛谷 P10167 [DTCPC 2024] 小方学乘法 题解

想法比较直接的优化 dp + 插值。

首先由于 xx 的插入会导致原来的数的错位,因此一个显然的想法是对于位数不同的 xx 分别处理。

然后注意到当 xx 的位数固定时,对于任意一种 xx 和乘号的填法,得到的乘法算式中的每一个数的值都是一个常数或者关于 xx 的一次函数。于是将它们乘起来就会得到一个 xxO(n)O(n) 次多项式。对所有填入乘号和 xx 的方法求和并不改变多项式的次数。至此我们可以考虑对若干个 xx 单独求值并通过多项式插值快速求和。

我们需要对一个固定的 xx(记这个 xx 的位数为 ll)在 O(n)O(n) 时间内计算答案。考虑一个 dp,设 fif_i 为考虑到第 ii 个问号之前的部分,所有填法的计算结果的和。暴力的转移是枚举最后一个乘号的位置(设为 jj),并将 fjf_j 乘上后面所有位置填 xx 构成的数转移到 fif_i。这样是 O(n2)O(n^2) 的,考虑优化。

注意到 jj 转移到的位置从 ii 变为 i+1i+1 时,fjf_j 乘上的值的变化一定是先乘上「xx 的位数」个 1010,加上 xx,再乘上「第 ii 个问号和第 i+1i+1 个问号之间的数字串长度」个 1010,再加上这串数字构成的值。所以我们在转移过程中维护 fjf_j 的和以及 fjf_j 乘上后面的值的和两个值,就可以 O(n)O(n) 完成上述 dp 了。具体可以看代码。

对每一个位数取出 O(n)O(n) 个点值,求前缀和之后插值即可。暴力插值或线性插值均可。复杂度 O(n2logR)O(n^2\log R)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
#include <bits/stdc++.h>
using namespace std;

#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
char buf[1 << 21], *p1 = buf, *p2 = buf;

inline int qread() {
char c = getchar();
int x = 0, f = 1;
while (c < '0' || c > '9') {
if (c == '-') f = -1;
c = getchar();
}
while (c >= '0' && c <= '9') {
x = (x << 3) + (x << 1) + c - 48;
c = getchar();
}
return x * f;
}

const long long mod = 1000000007;
const int N = 2005, L = 18;

int n;
char str[N];
vector <pair <long long, int> > val;
long long f[N], pw10[N], p10r[N], sl, sr, x[N], y[N], inv[N];

inline void Prefix() {
long long cur = 0, len = 0;
for (int i = 1;i <= n;i++) {
if (str[i] == '?') {
val.push_back(make_pair(cur, len));
cur = 0; len = 0;
} else {
cur = (cur * 10 + str[i] - '0') % mod;
len++;
}
}
val.push_back(make_pair(cur, len));
pw10[0] = 1;
for (int i = 1;i <= n;i++) pw10[i] = pw10[i - 1] * 10 % mod;
p10r[0] = 1;
for (int i = 1;i <= L;i++) p10r[i] = p10r[i - 1] * 10;
inv[1] = 1;
for (int i = 2;i <= n;i++) inv[i] = (mod - mod / i) * inv[mod % i] % mod;
}

inline long long getAns(long long x, long long xln) {
x %= mod;
long long fsm = 0, fpsm = 0;
f[0] = val[0].first;
fsm = 1;
fpsm = val[0].first;
for (int i = 1;i < val.size();i++) {
fpsm = fpsm * pw10[xln] % mod;
fpsm = (fpsm + x * fsm) % mod;
fpsm = fpsm * pw10[val[i].second] % mod;
fpsm = (fpsm + val[i].first * fsm) % mod;
f[i] = (fpsm + f[i - 1] * val[i].first) % mod;
fsm = (fsm + f[i - 1]) % mod;
fpsm = (fpsm + f[i - 1] * val[i].first) % mod;
}
return f[val.size() - 1];
}

inline long long Inv(long long x) {
if (x < 0) return mod - inv[-x];
else return inv[x];
}

inline void Solve() {
long long ans = 0;
for (int i = 1;i <= L;i++) {
long long vl = p10r[i - 1], vr = p10r[i] - 1;
vl = max(vl, sl); vr = min(vr, sr);
if (vl > vr) continue;
long long cnt = val.size() + 5;
if (vr - vl + 1 <= cnt) {
for (long long j = vl;j <= vr;j++) ans = (ans + getAns(j % mod, i)) % mod;
} else {
y[0] = 0;
for (long long j = vl;j <= vl + cnt - 1;j++) {
x[j - vl + 1] = j;
y[j - vl + 1] = getAns(j % mod, i);
}
for (int j = 1;j <= cnt;j++) y[j] = (y[j] + y[j - 1]) % mod;
for (int j = 1;j <= cnt;j++) {
long long cur = y[j];
for (int k = 1;k <= cnt;k++) {
if (k == j) continue;
cur = cur * (vr % mod - x[k] % mod) % mod * Inv(x[j] - x[k]) % mod;
}
ans = (ans + cur) % mod;
}
}
}
cout << (ans % mod + mod) % mod << endl;
}

int main() {
cin >> str + 1; n = strlen(str + 1);
cin >> sl >> sr;
Prefix();
Solve();
return 0;
}