zygygy@home:~$

최장 공통 부분 수열 - Longest Common Subsequence

문제

두 문자열의 최대 공통 부분 수열 (Longest Common Subsequence)의 길이를 구하라.

LCS(“ABCDGH”, “AEDFHR”) = “ADH”, length = 3
LCS(“AGGTAB”, “GXTXAYB”) = “GTAB”, length = 4

출처: https://www.geeksforgeeks.org/longest-common-subsequence-dp-4/

추가문제:

  • 공간 효율적인 방법으로 LCS를 구하라.
  • LCS 문자열을 구하라.
  • 모든 LCS 문자열을 구하라.

풀이

동적 계획법, DP (Dynamic programming)로 O(mn) 시간에 풀 수 있다.

Dynamic programming

LCS 문제의 optimal substructure

LCS 문제의 점화식은

lcs(0, _) = 0,
lcs(_, 0) = 0,
lcs(m, n) = 1 + lcs(m - 1, n - 1) if M[m] == N[n] else max(lcs(m - 1, n), lcs(m, n - 1))

로 표현할 수 있다.

테이블로 풀어보면,

M = “ABCDGH”, N = “AEDFHR”

|   | ''| A | E | D | F | H | R |
|---|---|---|---|---|---|---|---|
| ''| 0 | 0 | 0 | 0 | 0 | 0 | 0 |
| A | 0 |"1"| 1 | 1 | 1 | 1 | 1 |
| B | 0 | 1 | 1 | 1 | 1 | 1 | 1 |
| C | 0 | 1 | 1 | 1 | 1 | 1 | 1 |
| D | 0 | 1 | 1 |"2"| 2 | 2 | 2 |
| G | 0 | 1 | 1 | 2 | 2 | 2 | 2 |
| H | 0 | 1 | 1 | 2 | 2 |"3"| 3 |

LCS는 “ADH”, 길이는 3이라는 것을 알 수 있다.

LCS 길이 구하기

def lcs_length(str_m, str_n):
    if not str_m or not str_n:
        return 0
    len_m = len(str_m)
    len_n = len(str_n)
    tbl = [[0] * (len_n + 1) for _ in range(len_m + 1)]
    for m in range(1, len_m + 1):
        for n in range(1, len_n + 1):
            if str_m[m - 1] == str_n[n - 1]:
                tbl[m][n] = 1 + tbl[m - 1][n - 1]
            else:
                tbl[m][n] = max(tbl[m - 1][n], tbl[m][n - 1])
    return tbl[m][n]
  • Python에서 배열(리스트)를 초기화 할 때 list comprehension을 사용하는 코드를 기억해두자.
  • 빈문자열의 경우를 추가했기 때문에 배열의 가로세로 사이즈는 각각 문자열 길이 + 1

공간 효율적인 방법

테이블의 스캔 형태를 보면 곱하기 배열 퍼즐 - Product array puzzle 에서 사용한 트릭과 비슷하게 1차원 배열 두개만으로 LCS를 구하는 것이 가능하다.

def lcs_length_two_array(str_m, str_n):
    if not str_m or not str_n:
        return 0
    len_m = len(str_m)
    len_n = len(str_n)
    tbl = [[0] * (len_n + 1) for _ in range(2)]
    for m in range(1, len_m + 1):
        i = m % 2    # i is 0 or 1
        j = 0 if i == 1 else 1
        for n in range(1, len_n + 1):
            if str_m[m - 1] == str_n[n - 1]:
                tbl[i][n] = 1 + tbl[j][n - 1]
            else:
                tbl[i][n] = max(tbl[j][n], tbl[i][n - 1])
    return tbl[i][n]

직접 실행해보기: https://tech.io/snippet/DcioG9b

모든 LCS 찾기

“ABCBDAB”, “BDCABA” 두 문자열은 길이 4짜리 LCS가 하나가 아니라 여러 개 있을 수 있다. “BCAB”, “BCBA”, “BDAB” 이렇게 3가지가 가능하다.

직관적으로 확인하기 위해서 DP 테이블의 모양을 살펴보자.

        B   D   C   A   B   A
 [[ 0 , 0 , 0 , 0 , 0 , 0 , 0 ],
A [ 0 , 0 , 0 , 0 , 1 , 1 , 1 ],
B [ 0 , 1 , 1 , 1 , 1 , 2 , 2 ],
C [ 0 , 1 , 1 , 2 , 2 , 2 , 2 ],
B [ 0 , 1 , 1 , 2 , 2 , 3 , 3 ],
D [ 0 , 1 , 2 , 2 , 2 , 3 , 3 ],
A [ 0 , 1 , 2 , 2 , 3 , 3 , 4 ],
B [ 0 , 1 , 2 , 2 , 3 , 4 , 4 ]]

LCS 길이를 구할 때와는 반대 방향으로 맨 오른쪽 아래에서 부터 왼쪽 위로 거슬러 올라가면서 부분 최적 구조를 탐색해 나가야 한다.

def lcs_dp(str_m, str_n):
    len_m = len(str_m)
    len_n = len(str_n)
    tbl = [[0] * (len_n + 1) for _ in range(len_m + 1)]
    for m in range(1, len_m + 1):
        for n in range(1, len_n + 1):
            if str_m[m - 1] == str_n[n - 1]:
                tbl[m][n] = 1 + tbl[m - 1][n - 1]
            else:
                tbl[m][n] = max(tbl[m - 1][n], tbl[m][n - 1])
    return tbl

def find_all_lcs_rec(str_m, str_n, m, n, tbl):
    if m == 0 or n == 0:
        return set([""])

    if str_m[m - 1] == str_n[n - 1]:
        return set(s + str_m[m - 1] for s in find_all_lcs_rec(str_m, str_n, m - 1, n - 1, tbl))
    else:
        if tbl[m - 1][n] == tbl[m][n - 1]:
            return find_all_lcs_rec(str_m, str_n, m - 1, n, tbl) \
                 | find_all_lcs_rec(str_m, str_n, m, n - 1, tbl)
        elif tbl[m - 1][n] > tbl[m][n - 1]:
            return find_all_lcs_rec(str_m, str_n, m - 1, n, tbl)
        else:
            return find_all_lcs_rec(str_m, str_n, m, n - 1, tbl)

def find_all_lcs(str_m, str_n):
    tbl = lcs_dp(str_m, str_n)
    return find_all_lcs_rec(str_m, str_n, len(str_m), len(str_n), tbl)

직접 실행해보기: https://tech.io/snippet/BuFYoyO

Takeaway

  • DP 문제는 점화식을 구하기 위한 문제 패턴 파악이 핵심 열쇠
  • 대부분 문제 유형이 이미 연구되어 있기 때문에 대표적인 문제들을 다양하게 풀어보는 것이 최선일듯
    • 저걸 어떻게 바로 생각해내나? (나만 어렵나?)