알고리즘

Dynamic Programming 4 - Matrix Multiplication Problem

402번째 거북이 2021. 10. 20. 21:53

Matrix chain multiplication: Optimal Substructure

연산(총 multiplication)의 양이 최소가 되도록 하는 행렬의 곱 순서를 푸는 알고리즘으로 구현해보자.

 

*행렬곱에 관한 원칙

  • 행렬 A(p×q)와 B(r×s)를 곱하기 위해서는 q = r이어야 하며, 곱의 결괏값은 (p×s)인 행렬이다.
  • 행렬곱(p×q인 행렬과 q×r인 행렬의 곱)에서 시행되는 곱의 횟수는 총 pqr 회이다.
  • 행렬곱은 결합법칙이 성립한다. 하지만, 계산순서가 바뀌었을 때, 연산의 '양' pqr은 달라질 수 있다.

 

 

## Matrix Chain Multiplication Problem

 

다음의 행렬곱이 있다고 하자.

$$A_{1}·A_{2}·A_{3}·A_{4}····A_{n}$$

곱해지는 각 행렬의 차원은 다음과 같다.

$$A_{1} :\; p_{0}×p_{1}$$

$$A_{2} :\; p_{1}×p_{2}$$

$$A_{3} :\; p_{2}×p_{3}$$

$$....$$

$$A_{i} :\; p_{i-1}×p_{i}$$

$$....$$

$$A_{n} :\; p_{n-1}×p_{n}$$

 

이 행렬 chain에서 시행되는 곱의 총 횟수가 가장 작아지도록 연산순서를 짜는 것이 problem 이다.

이를 앞으로 아래처럼 이야기할 것이다.

parenthesizing(괄호치기) the product of matrices minimizing scalar

 

## Non-Dynamic Programming: Brute Force Approach

 

다음의 instance가 problem이라고 하자.

$$A_{1}·A_{2}·A_{3}·A_{4}$$

총 3개의 연산기호가 있음을 확인하자.

parenthesizing의 모든 경우의 수는 다음과 같이 나열된다.

총 3개의 연산기호가 있기 때문에 최대 3! = 6 개의 경우의 수에서 겹치는 것이 빠진 수만큼의 경우의 수가 존재할 것이다.

$$A_{1}·(A_{2}·(A_{3}·A_{4})),\;\;(A_{1}·A_{2})·(A_{3}·A_{4})...$$

3! - 1 = 5개의 경우의 수가 있음을 확인하자.

 

Brute Force Approach는 재귀적으로 곱의 연산순서를 정한다.

아래에서 BFA가 어떤 원리로 작동되는지 이해만 해보자.

 

n 개의 행렬을 곱하는 문제에서, parenthesization의 개수를 P(n)이라고 하면, P(n)은 다음과 같은 규칙을 따른다.

$$\begin{cases} 1 & \mbox{if }n\mbox{ =1} \\ \sum_{k = 1}^{n-1}P(k)P(n-k) & \mbox{if }n\mbox{ ≥2} \end{cases}$$

n이 1인 경우 행렬이 하나뿐이므로 parenthesization의 경우의 수가 하나다.

n이 2 이상인 경우, (last multiplication을 정하는 경우의 수) × (나머지를 arrange하는 경우의 수)로 해석하자.

 

이 Brute Force Approach의 경우 시간복잡도의 lower bound가 아래와 같다.

$$\Omega (4^{n}/n^{3/2})$$

시간이 너무 오래걸린다.

 


 

Dynamic: Optimal Substructure

Dynamic Programming으로 BFA를 좀 더 효율적으로 만들어보자.

 

각 행렬의 차원이 주어졌을 때, (p)

다음을 의미하는 m[ i ][ j ]를 구하는 것이 목적이다.

행렬 A_{i}에서부터 A_{j}까지의 곱에 있어서 scalar multiplication의 최소횟수

이를 경우에 따라 작성하면 아래와 같다.

$$\begin{cases} 0 & if\;\;i = j \\ \min_{i≤k≤j}(m[i][k]+m[k+1][j]+p_{i-1}p_{k}p_{j}) &if\;\;i<j \end{cases}$$

Ai에서 Aj까지의 chain을 A_(i)~A_(k), A_(k+1)~A_(j)의 큰 두 덩어리로 나누어 곱하는 과정을 재귀적으로 나타낸 것이라고 이해하자.

이 연산을 통해 얻은 scalar multiplication의 횟수들을 새로운 table m에 저장할 것이다.

여기서, 해당 재귀의 값이 최소가 되는 k를, 새로운 table s에 저장할 것이다.

 

그림으로 전체적인 과정을 이해해보자.

Table p

input이다.

행렬 chain들의 차원을 적어놓는 공간이다.

 

Table m

scalar multiplication의 총 개수를 저장하는 공간이다.

대각을 기준으로 오른쪽 위, 즉 i가 j보다 작거나 같은 경우에만 정의되어있다. (Ai~Aj까지의 chain을 생각한다면 당연한 이야기.)

대각방향으로 차례대로 올라가며 채워진다.

가장 먼저 i = j일 경우, 행이 하나뿐인 경우를 의미하므로 일괄적으로 0을 채워준다.

다음으로 j = i+1인 경우, j = i+2인 경우,,,, 순으로 채워준다.

각각의 칸을 채울 때는 아래의 경우의 수를 따져, 최솟값을 채워주면 된다.

$$\begin{cases} 0 & if\;\;i = j \\ \min_{i≤k≤j}(m[i][k]+m[k+1][j]+p_{i-1}p_{k}p_{j}) &if\;\;i<j \end{cases}$$

 

Table s

backtracking을 위해 만들어놓은 공간이다.

optimal k를 저장해놓는 공간이다. (parenthesizing의 기준이 되는 행의 위치)

m[i][j]를 구한 후, optimal한 k를 해당 인덱스에 채워준다.

 

m을 다 채운 후에 최종적인 답은, 행렬곱의 처음(1) 부터 끝(6)까지의 곱 연산 수이므로 m table에서 가장 오른쪽 위에 들어있는 값이다.

 

 

## Pseudo Code of MATRIX-CHAIN-ORDER

table m과 s를 반환해주는 코드를 이해해보자.

MATRIX-CHAIN-ORDER(p) #p는 matrix들의 차원을 담은 공간이다.(input)
	n = p.length - 1
    table m[1...n][1...n]과 s[1...n-1][2...n]을 새로운 table로 선언해준다
    
    for i = 1 to n:
    	m[i][i] = 0     # 대각을 0으로 넣어준다.
   	for l = 2 to n:     # m[1][2]부터 시작하므로 시작이 2이다. 
    	for i = 1 to n-l+1:
        	j = i + l -1
            m[i][j] = 1000000000
            for k = i to j-1:
            	q = m[i][k] + m[k+1][j] + p[i-1]p[k]p[j]
                if q < m[i][j]:
                	m[i][j] = q # 최솟값으로 계속 갱신
                    	s[i][j] = k
   return m and s

 

##PRINT-OPTIMAL-PARENS

s를 통한 backtracking으로 optimal한 parenthesizing 방법을 출력하게 하는 코드다.

PRINT-OPTIMAL-PARENS(s, i, j):
	if i == j:
    	print("A_i")
    else:
    	print("(")
        PRINT-OPTIMAL-PARENS(s, i, s[i][j]) # k를 기준으로 왼쪽에서 다시 parens를 시행
        PRINT-OPTIMAL-PARENS(s, s[i][j]+1, j) # k를 기준으로 오른쪽에서 다시 parens를 시행
        print(")")

 

 

## Space Consumption과 Time Consumption

 

(1) 공간복잡도

n*n 차원의 table m과 s를 새로 생성해야 하므로 공간복잡도는 다음과 같다.

$$\theta(2n^{2}) = \theta{n^{2}}$$

 

(2)시간복잡도

m 테이블의 대각을 올라가며 채워넣는 과정을 생각하면서 다음의 과정을 따라가보자.

$$\begin{align}1(n-0)+1(n-1)+2(n-2)+3(n-3)+....+(n-1)1 \\ = 1(n-0)+\sum_{k = 1}^{n-1}k(n-k)\\=...\\ =(n^{3}+5n)/6\\=\theta(n^{3}) \end{align}$$

$$\theta(n^{3})$$