본문 바로가기

분할 정복을 이용한 거듭제곱/행렬 거듭제곱

백준 24660번: High Powers

https://www.acmicpc.net/problem/24660

 

24660번: High Powers

It can be shown that the answer can be represented as a rational number $p/q$ where $p$ and $q$ are integers, $(p,q)=1$, $q>0$ and $q$ is not divisible by $998\,244\,353$.

www.acmicpc.net


문제의 생김새와는 다르게 풀이는 매우 깔끔했습니다.

 

문제가 요구하는 것은 결국

$$ \frac{a^{n}(b^{m}-c^{m}) + b^{n}(c^{m}-a^{m}) + c^{n}(a^{m}-b^{m}}{(a-b)(b-c)(c-a)} $$

의 값을 modulo 998244353으로 표현하는 것입니다.

 

제가 여기서 했던 실수는, 분모와 분자를 각각 $s,t,u$에 관한 식으로 환원하여 값을 구하려고 했던 것입니다.

 

물론 분모, 분자를 각각 $s,t,u$로 잘 표현해낼 수만 있다면, 이 풀이는 매우 타당합니다.

 

998244353은 소수이며 $s,t,u$의 범위에 의해 분모는 반드시 역원이 존재하는 값이 될 것이기 때문입니다.

 

또한, 분모는 다음과 같이 정리되므로

$$ (a-b)(b-c)(c-a) = (-1) \times \left \{ a^{2}(b-c) + b^{2}(c-a) + c^{2}(a-b) \right \} $$

저는 한동안 분모로부터 분자로 이어지는 점화식을 찾기 위해 노력했었습니다.

 

그 결과, 괴상한 점화식들만 잔뜩 생기고 점화식을 알아도 답을 구할 수 없는 상황이 펼쳐졌습니다. 이때 조언을 들어 3차방정식을 생각한 게 신의 한 수였습니다.


문제의 조건에 따라 $a,b,c$는 다음 3차방정식의 해에 대응합니다.

$$ x^{3} - sx^{2} + tx - u = 0 $$

이때 다음을 알 수 있습니다.

$$ a^{n+3} - sa^{n+2} + ta^{n+1} - ua^{n} = 0 $$

이를 다시 발전시키면 다음의 세 식을 얻을 수 있습니다.

$$ a^{n+3}(b^{m}-c^{m}) - sa^{n+2}(b^{m}-c^{m}) + ta^{n+1}(b^{m}-c^{m}) - ua^{n}(b^{m}-c^{m}) = 0 $$

$$ b^{n+3}(c^{m}-a^{m}) - sb^{n+2}(c^{m}-a^{m}) + tb^{n+1}(c^{m}-a^{m}) - ub^{n}(c^{m}-a^{m}) = 0 $$

$$ c^{n+3}(a^{m}-b^{m}) - sc^{n+2}(a^{m}-b^{m}) + tc^{n+1}(a^{m}-b^{m}) - uc^{n}(a^{m}-b^{m}) = 0 $$

이제 이 세 식을 더해서, 문제에서 주어진 분모로 나누면, 완벽한 점화식을 얻게 됩니다.

$$ poly(x,y) = \frac{a^{x}(b^{y}-c^{y}) + b^{x}(c^{y}-a^{y}) + c^{x}(a^{y}-b^{y})}{(a-b)(b-c)(c-a)} $$

일 때,

$$ poly(n+3,m) - s \times poly(n+2,m) + t \times poly(n+1,m) - u \times poly(n,m) = 0 $$

입니다.

 

다만, $n,m$의 범위에 의해 시간 안에 통과하려면 행렬 형태로 바꿔야겠습니다.

$$ \begin{bmatrix} poly(n+3,m) \\ poly(n+2,m) \\ poly(n+1,m) \end{bmatrix} = \begin{bmatrix} s & -t & u \\ 1 & 0 & 0 \\ 0 & 1 & 0 \end{bmatrix} \begin{bmatrix} poly(n+2,m) \\ poly(n+1,m) \\ poly(n,m) \end{bmatrix}$$

$$ \rightarrow \begin{bmatrix} poly(n+3,m) \\ poly(n+2,m) \\ poly(n+1,m) \end{bmatrix} = \begin{bmatrix} s & -t & u \\ 1 & 0 & 0 \\ 0 & 1 & 0 \end{bmatrix}^{n} \begin{bmatrix} poly(3,m) \\ poly(2,m) \\ poly(1,m) \end{bmatrix}$$

여기서 완전한 답을 얻어내려면 다음의 식을 캐치해야 합니다.

$$ poly(n,m) = -poly(m,n) $$

자명하게 성립하는 이 식을 사용하면 답을 구할 수 있게 됩니다. 시간 복잡도는 $O(k^{3}\log n)$이고 $k=3$입니다.


이제 남은 건 $n,m$이 충분히 작을 때의 답을 구하는 것입니다.

 

위에서 사용했던 대칭성에 관한 식을 가져오면, 다음을 얻을 수 있습니다.

$$ poly(n,n) = 0 $$

이로부터 우리가 구해야 하는 범위는 $n,m \leq 3$이었는데, 실제로 구해야 하는 경우의 수는 3가지 뿐임을 알게 되었습니다.

 

그런데 $poly(n,m)$은 결국 $a,b,c$에 관한 $n+m-3$차 다항식임을 생각하면, 직관적으로 답을 얻어낼 수 있게 됩니다.

(또한, 이런 성질들을 사용하는 게 좋습니다.)

 

그렇게 얻어낸 답을 검증하는 것은 아쉽게도 다항식의 곱셈을 직접 전개하는 게 제일 빠를 것 같습니다.

 

저는 wolframalpha를 이용해서 검증했었습니다.

 

이제 이를 구현하면 끝입니다. 음수 모듈러에 주의하여 구현합니다.

#include <stdio.h>
#define mod 998244353

typedef struct {
    long long array[3][3];
} Matrix;

Matrix matrix_multiply_modular (Matrix A, Matrix B);
Matrix matrix_power_modular (Matrix A, long long P);

int main() {
    
    long long N, M, s, t, u;
    scanf("%lld %lld %lld %lld %lld", &N, &M, &s, &t, &u);
    
    Matrix Base =
    {{
        {s, -t, u},
        {1,  0, 0},
        {0,  1, 0}
    }};
    Matrix tempN = matrix_power_modular(Base, N-3);
    Matrix tempM = matrix_power_modular(Base, M-3);
    
    long long initialCases[4][4] =
    {
        {0,  0,  0, 0},
        {0,  0,  1, s},
        {0, -1,  0, t},
        {0, -s, -t, 0}
    };
    
    
    long long answer = 0;
    if (N <= 3) {
        if (M <= 3) {
            answer = initialCases[N][M];
        }
        else {
            for (int i = 0; i < 3; i++) {
                answer = (answer + tempM.array[0][i] * initialCases[3-i][N]) % mod;
            }
            answer *= -1;
        }
    }
    else if (M <= 3) {
        for (int i = 0; i < 3; i++) {
            answer = (answer + tempN.array[0][i] * initialCases[3-i][M]) % mod;
        }
    }
    else {
        long long temp[4] = {};
        for (int i = 1; i <= 3; i++) {
            for (int j = 0; j < 3; j++) {
                temp[i] = (temp[i] + tempM.array[0][j] * initialCases[3-j][i]) % mod;
            }
            temp[i] *= -1; // (M, i) *= -1
        }
        
        for (int i = 0; i < 3; i++) {
            answer = (answer + tempN.array[0][i] * temp[3-i]) % mod;
        }
    }
    answer = (answer + mod) % mod;
    
    printf("%lld", answer);
    return 0;
}

Matrix matrix_multiply_modular (Matrix A, Matrix B) {
    Matrix result = {{{}}};
    for (int i = 0; i < 3; i++) {
        for (int k = 0; k < 3; k++) {
            for (int j = 0; j < 3; j++) {
                result.array[i][j] = (result.array[i][j] + A.array[i][k] * B.array[k][j]) % mod;
            }
        }
    }
    return result;
}

Matrix matrix_power_modular (Matrix A, long long P) {
    Matrix result = {{{}}};
    for (int i = 0; i < 3; i++) {
        result.array[i][i] = 1;
    }
    while (P) {
        if (P % 2) {
            result = matrix_multiply_modular(result, A);
        }
        A = matrix_multiply_modular(A, A);
        P /= 2;
    }
    return result;
}

(C11, 0ms, 1116KB, 제출번호 66242188)


행렬로 식을 바꾸고 나서 $9 \times 9$ 행렬로 바꾸면 좀 더 식이 예뻐지지 않을까 했는데 그만두었습니다...