[2022.05.18.] n^m의 약수의 합
문제 정보
링크: https://www.acmicpc.net/problem/11693
난이도: Gold II
분류: [수학 → 정수론] [분할 정복 → 분할 정복을 이용한 거듭제곱]
문제 요약
\( N^M \)의 약수의 합을 구하면 됩니다. 정말 정직하죠?
풀이
\( N \)을 소인수분해한 결과가 아래와 같다고 합시다.
\( N = p_{1}^{e_{1}} \times p_{2}^{e_{2}} \times \ldots p_{k}^{e_{k}} \)
그럼 \( N^M \)을 소인수분해한 결과는 자명히 아래와 같게 됩니다.
\( N^M = p_{1}^{Me_{1}} \times p_{2}^{Me_{2}} \times \ldots p_{k}^{Me_{k}} \)
지금부터는 이 \( Me_{i} \)를 편의상 \( e_{i} \)라고 재정의하겠습니다.
그럼, 약수의 합은 어떻게 구할까요?
만약 위 소인수분해 결과에서 \( k = 1 \)이라면, 약수의 합은 \( 1 + p_{1} + p_{1}^2 + \cdots + p_{1}^{e_{1}} \)이 됩니다.
\( k = 2 \)라면, \( (1 + p_{1} + p_{1}^2 + \cdots + p_{1}^{e_{1}})(1 + p_{2} + p_{2}^2 + \cdots + p_{2}^{e_{2}}) \)가 됩니다.
이를 일반화할 수 있으며, \( N^M \)의 약수의 합은 \( (1 + p_{1} + p_{1}^2 + \cdots + p_{1}^{e_{1}})(1 + p_{2} + p_{2}^2 + \cdots + p_{2}^{e_{2}}) \cdots (1 + p_{k} + p_{k}^2 + \cdots + p_{k}^{e_{k}}) \)가 됩니다.
그럼 이 문제의 풀이는 \( N^M \)을 소인수분해해서 \( (p_{i}, e_{i}) \) 쌍을 얻는 거랑
각 \( (p, e) \) 쌍에 대해 \( 1 + p + p^2 + \cdots + p^e \)를 구하는 게 됩니다.
앞부분은 \( N \)을 소인수분해한 뒤 각 지수에 \( M \)을 곱하는 방식으로 해결할 수 있으며,
뒷부분은 분할 정복을 이용한 거듭제곱과 이를 응용한 방식으로 해결할 수 있습니다.
뒷부분의 분할 정복을 요약하자면, \( f(p, e) = 1 + p + p^2 + \cdots + p^{e-1} \)이라고 할 때
\( f(p, e) = \begin{cases} 1 & \text{if } e = 1 \\ f(p, e/2) + p^{e/2} + p^{e/2+1}f(p, e/2) & \text{if } e \text{ is odd} \\ f(p, e/2) + p^{e/2}f(p, e/2) & \text{if } e \text{ is even} \end{cases} \)가 됩니다.
const ll mod = 1e9 + 7;
ll fpw(ll a, ll b){
ll res = 1, mul = a, bit = b;
while (bit){
if (bit & 1){ res = res*mul % mod; }
mul = mul*mul % mod; bit >>= 1;
}
return res;
}
ll dnc(ll a, ll b){
if (b == 1){ return 1; }
ll bb = b >> 1;
ll res = dnc(a, bb);
if (b & 1){
ll r1 = res, r2 = res * fpw(a, bb+1) % mod, r3 = fpw(a, bb);
return (r1+r2+r3) % mod;
}
else{
ll r1 = res, r2 = res * fpw(a, bb);
return (r1+r2) % mod;
}
}
vector<pl2> v;
void Main(){
ll n, m; cin >> n >> m;
if (m == 0){ cout << 1; return; }
for (ll i = 2; i*i <= n; i++){
if (n%i != 0){ continue; }
ll cnt = 0;
while (n%i == 0){ n /= i; cnt += m; }
v.push_back({i, cnt});
}
if (n != 1){ v.push_back({n, m}); }
ll ans = 1;
for (pl2 pr : v){
ll p = pr.fr, e = pr.sc;
ll res = dnc(p, e+1);
//cout << p << ' ' << e << " -> " << res << endl;
ans *= res; ans %= mod;
}
cout << ans;
}