ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • [2022.05.18.] n^m의 약수의 합
    ALOHA - 2022/ALOHA - 오늘의 문제 (2022 1학기) 2022. 5. 22. 00:00

    문제 정보

    링크: https://www.acmicpc.net/problem/11693

    난이도: Gold II

    분류: [수학 → 정수론] [분할 정복 → 분할 정복을 이용한 거듭제곱]

     

    문제 요약

    \( N^M \)의 약수의 합을 구하면 됩니다. 정말 정직하죠?

     

    풀이

    \( N \)을 소인수분해한 결과가 아래와 같다고 합시다.

    \( N = p_{1}^{e_{1}} \times p_{2}^{e_{2}} \times \ldots p_{k}^{e_{k}} \)

    그럼 \( N^M \)을 소인수분해한 결과는 자명히 아래와 같게 됩니다.

    \( N^M = p_{1}^{Me_{1}} \times p_{2}^{Me_{2}} \times \ldots p_{k}^{Me_{k}} \)

    지금부터는 이 \( Me_{i} \)를 편의상 \( e_{i} \)라고 재정의하겠습니다.

     

    그럼, 약수의 합은 어떻게 구할까요?

    만약 위 소인수분해 결과에서 \( k = 1 \)이라면, 약수의 합은 \( 1 + p_{1} + p_{1}^2 + \cdots + p_{1}^{e_{1}} \)이 됩니다.

    \( k = 2 \)라면, \( (1 + p_{1} + p_{1}^2 + \cdots + p_{1}^{e_{1}})(1 + p_{2} + p_{2}^2 + \cdots + p_{2}^{e_{2}}) \)가 됩니다.

    이를 일반화할 수 있으며, \( N^M \)의 약수의 합은 \( (1 + p_{1} + p_{1}^2 + \cdots + p_{1}^{e_{1}})(1 + p_{2} + p_{2}^2 + \cdots + p_{2}^{e_{2}}) \cdots (1 + p_{k} + p_{k}^2 + \cdots + p_{k}^{e_{k}}) \)가 됩니다.

     

    그럼 이 문제의 풀이는 \( N^M \)을 소인수분해해서 \( (p_{i}, e_{i}) \) 쌍을 얻는 거랑

    각 \( (p, e) \) 쌍에 대해 \( 1 + p + p^2 + \cdots + p^e \)를 구하는 게 됩니다.

     

    앞부분은 \( N \)을 소인수분해한 뒤 각 지수에 \( M \)을 곱하는 방식으로 해결할 수 있으며,

    뒷부분은 분할 정복을 이용한 거듭제곱과 이를 응용한 방식으로 해결할 수 있습니다.

    뒷부분의 분할 정복을 요약하자면, \( f(p, e) = 1 + p + p^2 + \cdots + p^{e-1} \)이라고 할 때

    \( f(p, e) = \begin{cases} 1 & \text{if } e = 1 \\ f(p, e/2) + p^{e/2} + p^{e/2+1}f(p, e/2) & \text{if } e \text{ is odd} \\ f(p, e/2) + p^{e/2}f(p, e/2) & \text{if } e \text{ is even} \end{cases} \)가 됩니다.

    const ll mod = 1e9 + 7;
    
    ll fpw(ll a, ll b){
    	ll res = 1, mul = a, bit = b;
    	while (bit){
    		if (bit & 1){ res = res*mul % mod; }
    		mul = mul*mul % mod; bit >>= 1;
    	}
    	return res;
    }
    ll dnc(ll a, ll b){
    	if (b == 1){ return 1; }
    	ll bb = b >> 1;
    	ll res = dnc(a, bb);
    	if (b & 1){
    		ll r1 = res, r2 = res * fpw(a, bb+1) % mod, r3 = fpw(a, bb);
    		return (r1+r2+r3) % mod;
    	}
    	else{
    		ll r1 = res, r2 = res * fpw(a, bb);
    		return (r1+r2) % mod;
    	}
    }
    
    vector<pl2> v;
    
    void Main(){
    	ll n, m; cin >> n >> m;
    	if (m == 0){ cout << 1; return; }
    	for (ll i = 2; i*i <= n; i++){
    		if (n%i != 0){ continue; }
    		ll cnt = 0;
    		while (n%i == 0){ n /= i; cnt += m; }
    		v.push_back({i, cnt});
    	}
    	if (n != 1){ v.push_back({n, m}); }
    	ll ans = 1;
    	for (pl2 pr : v){
    		ll p = pr.fr, e = pr.sc;
    		ll res = dnc(p, e+1);
    		//cout << p << ' ' << e << " -> " << res << endl;
    		ans *= res; ans %= mod;
    	}
    	cout << ans;
    }
Designed by Tistory.