본문 바로가기

분할 정복을 이용한 거듭제곱/행렬 거듭제곱

백준 31987번: ESC와 쿼리

https://www.acmicpc.net/problem/31987


문제에서 요구하는 것은 다음과 같습니다.

jx=iakx+bkx+ckxdndxn(exsin(x)cos(x))=anexsin2(x)+bnexcos2(x)+cnexsin(x)cos(x)jx=iakx+bkx+ckxdndxn(exsin(x)cos(x))=anexsin2(x)+bnexcos2(x)+cnexsin(x)cos(x)

 

그런데 exex 는 미분해도 자기 자신이며, sin(x)sin(x) 를 미분하면 cos(x)cos(x) 가 되고, cos(x)cos(x) 를 미분하면 sin(x)sin(x) 가 된다는 것이 알려져 있습니다. 따라서, 이 값을 구하기 위해 초항부터 구해나가며 관찰을 해도 좋지만, nn 번 미분된 식을 한 번 더 미분해도 되겠습니다.

 

dn+1dxn+1(anexsin2(x)+bnexcos2(x)+cnexsin(x)cos(x))dn+1dxn+1(anexsin2(x)+bnexcos2(x)+cnexsin(x)cos(x))

=(ancn)exsin2(x)+(bn+cn)excos2(x)+(2an2bn+cn)exsin(x)cos(x)=(ancn)exsin2(x)+(bn+cn)excos2(x)+(2an2bn+cn)exsin(x)cos(x)

an+1exsin2(x)+bn+1excos2(x)+cn+1exsin(x)cos(x)an+1exsin2(x)+bn+1excos2(x)+cn+1exsin(x)cos(x)

 

마지막 줄로부터, 점화식을 행렬로 쉽게 표현할 수 있습니다.

[anbncn]=[101011221]n1[a1b1c1]

 

그런데, 구해야 하는 식은 jx=i(akx+bkx+ckx) 입니다. 나이브하게 구하게 되면 쿼리 당 O(33×logik+33×(ji)) 라는, 너무 큰 복잡도가 나오기 때문에, 식을 간소화하고 복잡도를 낮출 방법을 찾아보겠습니다.


제일 먼저, 위 점화식으로부터

an+1+bn+1=an+bn

임을 알 수 있으며, 초항을 구해보면 이는 0입니다! 그러므로 구해야 하는 식은

jx=ickx

로 간소화됩니다. 추가적으로, an+bn=0 으로부터 점화식도 간소화됩니다.

[ancn]=[1141]n1[11]

하지만, 여전히 복잡도의 차수는 동일하기 때문에, 거듭제곱의 합을 빠르게 구할 방법을 찾아야만 합니다. 행렬의 등비수열로 접근한다면 예제까지는 충분히 답을 얻을 수 있으나, 역행렬이 보장되게 만드는 게 쉽진 않을 것 같습니다. 그렇다면, 일반항 밖에 선택의 여지가 없습니다.


Wolframalpha, 또는 직접 풀어 점화식의 일반항을 얻으면 아마 다음과 같은 형태로 나타날 것입니다. 직접 푸는 과정은 나중에 첨부하도록 하겠습니다.

cn=(1+2i)n+(12i)n2

이 식을 다양하게 변형할 수 있으나, 문제의 답을 구할 때는 등비수열의 합으로 구하는 게 가장 자연스러울 것 같습니다.

jx=ickx=12×jx=i((1+2i)kx+(12i)kx)

=12×[{jx=0(1+2i)kxi1x=0(1+2i)kx}+{jx=0(12i)kxi1x=0(12i)kx}]

=12×{(1+2i)k(j+1)(1+2i)ki(1+2i)k1+(12i)k(j+1)(12i)ki(12i)k1}

놀랍게도, 이 복잡한 식에서 통분을 하게 되면 상당히 예쁜 식이 나오게 됩니다.

jx=ickx=5k(ckjck(i1))ck(j+1)+cki5k+12×ck

이 식을 바탕으로 문제를 풀게 되면 시간 복잡도는 O(Q×23×logjk) 입니다.

 

C/C++은 시간제한이 0.2s이기 때문에, 이 정도의 복잡도를 가지고도 구현에서 최적화를 하지 않으면 시간 초과될 여지가 있습니다. 제 코드는 다음과 같습니다.

#include <iostream>
#define mod 1000000007

typedef unsigned long long ull;

ull power_modular (ull X, ull Y) {
	ull result = 1;
	while(Y) {
		if (Y % 2) {
			result = result * X % mod;
		}
		X = X * X % mod;
		Y /= 2;
	}
	return result;
}

typedef struct {
	ull array[2][2];
} Matrix;

Matrix matrix_multiply_modular (Matrix A, Matrix B) {
	Matrix result = {{{}}};
	result.array[0][0] = (A.array[0][0] * B.array[0][0] + A.array[0][1] * B.array[1][0]) % mod;
	result.array[0][1] = (A.array[0][0] * B.array[0][1] + A.array[0][1] * B.array[1][1]) % mod;
	result.array[1][0] = (A.array[1][0] * B.array[0][0] + A.array[1][1] * B.array[1][0]) % mod;
	result.array[1][1] = (A.array[1][0] * B.array[0][1] + A.array[1][1] * B.array[1][1]) % mod;
	return result;
}

Matrix matrix_power_modular (Matrix A, long long P) {
	Matrix result = {{{1,0},{0,1}}};
	while (P) {
		if (P % 2) {
			result = matrix_multiply_modular(result, A);
		}
		A = matrix_multiply_modular(A, A);
		P /= 2;
	}
	return result;
}

using namespace std;
int main() {
	ios_base::sync_with_stdio(false);
	cin.tie(nullptr);
	cout.tie(nullptr);
	
	int Q;
	cin >> Q;
	while(Q--)
	{
		long long i, j, k;
		cin >> i >> j >> k;

		Matrix base = {{{1, mod - 1},{4, 1}}};
		
		Matrix temp1 = matrix_power_modular(base, k);
		Matrix temp2 = matrix_power_modular(base, k * (i - 1));
		Matrix temp3 = matrix_multiply_modular(temp1, temp2);
		Matrix temp4 = matrix_power_modular(base, k * j);
		Matrix temp5 = matrix_multiply_modular(temp1, temp4);
		
		ull temp = power_modular(5, k);
		
		ull divisor	= (temp + 1 + (mod - 2) * temp1.array[1][1]) % mod;
		ull dividend = (temp * ((temp4.array[1][1] + (mod - 1) * temp2.array[1][1]) % mod) + (mod - 1) * temp5.array[1][1] + temp3.array[1][1]) % mod;
		
		ull answer = dividend * power_modular(divisor, mod - 2) % mod;
		cout << answer << "\n";
	}

	return 0;
}

(C++20, 112ms, 2020KB, 제출번호 86595282)