Strassen 행렬 곱셈

목표

행렬 곱셈이란 무엇인가?

참고: 선형 대수/행렬 곱셈에 이미 익숙하다면 이 섹션을 건너뛰어도 된다.

시작하기 전에, 행렬 곱셈이 무엇인지 궁금할 수 있다. 좋은 질문이다! 행렬 곱셈은 두 행렬을 각 요소별로 곱하는 것이 아니다. 행렬 곱셈은 두 행렬을 하나로 결합하는 수학적 연산이다. 각 요소별로 곱하는 것처럼 들리겠지만, 그렇지 않다. 만약 그랬다면 우리의 삶이 훨씬 쉬웠을 것이다. 행렬 곱셈이 어떻게 작동하는지 알아보기 위해 예제를 살펴보자.

예제: 행렬 곱셈

행렬 A = |1 2|, 행렬 B = |5 6|
         |3 4|           |7 8|

A * B = C

|1 2| * |5 6| = |1*5+2*7 1*6+2*8| = |19 22|
|3 4|   |7 8|   |3*5+4*7 3*6+4*8|   |43 50|

여기서 무슨 일이 일어나고 있을까? 먼저, 행렬 A와 B를 곱한다. 새로운 행렬 C의 [i, j] 요소는 첫 번째 행렬의 i번째 행과 두 번째 행렬의 j번째 열의 내적으로 결정된다. 내적에 대해 복습하려면 여기를 참고한다.

따라서 새로운 행렬의 왼쪽 위 요소 [i=1, j=1]는 A의 첫 번째 행과 B의 첫 번째 열의 조합이다.

A의 첫 번째 행 = [1, 2]
B의 첫 번째 열 = [5, 7]

[1, 2] 내적 [5, 7] = [1*5 + 2*7] = [19] = C[1, 1]

이제 [i=1, j=2]에 대해 시도해 보자. i=1이고 j=2이므로, 이는 새로운 행렬 C의 오른쪽 위 요소를 나타낸다.

A의 첫 번째 행 = [1, 2]
B의 두 번째 열 = [6, 8]

[1, 2] 내적 [6, 8] = [1*6 + 2*8] = [22] = C[1, 2]

A와 B의 각 행과 열에 대해 이 과정을 반복하면 결과 행렬 C를 얻을 수 있다!

다음은 이 과정을 시각적으로 보여주는 훌륭한 그래픽이다.

출처

행렬 곱셈 알고리즘

행렬 곱셈을 알고리즘으로 어떻게 구현할까? 기본 버전부터 시작해 스트라센 알고리즘까지 차근차근 알아본다.

기본 버전

이전에 에서 설명한 행렬 곱셈 방법을 기억하는가? 먼저 그 방법을 구현해 보자! 두 행렬의 크기가 적절한지 확인한다.

assert(A.columns == B.rows, "두 행렬은 하나가 mxn 차원이고 다른 하나가 nxp 차원일 때만 곱할 수 있다. 여기서 m, n, p는 실수이다.")

참고: 행렬 곱셈이 가능하려면 A의 컬럼 수와 B의 행 수가 같아야 한다.

다음으로, A의 컬럼과 B의 행을 반복한다. A의 컬럼과 B의 행의 길이가 같다는 것을 알고 있으므로, 그 길이를 n으로 설정한다.

for i in 0..<n {
  for j in 0..<n {

그런 다음, A의 각 행과 B의 각 컬럼에 대해, A의 i번째 행과 B의 j번째 컬럼의 내적을 계산하고 그 결과를 C의 [i, j] 요소에 저장한다. 즉, C[i, j]에 저장한다.

for k in 0..<n {
  C[i, j] += A[i, k] * B[k, j]
}

마지막으로, 새로운 행렬 C를 반환한다!

전체 구현은 다음과 같다:

public func matrixMultiply(by B: Matrix<T>) -> Matrix<T> {
  let A = self
  assert(A.columns == B.rows, "두 행렬은 하나가 mxn 차원이고 다른 하나가 nxp 차원일 때만 곱할 수 있다. 여기서 m, n, p는 실수이다.")
  let n = A.columns
  var C = Matrix<T>(rows: A.rows, columns: B.columns)
    
  for i in 0..<n {
    for j in 0..<n {
      for k in 0..<n {
        C[i, j] += A[i, k] * B[k, j]
      }
    }
  }
    
  return C
}

이 알고리즘의 시간 복잡도는 **O(n^3)**이다. **O(n^3)**은 세 개의 for 루프에서 비롯된다. 행과 컬럼을 반복하는 두 루프와 내적을 계산하는 한 루프 때문이다!

그런데 **O(n^3)**은 그다지 빠르지 않다. 더 나은 방법이 있을까? 물론 있다!

스트라센 알고리즘

볼커 스트라센은 1969년에 이 알고리즘을 처음 발표했다. 이 알고리즘은 기본적인 O(n^3) 시간 복잡도가 최적이 아니라는 사실을 처음으로 증명했다.

스트라센 알고리즘의 기본 아이디어는 행렬 A와 B를 8개의 부분 행렬로 나눈 후, C의 부분 행렬을 재귀적으로 계산하는 것이다. 이 전략을 *분할 정복(Divide and Conquer)*이라고 한다.

matrix A = |a b|, matrix B = |e f|
           |c d|             |g h|

재귀 호출은 총 8번 발생한다:

  1. a * e
  2. b * g
  3. a * f
  4. b * h
  5. c * e
  6. d * g
  7. c * f
  8. d * h

그런 다음 이 결과를 사용해 C의 부분 행렬을 계산한다.

matrix C = |ae+bg af+bh|
           |ce+dg cf+dh| 

http://d1hyf4ir1gqw6c.cloudfront.net//wp-content/uploads/strassen_new.png

하지만 이 단계만으로는 시간 복잡도를 개선할 수 없다. 마스터 정리(Master Theorem)를 적용해 T(n) = 8T(n/2) + O(n^2)를 계산해보면 여전히 O(n^3)의 시간 복잡도를 얻는다.

스트라센의 통찰은 이 과정을 완료하기 위해 8번의 재귀 호출이 필요하지 않다는 것이다. 7번의 재귀 호출과 약간의 덧셈과 뺄셈만으로도 이 과정을 마칠 수 있다.

스트라센의 7번의 호출은 다음과 같다:

  1. a * (f - h)
  2. (a + b) * h
  3. (c + d) * e
  4. d * (g - e)
  5. (a + d) * (e + h)
  6. (b - d) * (g + h)
  7. (a - c) * (e + f)

이제 새로운 행렬 C의 사분면을 계산할 수 있다!

matrix C = |p5+p4-p2+p6    p1+p2   |
           |   p3+p4    p1+p5-p3-p7|    

이제 이게 어떻게 작동하는지 궁금할 것이다. 증명해보자!

첫 번째 부분 행렬:

p5+p4-p2+p6 = (a+d)*(e+h) + d*(g-e) - (a+b)*h + (b-d)*(g+h)
            = (ae+de+ah+dh) + (dg-de) - (ah+bh) + (bg-dg+bh-dh)
            = ae+bg

처음에 얻은 결과와 정확히 일치한다!

이제 나머지도 증명해보자.

두 번째 부분 행렬:

p1+p2 = a*(f-h) + (a+b)*h
      = (af-ah) + (ah+bh)
      = af+bh

세 번째 부분 행렬:

p3+p4 = (c+d)*e + d*(g-e)
      = (ce+de) + (dg-de)
      = ce+dg

네 번째 부분 행렬:

p1+p5-p3-p7 = a*(f-h) + (a+d)*(e+h) - (c+d)*e - (a-c)*(e+f)
            = (af-ah) + (ae+de+ah+dh) -(ce+de) - (ae-ce+af-cf)
            = cf+dh

훌륭하다! 수학적으로 정확하다!

다음은 이 과정을 보여주는 이미지다.

출처

구현

이제 실제 구현을 시작한다. 기본 구현과 동일한 첫 단계부터 시작한다. 행렬 A의 컬럼 수와 행렬 B의 행 수가 같은지 먼저 확인해야 한다.

assert(A.columns == B.rows, "두 행렬은 mxn과 nxp 차원을 가질 때만 행렬 곱셈이 가능하다. 여기서 m, n, p는 실수이다.")

이제 준비 작업을 진행한다. 각 행렬을 정사각형으로 만들고 크기를 다음 2의 거듭제곱으로 늘린다. 이렇게 하면 슈트라센 알고리즘을 훨씬 쉽게 관리할 수 있다. 이제 짝수 번 분할할 수 있는 정사각형 행렬만 다루면 된다!

let n = max(A.rows, A.columns, B.rows, B.columns)
let m = nextPowerOfTwo(after: n)
    
var APrep = Matrix(size: m)
var BPrep = Matrix(size: m)
   
for i in A.rows {
  for j in A.columns {
    APrep[i, j] = A[i,j]
  }
}

for i in B.rows {
  for j in B.columns {
    BPrep[i, j] = B[i, j]
  }
}

마지막으로, 슈트라센 알고리즘을 사용해 행렬을 재귀적으로 계산하고, 새로운 행렬 C를 올바른 차원으로 다시 변환한다!

let CPrep = APrep.strassenR(by: BPrep)
var C = Matrix(rows: A.rows, columns: B.columns)
    
for i in 0..<A.rows {
  for j in 0..<B.columns {
    C[i,j] = CPrep[i,j]
  }
}

재귀적으로 행렬 곱셈 계산하기

이제 strassenR이라는 재귀 함수를 살펴보자.

먼저 8개의 부분 행렬을 초기화한다.

var a = Matrix(size: nBy2)
var b = Matrix(size: nBy2)
var c = Matrix(size: nBy2)
var d = Matrix(size: nBy2)
var e = Matrix(size: nBy2)
var f = Matrix(size: nBy2)
var g = Matrix(size: nBy2)
var h = Matrix(size: nBy2)
    
for i in 0..<nBy2 {
  for j in 0..<nBy2 {
    a[i,j] = A[i,j]
    b[i,j] = A[i, j+nBy2]
    c[i,j] = A[i+nBy2, j]
    d[i,j] = A[i+nBy2, j+nBy2]
    e[i,j] = B[i,j]
    f[i,j] = B[i, j+nBy2]
    g[i,j] = B[i+nBy2, j]
    h[i,j] = B[i+nBy2, j+nBy2]
  }
}

그 다음, 7개의 행렬 곱셈을 재귀적으로 계산한다.

let p1 = a.strassenR(by: f-h)       // a * (f - h)
let p2 = (a+b).strassenR(by: h)     // (a + b) * h
let p3 = (c+d).strassenR(by: e)     // (c + d) * e
let p4 = d.strassenR(by: g-e)       // d * (g - e)
let p5 = (a+d).strassenR(by: e+h)   // (a + d) * (e + h)
let p6 = (b-d).strassenR(by: g+h)   // (b - d) * (g + h)
let p7 = (a-c).strassenR(by: e+f)   // (a - c) * (e + f)

그리고 C의 부분 행렬을 계산한다.

let c11 = p5 + p4 - p2 + p6         // p5 + p4 - p2 + p6
let c12 = p1 + p2                   // p1 + p2
let c21 = p3 + p4                   // p3 + p4
let c22 = p1 + p5 - p3 - p7         // p1 + p5 - p3 - p7

마지막으로, 이 부분 행렬들을 새로운 행렬 C로 합친다!

var C = Matrix(size: n)    
for i in 0..<nBy2 {
  for j in 0..<nBy2 {
    C[i, j]           = c11[i,j]
    C[i, j+nBy2]      = c12[i,j]
    C[i+nBy2, j]      = c21[i,j]
    C[i+nBy2, j+nBy2] = c22[i,j]
  }
}

이전과 마찬가지로 마스터 정리를 사용해 시간 복잡도를 분석할 수 있다. T(n) = 7T(n/2) + O(n^2)이므로 O(n^log(7))의 실행 시간이 나온다. 이는 약 O(n^2.8074)로, O(n^3)보다 효율적이다!

이것이 바로 슈트라센 알고리즘이다. 전체 구현은 다음과 같다:

// MARK: - 슈트라센 곱셈

extension Matrix {
  public func strassenMatrixMultiply(by B: Matrix<T>) -> Matrix<T> {
    let A = self
    assert(A.columns == B.rows, "두 행렬은 하나가 mxn 차원이고 다른 하나가 nxp 차원일 때만 행렬 곱셈이 가능하다. 여기서 m, n, p는 실수이다.")
    
    let n = max(A.rows, A.columns, B.rows, B.columns)
    let m = nextPowerOfTwo(after: n)
    
    var APrep = Matrix(size: m)
    var BPrep = Matrix(size: m)
    
    A.forEach { (i, j) in
      APrep[i,j] = A[i,j]
    }
    
    B.forEach { (i, j) in
      BPrep[i,j] = B[i,j]
    }
    
    let CPrep = APrep.strassenR(by: BPrep)
    var C = Matrix(rows: A.rows, columns: B.columns)
    for i in 0..<A.rows {
      for j in 0..<B.columns {
        C[i,j] = CPrep[i,j]
      }
    }
    
    return C
  }
  
  private func strassenR(by B: Matrix<T>) -> Matrix<T> {
    let A = self
    assert(A.isSquare && B.isSquare, "이 함수는 정사각 행렬을 필요로 한다!")
    guard A.rows > 1 && B.rows > 1 else { return A * B }
    
    let n    = A.rows
    let nBy2 = n / 2
    
    /*
    부분 행렬은 다음과 같이 할당된다고 가정한다
    
     행렬 A = |a b|,    행렬 B = |e f|
              |c d|              |g h|
    */
    
    var a = Matrix(size: nBy2)
    var b = Matrix(size: nBy2)
    var c = Matrix(size: nBy2)
    var d = Matrix(size: nBy2)
    var e = Matrix(size: nBy2)
    var f = Matrix(size: nBy2)
    var g = Matrix(size: nBy2)
    var h = Matrix(size: nBy2)
    
    for i in 0..<nBy2 {
      for j in 0..<nBy2 {
        a[i,j] = A[i,j]
        b[i,j] = A[i, j+nBy2]
        c[i,j] = A[i+nBy2, j]
        d[i,j] = A[i+nBy2, j+nBy2]
        e[i,j] = B[i,j]
        f[i,j] = B[i, j+nBy2]
        g[i,j] = B[i+nBy2, j]
        h[i,j] = B[i+nBy2, j+nBy2]
      }
    }
    
    let p1 = a.strassenR(by: f-h)       // a * (f - h)
    let p2 = (a+b).strassenR(by: h)     // (a + b) * h
    let p3 = (c+d).strassenR(by: e)     // (c + d) * e
    let p4 = d.strassenR(by: g-e)       // d * (g - e)
    let p5 = (a+d).strassenR(by: e+h)   // (a + d) * (e + h)
    let p6 = (b-d).strassenR(by: g+h)   // (b - d) * (g + h)
    let p7 = (a-c).strassenR(by: e+f)   // (a - c) * (e + f)
    
    let c11 = p5 + p4 - p2 + p6         // p5 + p4 - p2 + p6
    let c12 = p1 + p2                   // p1 + p2
    let c21 = p3 + p4                   // p3 + p4
    let c22 = p1 + p5 - p3 - p7         // p1 + p5 - p3 - p7
    
    var C = Matrix(size: n)
    for i in 0..<nBy2 {
      for j in 0..<nBy2 {
        C[i, j]           = c11[i,j]
        C[i, j+nBy2]      = c12[i,j]
        C[i+nBy2, j]      = c21[i,j]
        C[i+nBy2, j+nBy2] = c22[i,j]
      }
    }
    
    return C
  }
  
  private func nextPowerOfTwo(after n: Int) -> Int {
    return Int(pow(2, ceil(log2(Double(n)))))
  }
}

부록

숫자 프로토콜

매트릭스를 일반화하기 위해 숫자 프로토콜을 사용한다.

숫자 프로토콜은 세 가지를 보장한다:

  1. 숫자인 모든 것은 곱할 수 있다.
  2. 숫자인 모든 것은 더하거나 뺄 수 있다.
  3. 숫자인 모든 것은 제로 값을 가진다.

Int, Float, Double을 이 프로토콜에 맞게 확장하는 것은 매우 간단하다. static var zero를 구현하기만 하면 된다!

public protocol Number: Multipliable, Addable {
  static var zero: Self { get }
}

public protocol Addable {
  static func +(lhs: Self, rhs: Self) -> Self
  static func -(lhs: Self, rhs: Self) -> Self
}

public protocol Multipliable {
  static func *(lhs: Self, rhs: Self) -> Self
}

내적(Dot Product)

Array의 요소가 Number 프로토콜을 준수할 때, 내적(dot product) 함수를 추가하기 위해 Array를 확장한다.

extension Array where Element: Number {
  public func dot(_ b: Array<Element>) -> Element {
    let a = self
    assert(a.count == b.count, "같은 길이의 배열끼리만 내적을 계산할 수 있습니다!")
    let c = a.indices.map{ a[$0] * b[$0] }
    return c.reduce(Element.zero, { $0 + $1 })
  }
}

리소스

Swift Algorithm Club을 위해 Richard Ash가 작성