13246번: 행렬 제곱의 합
첫째 줄에 행렬의 크기 N과 B가 주어진다. (2 ≤ N ≤ 5, 1 ≤ B ≤ 100,000,000,000) 둘째 줄부터 N개의 줄에 행렬의 각 원소가 주어진다. 행렬의 각 원소는 1,000보다 작거나 같은 자연수 또는 0이다.
www.acmicpc.net
저번 글(백준 1492번: 합)과 비슷하게 $B$ 개의 항을 전부 구하려 하면 반드시 시간초과가 뜨게 되어있습니다.
즉 이번에도 적절한 방법을 찾아야 하는데, 다행히 1492번 보다는 쉽게 발견할 수 있는 편입니다.
본문의 기호를 좀 차용해서 다음과 같이 써보겠습니다.
$S\left ( B \right )=A^{1}+A^{2}+\cdots+A^{B}$
$B$ 가 짝수인 경우, 다음과 같이 식을 변형할 수 있습니다.
$S\left ( B \right )=\left ( A^{1}+A^{2}+\cdots+A^{\frac{B}{2}} \right )A^{\frac{B}{2}}+\left ( A^{1}+A^{2}+\cdots+A^{\frac{B}{2}} \right )$
$=S\left ( \frac{B}{2} \right )\times \left ( I+A^{\frac{B}{2}} \right )$ (단, $I$ 는 단위행렬)
이를 바탕으로 재귀적으로 풀면 문제를 해결할 수 있습니다.
행렬끼리의 덧셈은 $O\left ( N^{2} \right )$으로 구현하면 충분합니다. (저의 경우 단위행렬을 더하는 건 $O\left ( N \right )$으로 구현하긴 했습니다.)
그런데 $B$ 가 홀수인 경우는 위 식을 맹목적으로 적용하면 오류를 일으킬 수도 있습니다.
$S\left ( B \right )\neq S\left ( \left \lfloor \frac{B}{2} \right \rfloor \right )\times \left ( I+A^{\left \lfloor \frac{B}{2} \right \rfloor} \right )$
이렇게 쓴 식은 좌변과 우변의 차가 $A^{B}$ 만큼 나기 때문입니다.
$S\left ( B \right )\neq S\left ( \left \lfloor \frac{B}{2} \right \rfloor \right )\times \left ( I+A^{\left \lceil \frac{B}{2} \right \rceil} \right )$
이렇게 쓴 식은 좌변과 우변의 차가 $A^{\left \lceil \frac{B}{2} \right \rceil}$ 만큼 납니다.
그 밖에도 몇 가지 더 시도해볼 수 있겠지만, 제가 추천하는 방법은 아래의 두 가지입니다.
$S\left ( B \right )=S\left ( B-1 \right )+A^{B}$
$S\left ( B \right )=S\left ( \left \lfloor \frac{B}{2} \right \rfloor \right )\times \left ( I+A^{\left \lceil \frac{B}{2} \right \rceil} \right )+A^{\left \lceil \frac{B}{2} \right \rceil}$
어느쪽으로 구현하든 사실 시간초과에는 걸리지 않습니다.
개인적으로는 소스코드의 길이가 더 짧은 전자를 선호합니다.
코드는 다음과 같습니다.
#include <stdio.h>
#define mod 1000
typedef struct {
int array[5][5];
} Matrix;
Matrix matrix_multiply_modular (Matrix A, Matrix B, int size);
Matrix matrix_power_modular (Matrix Base, int size, long long power);
Matrix matrix_power_sum_modular (Matrix Base, int size, long long power);
int main() {
int size;
long long power;
scanf("%d %lld", &size, &power);
Matrix Base;
for (int i = 0; i < size; i++) {
for (int j = 0; j < size; j++) {
scanf("%d", &Base.array[i][j]);
Base.array[i][j] %= mod;
}
}
Matrix Answer = matrix_power_sum_modular (Base, size, power);
for (int i = 0; i < size; i++) {
for (int j = 0; j < size; j++) {
printf("%d ", Answer.array[i][j]);
}
printf("\n");
}
return 0;
}
Matrix matrix_multiply_modular (Matrix A, Matrix B, int size) {
Matrix result;
for (int i = 0; i < size; i++) {
for (int j = 0; j < size; j++) {
result.array[i][j] = 0;
for (int k = 0; k < size; k++) {
int temp = (A.array[i][k] * B.array[k][j]) % mod;
result.array[i][j] = (result.array[i][j] + temp) % mod;
}
}
}
return result;
}
Matrix matrix_power_modular (Matrix Base, int size, long long power) {
Matrix result;
for (int i = 0; i < size; i++) {
for (int j = 0; j < size; j++) {
if (j == i) {
result.array[i][j] = 1;
}
else {
result.array[i][j] = 0;
}
}
}
while(power) {
if (power % 2) {
result = matrix_multiply_modular (result, Base, size);
}
Base = matrix_multiply_modular (Base, Base, size);
power /= 2;
}
return result;
}
Matrix matrix_power_sum_modular (Matrix Base, int size, long long power) {
if (power == 1) {
return Base;
}
else if (power % 2) {
Matrix result = matrix_power_modular (Base, size, power);
Matrix temp = matrix_power_sum_modular (Base, size, power - 1);
for (int i = 0; i < size; i++) {
for (int j = 0; j < size; j++) {
result.array[i][j] = (result.array[i][j] + temp.array[i][j]) % mod;
}
}
return result;
}
else {
Matrix result = matrix_power_modular (Base, size, power/2);
Matrix temp = matrix_power_sum_modular (Base, size, power/2);
for (int i = 0; i < size; i++) {
result.array[i][i] = (result.array[i][i] + 1) % mod;
}
result = matrix_multiply_modular (temp, result, size);
return result;
}
}
(전자의 방법을 구현, C11, 1116KB, 0ms, 제출번호 26469324)
#include <stdio.h>
#define mod 1000
typedef struct {
int array[5][5];
} Matrix;
void print_matrix (Matrix A, int size);
Matrix matrix_multiply_modular (Matrix A, Matrix B, int size);
Matrix matrix_power_modular (Matrix Base, int size, long long power);
Matrix matrix_power_sum_modular (Matrix A, int size, long long power);
int main() {
int size;
long long power;
scanf("%d %lld", &size, &power);
Matrix Base;
for (int i = 0; i < size; i++) {
for (int j = 0; j < size; j++) {
scanf("%d", &Base.array[i][j]);
Base.array[i][j] %= mod;
}
}
Matrix Answer = matrix_power_sum_modular (Base, size, power);
print_matrix(Answer, size);
return 0;
}
void print_matrix (Matrix A, int size) {
for (int i = 0; i < size; i++) {
for (int j = 0; j < size; j++) {
printf("%d ", A.array[i][j]);
}
printf("\n");
}
}
Matrix matrix_multiply_modular (Matrix A, Matrix B, int size) {
Matrix result;
for (int i = 0; i < size; i++) {
for (int j = 0; j < size; j++) {
result.array[i][j] = 0;
for (int k = 0; k < size; k++) {
int temp = (A.array[i][k] * B.array[k][j]) % mod;
result.array[i][j] = (result.array[i][j] + temp) % mod;
}
}
}
return result;
}
Matrix matrix_power_modular (Matrix Base, int size, long long power) {
Matrix result;
for (int i = 0; i < size; i++) {
for (int j = 0; j < size; j++) {
if (j == i) {
result.array[i][j] = 1;
}
else {
result.array[i][j] = 0;
}
}
}
while(power) {
if (power % 2) {
result = matrix_multiply_modular (result, Base, size);
}
Base = matrix_multiply_modular (Base, Base, size);
power /= 2;
}
return result;
}
Matrix matrix_power_sum_modular (Matrix A, int size, long long power) {
if (power == 1) {
return A;
}
else if (power % 2) {
long long temp_power = (power + 1) / 2;
Matrix temp1 = matrix_power_modular (A, size, temp_power);
Matrix temp2 = matrix_power_sum_modular (A, size, power / 2);
Matrix result = temp1;
for (int i = 0; i < size; i++) {
temp1.array[i][i] = (temp1.array[i][i] + 1) % mod;
}
Matrix temp = matrix_multiply_modular (temp2, temp1, size);
for (int i = 0; i < size; i++) {
for (int j = 0; j < size; j++) {
result.array[i][j] = (result.array[i][j] + temp.array[i][j]) % mod;
}
}
return result;
}
else {
Matrix result = matrix_power_modular (A, size, power/2);
for (int i = 0; i < size; i++) {
result.array[i][i] = (result.array[i][i] + 1) % mod;
}
Matrix temp = matrix_power_sum_modular (A, size, power/2);
result = matrix_multiply_modular (temp, result, size);
return result;
}
}
(후자의 방법을 구현, C11, 1120KB, 0ms, 제출번호 26468982)
전자든 후자든, 결과적으로 시간복잡도는 $O\left ( N^{3}logB \right )$ 정도가 될 것입니다.
사실 메모리 사용량?도 전자든 후자든 같을 것입니다.
아마 후자에서 4KB가 증가한건 print_matrix 함수의 영향이 다소 있는 듯합니다.
'분할 정복을 이용한 거듭제곱 > 행렬 거듭제곱' 카테고리의 다른 글
백준 16467번: 병아리의 변신은 무죄 (0) | 2021.02.19 |
---|---|
백준 17272번: 리그 오브 레전설 (Large) (0) | 2021.02.19 |
백준 1160번: Random Number Generator (0) | 2021.01.29 |
백준 15712번: 등비수열 (0) | 2021.01.29 |
백준 14440번: 정수 수열 (0) | 2021.01.29 |