본문 바로가기

분할 정복을 이용한 거듭제곱/정수 거듭제곱

백준 17315번: Matrix Game (integer interpretation)

https://www.acmicpc.net/problem/17315

 

17315번: Matrix Game

The input will contain the six integers n, m, a, b, c, and d.

www.acmicpc.net


National Olympiad in Informatics, China, 2013, Day2에 Problem1로 출제된 문제라고 합니다.

 

행렬 거듭제곱 풀이에 대해 궁금하신 분은 이쪽 링크를 이용하시면 됩니다:Matrix Game (matrix interpretation)


정수 거듭제곱 풀이는 주어진 규칙대로 끝까지 전개하여 다음 식에서 $f$와 $g$를 구하는 것이 목표입니다.

$$ F[n][m] = f(a,b,c,d) \times F[1][1] + g(a,b,c,d) $$

 

일반화를 위해 한 단계씩 식을 전개하면 식이 꽤 길어지기 때문에, 요약하여 보여드리겠습니다. $n,m > 1$에 대하여

$$ F[n][m] = a^{m-1} F[n][1] + b \sum_{i=0}^{m-2} a^{i} $$

$$ \rightarrow F[n][m] = \left ( a^{m-1}c \right ) \times F[n-1][m] + da^{m-1} + b \sum_{i=0}^{m-2} a^{i} $$

$$ \rightarrow F[n][m] = \left ( a^{m-1}c \right )^{2} \times F[n-2][m] + da^{m-1} \left ( 1 + a^{m-1}c \right ) + b \left ( 1 + a^{m-1}c \right ) \sum_{i=0}^{m-2} a^{i} $$

$$ \vdots $$

$$ \rightarrow F[n][m] = \left ( a^{m-1}c \right )^{n-1} \times F[1][m] + da^{m-1} \sum_{i=0}^{n-2} \left ( a^{m-1}c \right )^{i} + b \sum_{i=0}^{n-2} \left ( a^{m-1}c \right )^{i} \sum_{i=0}^{m-2} a^{i} $$

$$ \rightarrow F[n][m] = \left ( a^{m-1}c \right )^{n-1} a^{m-1} \times F[1][1] + \left ( a^{m-1}d \right ) \sum_{i=0}^{n-2} \left ( a^{m-1}c \right )^{i} + b \sum_{i=0}^{n-1} \left ( a^{m-1}c \right )^{i} \sum_{i=1}^{m-2} a^{i} $$

$$ = \left ( a^{m-1}c \right )^{n-1} a^{m-1} + \left ( a^{m-1}d \right ) \sum_{i=0}^{n-2} \left ( a^{m-1}c \right )^{i} + b \sum_{i=0}^{n-1} \left ( a^{m-1}c \right )^{i} \sum_{i=1}^{m-2} a^{i} $$


위 식이 행렬 때와 마찬가지로 $n,m \geq 1$에 대하여 모두 만족하는 것을 쉽게 보일 수 있습니다. (참고) 그리고 $n,m$을 문자열로 저장하여 분할정복으로 거듭제곱하는 방법은 행렬 글에서 설명드렸습니다.


그러면 이제 남은 부분은 summation 함수입니다. 함수 $S(r,k)$을 다음과 같이 정의합니다.

$$ S(r,k) = \sum_{i=0}^{k} r^{i} $$

이렇게 하면 $S(1,k)=k+1$입니다. 나머지 $r,k$에 대하여 $S(r,k)=\frac{r^{k+1}-1}{r-1}$ 인 것은 고등학교 수학에서 나옵니다.

 

Summation 함수를 구현할 때 주의할 점은 $r=1$일 때인데, $n,m$을 $10^{9}+7$로 나눈 나머지를 구해야하기 때문입니다.

$$ n = d_{l}d_{l-1} \cdots d_{2}d_{1} = d_{1} + 10^{1}d_{2} + \cdots + 10^{l-1}d_{l} = \sum_{i=1}^{l} 10^{i-1}d_{i} $$

라는 것을 알면 쉽게 구현할 수 있을 것입니다. 제 코드는 문자열 $S$에 대해 $O( |S| )$인데 더 빠른 방법이 있을 지는 잘 모르겠습니다.

 

다음은 제 코드입니다.

#include <stdio.h>
#include <string.h>
#define mod 1000000007

long long mod_of_string (char* P, int L);
long long power_modular (long long X, long long Y);
long long power_modular_log10 (long long X, char* P, int L);
long long sum_geometric_modular_log10 (long long R, char* P, int L, int flag);

int main(void) {
	
	char N[1000002], M[1000002];
	long long a, b, c, d;
	scanf("%s %s %lld %lld %lld %lld", N, M, &a, &b, &c, &d);
	
	int lengthN = strlen(N);
	int lengthM = strlen(M);
	
	long long inv_a = power_modular(a, mod - 2);
	long long inv_c = power_modular(c, mod - 2);
	
	long long base = power_modular_log10(a, M, lengthM) * inv_a % mod;
	
	long long temp1 = base * d % mod;
	
	base = c * base % mod;
	long long temp2 = power_modular_log10(base, N, lengthN) * inv_c % mod;
	
	temp1 = temp1 * sum_geometric_modular_log10(base, N, lengthN, 2) % mod;
	
	long long temp3 = sum_geometric_modular_log10(a, M, lengthM, 2)
                    * sum_geometric_modular_log10(base, N, lengthN, 1) % mod;
	temp3 = b * temp3 % mod;
	
	long long answer = temp1 + temp2 + temp3;
	printf("%lld", answer % mod);
	return 0;
}

long long mod_of_string (char* P, int L) {
	long long result = 0, X = 1;
	for (int i = L-1; i >= 0; i--) {
		result = (result + X * (P[i] - '1' + 1)) % mod;
		X = X * 10 % mod;
	}
	return result;
}

long long power_modular (long long X, long long Y) {
	long long result = 1;
	while(Y) {
		if (Y % 2) {
			result = result * X % mod;
		}
		X = X * X % mod;
		Y /= 2;
	}
	return result;
}

long long power_modular_log10 (long long X, char* P, int L) {
	long long result = 1;
	for (int i = L-1; i >= 0; i--) {
		result = result * power_modular(X, P[i] - '1' + 1) % mod;
		X = power_modular(X, 10);
	}
	return result;
}

long long sum_geometric_modular_log10 (long long R, char* P, int L, int flag) {
    if (L == 1 && P[0] - '1' + 1 < flag) return 0;
	if (R == 1) {
		return (mod_of_string(P,L) - flag + 1 + mod) % mod;
	}
	else {
		long long inv = power_modular(power_modular(R, mod - 2), flag - 1);
		long long numerator = ((power_modular_log10(R, P, L) * inv % mod) - 1 + mod) % mod;
		long long denominator = power_modular(R - 1, mod - 2);
		return numerator * denominator % mod;
	}
}

(C11, 144ms, 2948KB, 제출번호 54457005)


Euler's Theorem을 생각하면 어차피 $a,b,c,d$가 $p$와 서로소이므로 지수의 범위를 크게 낮출 수 있음을 알 수 있습니다. (C11, 48ms, 2952KB, 제출번호 54495403) 솔직히 기대 안 했는데, C++ 고수 분들의 코드보다 약간 더 빨라서 정말 만족스럽습니다.