Loading [MathJax]/jax/output/CommonHTML/jax.js
본문 바로가기

분할 정복을 이용한 거듭제곱/행렬 거듭제곱

백준 24660번: High Powers

https://www.acmicpc.net/problem/24660

 

24660번: High Powers

It can be shown that the answer can be represented as a rational number p/q where p and q are integers, (p,q)=1, q>0 and q is not divisible by 998244353.

www.acmicpc.net


문제의 생김새와는 다르게 풀이는 매우 깔끔했습니다.

 

문제가 요구하는 것은 결국

an(bmcm)+bn(cmam)+cn(ambm(ab)(bc)(ca)

의 값을 modulo 998244353으로 표현하는 것입니다.

 

제가 여기서 했던 실수는, 분모와 분자를 각각 s,t,u에 관한 식으로 환원하여 값을 구하려고 했던 것입니다.

 

물론 분모, 분자를 각각 s,t,u로 잘 표현해낼 수만 있다면, 이 풀이는 매우 타당합니다.

 

998244353은 소수이며 s,t,u의 범위에 의해 분모는 반드시 역원이 존재하는 값이 될 것이기 때문입니다.

 

또한, 분모는 다음과 같이 정리되므로

(ab)(bc)(ca)=(1)×{a2(bc)+b2(ca)+c2(ab)}

저는 한동안 분모로부터 분자로 이어지는 점화식을 찾기 위해 노력했었습니다.

 

그 결과, 괴상한 점화식들만 잔뜩 생기고 점화식을 알아도 답을 구할 수 없는 상황이 펼쳐졌습니다. 이때 조언을 들어 3차방정식을 생각한 게 신의 한 수였습니다.


문제의 조건에 따라 a,b,c는 다음 3차방정식의 해에 대응합니다.

x3sx2+txu=0

이때 다음을 알 수 있습니다.

an+3san+2+tan+1uan=0

이를 다시 발전시키면 다음의 세 식을 얻을 수 있습니다.

an+3(bmcm)san+2(bmcm)+tan+1(bmcm)uan(bmcm)=0

bn+3(cmam)sbn+2(cmam)+tbn+1(cmam)ubn(cmam)=0

cn+3(ambm)scn+2(ambm)+tcn+1(ambm)ucn(ambm)=0

이제 이 세 식을 더해서, 문제에서 주어진 분모로 나누면, 완벽한 점화식을 얻게 됩니다.

poly(x,y)=ax(bycy)+bx(cyay)+cx(ayby)(ab)(bc)(ca)

일 때,

poly(n+3,m)s×poly(n+2,m)+t×poly(n+1,m)u×poly(n,m)=0

입니다.

 

다만, n,m의 범위에 의해 시간 안에 통과하려면 행렬 형태로 바꿔야겠습니다.

[poly(n+3,m)poly(n+2,m)poly(n+1,m)]=[stu100010][poly(n+2,m)poly(n+1,m)poly(n,m)]

[poly(n+3,m)poly(n+2,m)poly(n+1,m)]=[stu100010]n[poly(3,m)poly(2,m)poly(1,m)]

여기서 완전한 답을 얻어내려면 다음의 식을 캐치해야 합니다.

poly(n,m)=poly(m,n)

자명하게 성립하는 이 식을 사용하면 답을 구할 수 있게 됩니다. 시간 복잡도는 O(k3logn)이고 k=3입니다.


이제 남은 건 n,m이 충분히 작을 때의 답을 구하는 것입니다.

 

위에서 사용했던 대칭성에 관한 식을 가져오면, 다음을 얻을 수 있습니다.

poly(n,n)=0

이로부터 우리가 구해야 하는 범위는 n,m3이었는데, 실제로 구해야 하는 경우의 수는 3가지 뿐임을 알게 되었습니다.

 

그런데 poly(n,m)은 결국 a,b,c에 관한 n+m3차 다항식임을 생각하면, 직관적으로 답을 얻어낼 수 있게 됩니다.

(또한, 이런 성질들을 사용하는 게 좋습니다.)

 

그렇게 얻어낸 답을 검증하는 것은 아쉽게도 다항식의 곱셈을 직접 전개하는 게 제일 빠를 것 같습니다.

 

저는 wolframalpha를 이용해서 검증했었습니다.

 

이제 이를 구현하면 끝입니다. 음수 모듈러에 주의하여 구현합니다.

#include <stdio.h>
#define mod 998244353

typedef struct {
    long long array[3][3];
} Matrix;

Matrix matrix_multiply_modular (Matrix A, Matrix B);
Matrix matrix_power_modular (Matrix A, long long P);

int main() {
    
    long long N, M, s, t, u;
    scanf("%lld %lld %lld %lld %lld", &N, &M, &s, &t, &u);
    
    Matrix Base =
    {{
        {s, -t, u},
        {1,  0, 0},
        {0,  1, 0}
    }};
    Matrix tempN = matrix_power_modular(Base, N-3);
    Matrix tempM = matrix_power_modular(Base, M-3);
    
    long long initialCases[4][4] =
    {
        {0,  0,  0, 0},
        {0,  0,  1, s},
        {0, -1,  0, t},
        {0, -s, -t, 0}
    };
    
    
    long long answer = 0;
    if (N <= 3) {
        if (M <= 3) {
            answer = initialCases[N][M];
        }
        else {
            for (int i = 0; i < 3; i++) {
                answer = (answer + tempM.array[0][i] * initialCases[3-i][N]) % mod;
            }
            answer *= -1;
        }
    }
    else if (M <= 3) {
        for (int i = 0; i < 3; i++) {
            answer = (answer + tempN.array[0][i] * initialCases[3-i][M]) % mod;
        }
    }
    else {
        long long temp[4] = {};
        for (int i = 1; i <= 3; i++) {
            for (int j = 0; j < 3; j++) {
                temp[i] = (temp[i] + tempM.array[0][j] * initialCases[3-j][i]) % mod;
            }
            temp[i] *= -1; // (M, i) *= -1
        }
        
        for (int i = 0; i < 3; i++) {
            answer = (answer + tempN.array[0][i] * temp[3-i]) % mod;
        }
    }
    answer = (answer + mod) % mod;
    
    printf("%lld", answer);
    return 0;
}

Matrix matrix_multiply_modular (Matrix A, Matrix B) {
    Matrix result = {{{}}};
    for (int i = 0; i < 3; i++) {
        for (int k = 0; k < 3; k++) {
            for (int j = 0; j < 3; j++) {
                result.array[i][j] = (result.array[i][j] + A.array[i][k] * B.array[k][j]) % mod;
            }
        }
    }
    return result;
}

Matrix matrix_power_modular (Matrix A, long long P) {
    Matrix result = {{{}}};
    for (int i = 0; i < 3; i++) {
        result.array[i][i] = 1;
    }
    while (P) {
        if (P % 2) {
            result = matrix_multiply_modular(result, A);
        }
        A = matrix_multiply_modular(A, A);
        P /= 2;
    }
    return result;
}

(C11, 0ms, 1116KB, 제출번호 66242188)


행렬로 식을 바꾸고 나서 9×9 행렬로 바꾸면 좀 더 식이 예뻐지지 않을까 했는데 그만두었습니다...