본문 바로가기
자유로운 이야기

[백준] 14500번 - 테트로미노

by Kaya_Alpha 2023. 6. 21.

문제 출처 : https://www.acmicpc.net/problem/14500

 

14500번: 테트로미노

폴리오미노란 크기가 1×1인 정사각형을 여러 개 이어서 붙인 도형이며, 다음과 같은 조건을 만족해야 한다. 정사각형은 서로 겹치면 안 된다. 도형은 모두 연결되어 있어야 한다. 정사각형의 변

www.acmicpc.net

백준 14500번 - 테트로미노의 풀이와 개인적인 저의 생각을 정리하였습니다.

 

[문제]

문제 설명

[입력과 출력]

입력과 출력

우선 문제를 읽으면 BFS 혹은 DFS로 접근을 해야겠다는 생각이 들게 됩니다. 여기서 저는 BFS는 주변을 순차적으로 탐색하기 때문에 적절하지 않으므로 DFS를 이용하여 탐색을 해야겠구나! 라고 접근하였습니다.

 

그렇다면 DFS를 이용하여 접근을 하게 된다면 다음과 같이 코드를 작성할 수 있습니다.

import sys

N,M = map(int,sys.stdin.readline().strip().split())
board = []
for _ in range(N):
    line = list(map(int,sys.stdin.readline().strip().split()))
    board.append(line)
    
def find(total_board,check_board,start_point,sums,count):
    global MAXIMUM
    x,y = start_point
    s = total_board[x][y]
    
    if count == 4:
        if sums > MAXIMUM:
            MAXIMUM = sums
        return
    
    directions = [(0,1),(0,-1),(1,0),(-1,0)]
    for d in directions:
        next_x = x + d[0]
        next_y = y + d[1]
        if 0 <= next_x < N and 0 <= next_y < M and check_board[next_x][next_y] == False:
            next_point = [next_x,next_y]
            
            check_board[next_x][next_y] = True
            find(total_board,check_board,next_point,sums + s,count + 1)
            check_board[next_x][next_y] = False

MAXIMUM = -1
check = [[False]*M for _ in range(N)]

for i in range(N):
    for j in range(M):
        check[i][j] = True
        find(board,check,(i,j),0,0)
        check[i][j] = False

print(MAXIMUM)

 

구현을 하고 테스트를 하게 되면 다음과 같은 예제에서 틀리게 된다.

 

테스트케이스 3

이 테스트케이스를 실행시켜보면 7이 아닌 6이 나오게 된다.

그 이유는 DFS로 접근하기 때문에 ㅗ/ ㅓ / ㅏ / ㅜ 같은 경우에 대해서는 탐색을 할 수 없다.

 

이 상황을 그림으로 표현하면 다음과 같다.

만약 위 그림에서 가장 아랫 부분부터 DFS탐색을 시작했으며, 재귀를 돌면서 현재 지금 파란색 부분까지 Visited를 찍었다고 가정하자.

이렇게 되면 파란색 부분 다음, 아직 방문하지 않은 초록색 부분을 탐색하기 위해서는 다시 이전 노드로 돌아가야 하기 때문에 상당히 까다롭다. 초록색 부분이 오른쪽이 아닌 왼쪽도 있는 경우가 있기 때문에 두 경우를 모두 고려해야한다.

이러한 경우에 가장 간단한 해결 방법은 이러한 유형의 블록은 예외로 두고 따로 처리를 하는 것이다.

 

따로 처리를 하는 함수를 find_plus 라는 이름으로 구현하였다.

import sys

N,M = map(int,sys.stdin.readline().strip().split())
board = []
for _ in range(N):
    line = list(map(int,sys.stdin.readline().strip().split()))
    board.append(line)
    
def find(total_board,check_board,start_point,sums,count):
    global MAXIMUM
    x,y = start_point
    s = total_board[x][y]
    
    if count == 4:
        if sums > MAXIMUM:
            MAXIMUM = sums
        return
    
    directions = [(0,1),(0,-1),(1,0),(-1,0)]
    for d in directions:
        next_x = x + d[0]
        next_y = y + d[1]
        if 0 <= next_x < N and 0 <= next_y < M and check_board[next_x][next_y] == False:
            next_point = [next_x,next_y]
            
            check_board[next_x][next_y] = True
            find(total_board,check_board,next_point,sums + s,count + 1)
            check_board[next_x][next_y] = False


def find_plus(total_board,start_point):
    global MAXIMUM
    x,y = start_point
    #ㅗ/ㅜ/ㅏ/ㅓ 케이스
    cases = [[(0,1),(0,-1),(1,0)],
             [(0,1),(0,-1),(-1,0)],
             [(0,1),(1,0),(-1,0)],
             [(0,-1),(1,0),(-1,0)]]
    for c in cases:
        total = total_board[x][y]
        for d in c:
            nx = x + d[0]
            ny = y + d[1]
            if 0 <= nx < N and 0 <= ny < M:
                total += total_board[nx][ny]
            else:
                break
        else:
            if total > MAXIMUM:
                MAXIMUM = total

MAXIMUM = -1
check = [[False]*M for _ in range(N)]

for i in range(N):
    for j in range(M):
        check[i][j] = True
        find(board,check,(i,j),0,0)
        find_plus(board,(i,j))
        check[i][j] = False

print(MAXIMUM)

하지만 이 코드는 아쉽게도 Python3에서는 시간초과가 뜨지만 PyPy3에서는 통과가 된다.

느리다!

즉, 느리다는 뜻이다.

그렇다면, 좀 더 빠르게 하기 위해서는 함수 호출 횟수를 좀 줄여볼 수 있을 것이다.

즉, find_plus의 로직을 find함수에 녹여내는 것이다. 즉, 하나의 함수로 ㅗ/ㅜ/ㅓ/ㅏ 의 경우도 처리하는 것이다.

 

다시 ㅗ/ㅜ/ㅓ/ㅏ 의 경우를 살펴보자.

두 번째 블럭까지 방문하게 되면, 이 특수 케이스를 처리하기 위해서는 인접한 두 블록을 방문처리 해야하는 것이다.

여기서 반짝이는 아이디어는 다음 방문해야할 블록을 방문처리 하고, 다음 블록이 아닌 현재 블록을 한번 더 DFS를 돌리는 것이다.

즉 그림으로 표현하면 다음과 같다.

코드로 보면 다음과 같다.

import sys

N,M = map(int,sys.stdin.readline().strip().split())
board = []
for _ in range(N):
    line = list(map(int,sys.stdin.readline().strip().split()))
    board.append(line)
    
def find(total_board,check_board,start_point,sums,count):
    global MAXIMUM
    x,y = start_point
    s = total_board[x][y]
    
    if count == 3:
        sums += total_board[x][y]
        if sums > MAXIMUM:
            MAXIMUM = sums
        return
    
    directions = [(0,1),(0,-1),(1,0),(-1,0)]
    for d in directions:
        next_x = x + d[0]
        next_y = y + d[1]
        if 0 <= next_x < N and 0 <= next_y < M and check_board[next_x][next_y] == False:
            next_point = [next_x,next_y]
            if count == 1:
                # 2칸을 전진했으면 ㅗ/ㅓ/ㅏ/ㅜ 같은 경우는 한칸 전진 후, 그 값을 더해주고
                #현재 start_point지점에서 DFS를 한번 더 돌림 -> ㅜ 와 같은 모양이 나옴
                check_board[next_x][next_y] = True
                find(total_board,check_board,start_point,sums + total_board[next_x][next_y],count + 1)
                check_board[next_x][next_y] = False

            check_board[next_x][next_y] = True
            find(total_board,check_board,next_point,sums + s,count + 1)
            check_board[next_x][next_y] = False


MAXIMUM = -1
check = [[False]*M for _ in range(N)]

for i in range(N):
    for j in range(M):
        check[i][j] = True
        find(board,check,(i,j),0,0)
        check[i][j] = False

print(MAXIMUM)

이렇게 되면 드디어 Python3에서 통과가 된다!

통과!

하지만 시간을 보면 7940ms나 걸린다...

즉, 아직도 느리다는 소리다.

 

어떤 문제가 있는지 생각해보았는데, 로직은 나쁘지 않은것 같았다.

그래서 다음은 불필요한 계산을 최대한 줄여보고자 했다. 여기서 불필요한 계산은 굳이 계산을 해야하나? 라는 느낌이다.

이 문제에서는 정수가 쓰인 N X M 크기의 종이에서 값이 가장 큰 값이 무엇인지 구한 다음에,

매 재귀마다 나머지 칸을 최대값으로 채웠을 때의 값이 현재 MAXSIZE 값보다 작다면? 굳이 더 재귀를 돌릴 필요가 없다는 의미라는 뜻이다.(우리는 MAXSIZE를 찾는 것이기 때문에 남은 빈 부분을 최대값으로 채운 테트로미노의 값이 MAXSIZE보다 작으면 우리가 찾는 테트로미노는 아니다.)

따라서 간단히 다음 코드를 추가해준다.

if sums + total_board[x][y] + board_max*(3-count) < MAXIMUM:
        return

위 코드를 적용한 전체 코드는 다음과 같다. (보드의 최대값을 구하는 부분을 제외한 나머지 부분은 전부 똑같다)

import sys

N,M = map(int,sys.stdin.readline().strip().split())
board = []
board_max = -1
for _ in range(N):
    line = list(map(int,sys.stdin.readline().strip().split()))
    line_max = max(line)
    board_max = max(board_max,line_max)#입력한 판때기의 최대값 구하기
    board.append(line)
    
def find(total_board,check_board,start_point,sums,count):
    global MAXIMUM
    x,y = start_point
    s = total_board[x][y]
    #불필요한 계산은 끝내버리기
    #나머지 부분을 최대값으로 채웠을때보다 현재 최대값이 더 큰 경우 -> 더 해볼 필요가 없음
    if sums + total_board[x][y] + board_max*(3-count) < MAXIMUM:
        return
    #finish!
    if count == 3:
        sums += total_board[x][y]
        if sums > MAXIMUM:
            MAXIMUM = sums
        return
    
    directions = [(0,1),(0,-1),(1,0),(-1,0)]
    for d in directions:
        next_x = x + d[0]
        next_y = y + d[1]
        if 0 <= next_x < N and 0 <= next_y < M and check_board[next_x][next_y] == False:
            next_point = [next_x,next_y]
            if count == 1:
                # 2칸을 전진했으면 ㅗ/ㅓ/ㅏ/ㅜ 같은 경우는 한칸 전진 후, 그 값을 더해주고
                #현재 start_point지점에서 DFS를 한번 더 돌림 -> ㅜ 와 같은 모양이 나옴
                check_board[next_x][next_y] = True
                find(total_board,check_board,start_point,sums + total_board[next_x][next_y],count + 1)
                check_board[next_x][next_y] = False
            
            check_board[next_x][next_y] = True
            find(total_board,check_board,next_point,sums + s,count + 1)
            check_board[next_x][next_y] = False

MAXIMUM = -1 #정답값
check = [[False]*M for _ in range(N)] #visited

for i in range(N):
    for j in range(M):
        check[i][j] = True
        find(board,check,(i,j),0,0)
        check[i][j] = False

print(MAXIMUM)

위 코드를 제출하게 되면 

코드 한줄만 추가 했을 뿐인데 처음과 비교할 수 없을 정도로 빠르다!

즉, 불필요한 계산이 기존 코드는 너무 많았다는 뜻으로 보인다.