본문 바로가기

분할 정복을 이용한 거듭제곱/행렬 거듭제곱

백준 13430번: 합 구하기

 

www.acmicpc.net/problem/13430

 

13430번: 합 구하기

첫째 줄에 k와 n이 주어진다. (1 ≤ k ≤ 50, 1 ≤ n ≤ 1,000,000,000)

www.acmicpc.net


기본적인 아이디어를 얻기 위해 작은 경우부터 나열해보는 게 좋은 문제입니다.

 

$S\left ( i,j \right)$ 의 값은

 

$i=0$ 일 때 : $1$, $2$,  $3$,   $4$,    $5$, $\cdots$

 

$i=1$ 일 때 : $1$, $3$,  $6$,  $10$,  $15$, $\cdots$

 

$i=2$ 일 때 : $1$, $4$, $10$, $20$,  $35$, $\cdots$

 

$i=3$ 일 때 : $1$, $5$, $15$, $35$,  $70$, $\cdots$

 

$i=4$ 일 때 : $1$, $6$, $21$, $56$, $126$, $\cdots$

 

 

이렇게 쭉 나열했을 때 (표를 그려보면 보기 더 쉬워집니다) 규칙이 보일까요?

 

우리가 찾아야하는 규칙은 $n$의 값이 아주 크기 때문에 $S\left(k,n \right)$ 과 $S\left(k,n-1 \right)$ 사이의 규칙이 되면 좋을 것입니다.

 

첫 번째로, $S\left (i,1 \right )$ 의 값은 항상 $1$ 이 됩니다.

 

증명은 간단하게 할 수 있습니다.

 

주어진 점화식에서 $S\left (k,1 \right)=S\left(k-1,1 \right)=\cdots=S\left (1,1 \right)=S\left (0,1 \right)=1$ 이기 때문입니다.

 

두 번째로, 조금 더 발견하기 어렵지만 $S\left ( i,j \right )=S\left ( i-1,j \right )+S\left ( i,j-1 \right )$ 을 추측할 수 있습니다. $\left( i,j\geq 1\right)$

 

이것도 증명할 수 있습니다.

 

주어진 점화식에서 $S\left(k-1,1 \right)+S\left(k-1,2 \right)+\cdots+S\left(k-1,n-1 \right)=S\left(k,n-1 \right)$입니다.

 

따라서 $S\left(k,n \right)=S\left(k,n-1 \right)+S\left(k-1,n \right)$ 입니다.

 

마지막으로, 두 번째 성질을 확장해서 $S\left(i,j \right)=1+S\left(0,j-1 \right)+S\left(1,j-1 \right)+\cdots+S\left(i,j-1 \right)$이 성립합니다.

 

이는 다음의 과정을 거쳐 보일 수 있습니다.

 

$S\left(k,n \right)=S\left(k,n-1 \right)+S\left(k-1,n \right)$

$S\left(k,n \right)=S\left(k,n-1 \right)+S\left(k-1,n-1 \right)+S\left(k-2,n \right)$

$S\left(k,n \right)=S\left(k,n-1 \right)+S\left(k-1,n-1 \right)+S\left(k-2,n-1 \right)+S\left(k-3,n \right)$

$\vdots$

$S\left(k,n \right)=S\left(k,n-1 \right)+S\left(k-1,n-1 \right)+\cdots+S\left(2,n-1 \right)+S\left(1,n \right)$

$S\left(k,n \right)=S\left(k,n-1 \right)+S\left(k-1,n-1 \right)+\cdots+S\left(2,n-1 \right)+S\left(1,n-1 \right)+S\left(0,n \right)$

$S\left(k,n \right)=S\left(k,n-1 \right)+S\left(k-1,n-1 \right)+\cdots+S\left(2,n-1 \right)+S\left(1,n-1 \right)+S\left(0,n-1 \right)+1$

 

이를 행렬 점화식으로 나타내려면 어떻게 해야 할까요?

 

우선 $S\left(1,n \right)$ 부터 생각해봅시다.

 

$S\left(1,n \right)=S\left(1,n-1 \right)+S\left(0,n-1 \right)+1$ 인데,

 

이를 만족하는 점화식이 많이 있겠지만 저는 다음 식을 사용하겠습니다.

 

$\begin{bmatrix}1\\ S\left(0,n \right)\\ S\left(1,n \right)\end{bmatrix}=\begin{bmatrix}1 & 0 & 0\\ 1 & 1 & 0\\ 1 & 1 & 1\end{bmatrix}\begin{bmatrix}1\\ S\left(0,n-1 \right)\\ S\left(1,n-1 \right)\end{bmatrix}$

 

이제 이 식을 일반화하면 다음을 얻습니다.

 

$\begin{bmatrix}1\\ S\left(0,n \right)\\ S\left(1,n \right)\\ \vdots \\ S\left(k-1,n \right)\\ S\left(k,n \right)\end{bmatrix}=\begin{bmatrix}1 & 0 & 0 & \cdots & 0 & 0\\ 1 & 1 & 0 & \cdots & 0 & 0\\ 1 & 1 & 1 & \cdots & 0 & 0\\ \vdots & \vdots & \vdots & \ddots & \vdots & \vdots\\ 1 & 1 & 1 & \cdots & 1 & 0\\ 1 & 1 & 1 & \cdots & 1 & 1\end{bmatrix}\begin{bmatrix}1\\ S\left(0,n-1 \right)\\ S\left(1,n-1 \right)\\ \vdots\\ S\left(k-1,n-1 \right)\\ S\left(k,n-1 \right)\end{bmatrix}$

 

$S\left(i,1 \right)=1$임을 이용하여 구현하면 문제가 해결됩니다.

 

코드는 다음과 같습니다.

 

#include <stdio.h>

typedef struct {
	long long array[52][52];
} Matrix;

Matrix matrix_multiply_modular (Matrix A, Matrix B, int n, long long M); // n x n square matrix AB
Matrix matrix_power_modular (Matrix A, int n, long long K, long long M); // n x n square matrix power k

long long Get_S (int k, long long n);

int main() {
	int K;
	long long N;
	scanf("%d %lld", &K, &N);
	printf("%lld", Get_S(K, N));
	return 0;
}

Matrix matrix_multiply_modular (Matrix A, Matrix B, int n, long long M) {
	Matrix result;
	int i, j, k;
	for (i = 0; i < n; i++) {
		for (j = 0; j < n; j++) {
			result.array[i][j] = 0;
			for (k = 0; k < n; k++) {
				long long temp = (A.array[i][k] * B.array[k][j]) % M;
				result.array[i][j] = (result.array[i][j] + temp) % M;
			}
		}
	}
	return result;
}

Matrix matrix_power_modular (Matrix A, int n, long long K, long long M) {
	Matrix result;
	int i, j;
	for (i = 0; i < n; i++) {
		for (j = 0; j < n; j++) {
			if (i == j) {
				result.array[i][j] = 1;
			}
			else {
				result.array[i][j] = 0;
			}
		}
	}
	while (K) {
		if (K & 1) {
			result = matrix_multiply_modular (result, A, n, M);
		}
		A = matrix_multiply_modular (A, A, n, M);
		K >>= 1;
	}
	return result;
}

long long Get_S (int k, long long n) {
	Matrix Base;
	int i, j;
	for (i = 0; i < k + 2; i++) {
		for (j = 0; j <= i; j++) {
			Base.array[i][j] = 1;
		}
		for (j = i + 1; j < k + 2; j++) {
			Base.array[i][j] = 0;
		}
	}
	Base = matrix_power_modular (Base, k+2, n-1, 1000000007);
	Matrix result;
	for (i = 0; i < k + 2; i++) {
		for (j = 0; j < k + 2; j++) {
			if (j == 0) {
				result.array[i][j] = 1;
			}
			else {
				result.array[i][j] = 0;
			}
		}
	}
	result = matrix_multiply_modular (Base, result, k+2, 1000000007);
	return result.array[k+1][0];
}

 

(C11, 1184KB, 148ms, 제출번호 25818591)


사실 20일 전에 제출한 코드인데, 이제 보니 미묘하게 고치고 싶고 그렇습니다.

 

위 코드는 다음과 같이 수정됐습니다. (2021/4/4)

 

#include <stdio.h>
#define mod 1000000007

typedef struct {
	unsigned long long array[52][52];
} Matrix;

Matrix matrix_multiply_modular (Matrix A, Matrix B, int n); // n x n square matrix AB
Matrix matrix_power_modular (Matrix A, int n, long long K); // n x n square matrix power k

void Get_S (int k, long long n);

int main() {
	int K;
	long long N;
	scanf("%d %lld", &K, &N);
	Get_S(K, N);
	return 0;
}

Matrix matrix_multiply_modular (Matrix A, Matrix B, int n) {
	Matrix result = { {{0}} };
	for (int i = 0; i < n; i++) {
		for (int k = 0; k < n; k++) {
			for (int j = 0; j < n; j++) {
				unsigned long long temp = A.array[i][k] * B.array[k][j] % mod;
				result.array[i][j] = (result.array[i][j] + temp) % mod;
			}
		}
	}
	return result;
}

Matrix matrix_power_modular (Matrix A, int n, long long K) {
	Matrix result = { {{0}} };
	for (int i = 0; i < n; i++) {
		result.array[i][i] = 1ULL;
	}
	while (K) {
		if (K & 1) {
			result = matrix_multiply_modular (result, A, n);
		}
		A = matrix_multiply_modular (A, A, n);
		K >>= 1;
	}
	return result;
}

void Get_S (int k, long long n) {
	Matrix Base = { {{0}} };
	for (int i = 0; i < k + 2; i++) {
		for (int j = 0; j <= i; j++) {
			Base.array[i][j] = 1;
		}
	}
	Base = matrix_power_modular (Base, k+2, n-1);
    long long result = 0;
    for (int i = 0; i < k + 2; i++) {
        result = (result + Base.array[k+1][i]) % mod;
    }
    printf("%lld", result);
}

(C11, 1164KB, 20ms, 제출번호 28021893)