ALOHA - 2022/ALOHA - 오늘의 문제 (2022 1학기)

[2022.05.18.] n^m의 약수의 합

hibye1217-aloha 2022. 5. 22. 00:00

문제 정보

링크: https://www.acmicpc.net/problem/11693

난이도: Gold II

분류: [수학 → 정수론] [분할 정복 → 분할 정복을 이용한 거듭제곱]

 

문제 요약

\( N^M \)의 약수의 합을 구하면 됩니다. 정말 정직하죠?

 

풀이

\( N \)을 소인수분해한 결과가 아래와 같다고 합시다.

\( N = p_{1}^{e_{1}} \times p_{2}^{e_{2}} \times \ldots p_{k}^{e_{k}} \)

그럼 \( N^M \)을 소인수분해한 결과는 자명히 아래와 같게 됩니다.

\( N^M = p_{1}^{Me_{1}} \times p_{2}^{Me_{2}} \times \ldots p_{k}^{Me_{k}} \)

지금부터는 이 \( Me_{i} \)를 편의상 \( e_{i} \)라고 재정의하겠습니다.

 

그럼, 약수의 합은 어떻게 구할까요?

만약 위 소인수분해 결과에서 \( k = 1 \)이라면, 약수의 합은 \( 1 + p_{1} + p_{1}^2 + \cdots + p_{1}^{e_{1}} \)이 됩니다.

\( k = 2 \)라면, \( (1 + p_{1} + p_{1}^2 + \cdots + p_{1}^{e_{1}})(1 + p_{2} + p_{2}^2 + \cdots + p_{2}^{e_{2}}) \)가 됩니다.

이를 일반화할 수 있으며, \( N^M \)의 약수의 합은 \( (1 + p_{1} + p_{1}^2 + \cdots + p_{1}^{e_{1}})(1 + p_{2} + p_{2}^2 + \cdots + p_{2}^{e_{2}}) \cdots (1 + p_{k} + p_{k}^2 + \cdots + p_{k}^{e_{k}}) \)가 됩니다.

 

그럼 이 문제의 풀이는 \( N^M \)을 소인수분해해서 \( (p_{i}, e_{i}) \) 쌍을 얻는 거랑

각 \( (p, e) \) 쌍에 대해 \( 1 + p + p^2 + \cdots + p^e \)를 구하는 게 됩니다.

 

앞부분은 \( N \)을 소인수분해한 뒤 각 지수에 \( M \)을 곱하는 방식으로 해결할 수 있으며,

뒷부분은 분할 정복을 이용한 거듭제곱과 이를 응용한 방식으로 해결할 수 있습니다.

뒷부분의 분할 정복을 요약하자면, \( f(p, e) = 1 + p + p^2 + \cdots + p^{e-1} \)이라고 할 때

\( f(p, e) = \begin{cases} 1 & \text{if } e = 1 \\ f(p, e/2) + p^{e/2} + p^{e/2+1}f(p, e/2) & \text{if } e \text{ is odd} \\ f(p, e/2) + p^{e/2}f(p, e/2) & \text{if } e \text{ is even} \end{cases} \)가 됩니다.

const ll mod = 1e9 + 7;

ll fpw(ll a, ll b){
	ll res = 1, mul = a, bit = b;
	while (bit){
		if (bit & 1){ res = res*mul % mod; }
		mul = mul*mul % mod; bit >>= 1;
	}
	return res;
}
ll dnc(ll a, ll b){
	if (b == 1){ return 1; }
	ll bb = b >> 1;
	ll res = dnc(a, bb);
	if (b & 1){
		ll r1 = res, r2 = res * fpw(a, bb+1) % mod, r3 = fpw(a, bb);
		return (r1+r2+r3) % mod;
	}
	else{
		ll r1 = res, r2 = res * fpw(a, bb);
		return (r1+r2) % mod;
	}
}

vector<pl2> v;

void Main(){
	ll n, m; cin >> n >> m;
	if (m == 0){ cout << 1; return; }
	for (ll i = 2; i*i <= n; i++){
		if (n%i != 0){ continue; }
		ll cnt = 0;
		while (n%i == 0){ n /= i; cnt += m; }
		v.push_back({i, cnt});
	}
	if (n != 1){ v.push_back({n, m}); }
	ll ans = 1;
	for (pl2 pr : v){
		ll p = pr.fr, e = pr.sc;
		ll res = dnc(p, e+1);
		//cout << p << ' ' << e << " -> " << res << endl;
		ans *= res; ans %= mod;
	}
	cout << ans;
}