求长度小于等于n
,由前k
个小写字母组成的字符串中,满足双回文串定义的有多少个。
双回文串:本身为回文串或者由两个回文串拼接成的字符串。
先计算双回文串可能的数量,令
\displaystyle R(n) = \sum_{i=0}^{n-1}k^{\lceil \frac{i}{2} \rceil}k^{\lceil \frac{n-i}{2} \rceil}
考虑减去那些重复计算的字符串。可以发现如果一个字符串p
的划分方式不唯一,那么这个字符串必然可以写成s \times m
的形式,其中s
是一个划分唯一的串。然后通过R(n)
的表达式以及s \times m
的性质,这个字符串被统计了m
次。
则唯一划分字符串个数的表达式为\displaystyle D(n) = R(n)-\sum_{l|n, l < n} \frac{n}{l} D(l)
。最后答案为\displaystyle \sum_{i=1}^n \lfloor \frac{n}{i} \rfloor D(i)
R(n)
可以通过简单的数学推导在线性时间复杂度内求出,其余计算的时间复杂度O(n \log n)
。
#include<bits/stdc++.h>
using namespace std;
const int mod = 998244353;
const int maxn = 100010;
int f[maxn];
int pow(int x, int u){
int y = 1;
for (; u; u >>= 1, x = 1LL * x * x % mod) if (u & 1) y = 1LL * x * y % mod;
return y;
}
int main(){
int n, k;
cin >> n >> k;
for (int i = 1; i <= n; i++) {
if (i & 1) f[i] = 1LL * i * pow(k, (i + 1) / 2) % mod;
else f[i] = (1LL* i / 2 * pow(k, i / 2) + 1LL * i / 2 * pow(k, i / 2 + 1)) % mod;
}
for (int i = 1; i <= n; i++)
for (int j = i + i; j <= n; j += i)
f[j] = (f[j] - 1LL * j / i * f[i]) % mod;
int ans = 0;
for (int i = 1; i <= n; i++) ans = (ans + 1LL * f[i] * (n / i)) % mod;
if (ans < 0) ans += mod;
cout << ans << endl;
}