알고리즘 공부

[삼성SW역량테스트] 상어중학교 (백준 21609)| Python3

유나쒸 2021. 10. 18. 14:56


 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import sys
from collections import deque
input=sys.stdin.readline
 
def get_biggest_group():
 
    def bfs(row,col,color):
        block=[]
        cnt_of_rainbow=0
        min_r, min_c=n-1,n-1
        dx, dy = [-1,1,0,0], [0,0,-1,1]
        queue=deque([(row,col)])
        visited[row][col]=True
        while queue:
            r,c=queue.popleft()
            for x,y in zip(dx,dy):
                mr,mc=r+x,c+y
                if 0<=mr<and 0<=mc<and not visited[mr][mc]:
                    if maps[mr][mc]==color or maps[mr][mc]==0:
                        queue.append((mr,mc))
                        visited[mr][mc]=True
            if maps[r][c]==0: cnt_of_rainbow+=1
            else
                if min_r>=r:
                    min_r, min_c= r, min(min_c, c)
            block.append((r,c))
        return len(block), cnt_of_rainbow, min_r, min_c,block
 
    groups=[]
    visited=[[False]*for _ in range(n)]
    
    for i in range(n):
        for j in range(n):
            if not visited[i][j] and maps[i][j]>0:
                size, cnt_of_rainbow, min_r, min_c, block=bfs(i,j,maps[i][j])
                if size<2:
                    for r,c in block: visited[r][c]=False
                else
                    for r,c in block:
                        if maps[r][c]==0: visited[r][c]=False
                    groups.append([size, cnt_of_rainbow, min_r, min_c, block])
 
    return sorted(groups, key=lambda x: (-x[0], -x[1], -x[2], -x[3]))[0][4if groups else []
 
def remove_and_get_score(group):
    global total
    for r,c in group: maps[r][c]=-2
    total+=pow(len(group),2)
    return 
 
def rotate_map():
    global maps
    new_map=[item for item in zip(*maps)]   
    maps=[list(new_map[i][:]) for i in range(n-1,-1,-1)]
 
def get_gravity():
    for k in range(n):
        for i in range(n-2,-1,-1):
            if maps[i][k]>-1:
                prev_r, prev_c = i,k
                for j in range(i+1, n):
                    if maps[j][k]==-2:
                        maps[j][k]=maps[prev_r][prev_c]
                        maps[prev_r][prev_c]=-2
                        prev_r, prev_c = j, k
                    else: break
    return
 
n,m=map(int, input().split())
maps=[list(map(int, input().split())) for _ in range(n)]
total=0
 
while True:
    group=get_biggest_group()
    if not group: break
    remove_and_get_score(group)
    get_gravity()
    rotate_map()
    get_gravity()
 
print(total)
 
cs