Loading [MathJax]/jax/output/CommonHTML/jax.js
본문 바로가기

분할 정복을 이용한 거듭제곱/행렬 거듭제곱

백준 29705번: 문자열 만들기

https://www.acmicpc.net/problem/29705

 

29705번: 문자열 만들기

알파벳 대문자 ‘A-Z’, 알파벳 소문자 ‘a-z’, 숫자 ‘0-9’를 원소로 가지는 집합 C의 부분 집합 P를 정의역으로 하는 함수 f(c)가 있다. f(c)의 값은 1 이상 9 이하의 자연수이다. 1 이상 $

www.acmicpc.net


Notation이 좀 난해한 문제였습니다.

 

그러나 25962번과 아이디어가 비슷하다는 걸 캐치하면, 풀이의 앞 부분은 쉽게 해결됩니다.


우선 자연수 N에 대해 g(x)=Nx의 개수를 AN으로 놓으면

AN+9=|S1|×AN+8+|S2|×AN+7++|S8|×AN+1+|S9|×AN

이고, N이 작은 경우

AN=Ni=1(|Si|×ANi)(N9,A0=1)

입니다.

 

이런 점화식에서 AN 하나의 값을 구하는 시간 복잡도는 O(dN)부터 O(d3logN), O(d2logN), O(dlogdlogN) 까지 다양합니다. (d=9)

 

그러나, 이 문제는 구간 합을 구하길 원하므로 Aa부터 Ab까지의 값을 모두 구하느냐, 혹은 더 세련된 방법을 쓰느냐로 접근 방식을 나누게 됩니다.


전자의 경우, FFT + Kitamasa를 쓰더라도 ba의 상한이 5×106이므로, 구간에 속한 모든 값을 하나하나 구할 수는 없습니다. (다시 말해, O(dlogd×(ba))는 시간 초과입니다.)

 

하지만, Aa부터 Aa+8까지의 값을 구하면, 이후 Aa+9부터 Ab까지의 값은 O(9(ba))로 구할 수 있게 됩니다.

 

한편 이 9개의 값을 구하는 것은 어떻게 구하더라도 큰 상관은 없겠지만, 저는 아래 식을 썼습니다.

[Aa+8Aa+7Aa+6Aa+1Aa]=[|S1||S2||S3||S8||S9|10000010000000000010]a[A8A7A6A1A0]

 

이 방법의 시간복잡도는 결국 O(d3loga+d(ba)) 인데, 캐시 덕분인지 생각보다는 빠른 실행시간이 나오는 걸로 보입니다.


후자는 약간의 식 변형을 필요로 하겠습니다.

행렬 Mij열을 Mij로 표기한다면, (저는 0-index를 쓰겠습니다)

AN=([|S1||S2||S3||S8||S9|10000010000000000010]N[A8A7A6A1A0])80

이므로, BN=Ni=0AN 이면,

BN=Ni=0([|S1||S2||S3||S8||S9|10000010000000000010]i[A8A7A6A1A0])80

=((Ni=0[|S1||S2||S3||S8||S9|10000010000000000010]i)[A8A7A6A1A0])80

으로 식이 정리가 되며, d×d 행렬의 거듭제곱의 합 또한 O(d3logN)에 계산됩니다.

 

즉, 이 방법의 시간 복잡도는 O(d3(loga+logb))가 됩니다.

 

참고로, 어떤 좋은 행렬에 대해서는 geometric progression처럼 취급하여 거듭제곱의 합을 구할 수 있습니다만,

 

역행렬이 보장되는지도 문제이고, 구현량도 많아지는 것 같습니다. 이 글과 이 글을 참고해볼 수는 있을 것 같습니다.

 

일단 아래는 제 코드입니다.

#include <stdio.h>
#include <string.h>
#define mod 1000000007

typedef struct {
    unsigned long long array[9][9];
} Matrix;

Matrix I;

Matrix matrix_multiply_modular (Matrix A, Matrix B);
Matrix matrix_power_modular (Matrix A, long long P);
Matrix matrix_power_sum_modular (Matrix A, long long P);

int main() {
    
    for (int i = 0; i < 9; i++) I.array[i][i] = 1;
    
    long long a, b;
    scanf("%lld %lld", &a, &b);

    char string[10][51] = {{}};
    long long len[10] = {};
    for (int i = 1; i <= 9; i++) {
        scanf("%s", string[i]);
        len[i] = strlen(string[i]);
    }
    
    long long s[10] = {1};
    for (int i = 1; i <= 9; i++) {
        for (int j = 1; j <= i; j++) {
            s[i] += s[i - j] * len[j];
        }
        s[i] %= mod;
    }
    
    Matrix simple = {{{}}};
    for (int i = 0; i < 9; i++) {
        simple.array[i][0] = s[i];
    }
    
    Matrix base = {{{}}};
    for (int i = 0; i < 9; i++) {
        base.array[8][i] = len[9 - i];
    }
    for (int i = 0; i + 1 < 9; i++) {
        base.array[i][i + 1] = 1;
    }
    
    Matrix temp;
    
    temp = matrix_multiply_modular(matrix_power_sum_modular(base, b), simple);
    long long ans_b = temp.array[0][0];
    
    temp = matrix_multiply_modular(matrix_power_sum_modular(base, a - 1), simple);
    long long ans_am1 = temp.array[0][0];
    
    long long answer = (ans_b + mod - ans_am1) % mod;
    printf("%lld", answer);
    
    return 0;
}

Matrix matrix_addition_modular (Matrix A, Matrix B) {
    for (int i = 0; i < 9; i++) {
        for (int j = 0; j < 9; j++) {
            A.array[i][j] = (A.array[i][j] + B.array[i][j]) % mod;
        }
    }
    return A;
}

Matrix matrix_multiply_modular (Matrix A, Matrix B) {
    Matrix result = {{{}}};
    for (int i = 0; i < 9; i++) {
        for (int j = 0; j < 9; j++) {
            for (int k = 0; k < 9; k++) {
                result.array[i][j] += A.array[i][k] * B.array[k][j];
            }
            result.array[i][j] %= mod;
        }
    }
    return result;
}

Matrix matrix_power_modular (Matrix A, long long P) {
    Matrix result = I;
    while(P) {
        if (P % 2) {
            result = matrix_multiply_modular(result, A);
        }
        A = matrix_multiply_modular(A, A);
        P /= 2;
    }
    return result;
}

Matrix matrix_power_sum_modular (Matrix A, long long P) {
    // A^0 + A^1 + A^2 + ... + A^P
	if (P <= 0) {
		return I;
	}
	if (P % 2) {
		return matrix_multiply_modular(matrix_power_sum_modular(A, P / 2),
                                       matrix_addition_modular(matrix_power_modular(A, (P + 1) / 2), I));
	}
	else {
		return matrix_addition_modular(matrix_power_modular(A, P / 2),
		                               matrix_multiply_modular(matrix_power_sum_modular(A, P / 2 - 1),
		                                                       matrix_addition_modular(I,
		                                                                               matrix_power_modular(A, P / 2 + 1))));
	}
}

(C11, 8ms, 1384KB, 제출번호 67070258)


생각 이상으로 제 matrix_power_sum이 비효율적이었는지, 28ms부터 8ms까지 널뛰는 것 같습니다.

 

12987번을 다시 확인해봐야할 것 같습니다...