문제 설명

- 첫 번째 줄에서는 주어질 데이터의 개수(N)가 주어진다.

- 다음 줄에는 N개의 데이터가 주어진다. (1 <= N <= 500,000)

- 시간 제한  1초

출력 : 숫자 간 교환이 일어난 횟수를 구해야 한다.

 

 

버블 정렬 예시

(1) 321 : 3 2 1 => 2 3 1  => 2 1 3 => 1 2 3

(2) 4253 : 4 2 5 3 => 2 4 5 3=> 2 4 3 5 => 2 3 4 5

 

문제 풀이

(1) 어떤 정렬 알고리즘을 사용할지?

버블 소트는 N 제곱의 시간 복잡도를 가진다.

문제 조건에 따르면 최대 25 * (10^10)의 시간 복잡도를 가진다.

 

파이썬의 1초당 계산 가능 횟수는 2 * (10^7)이므로, 버블 소팅으로는 문제를 풀 수 없다.

 

따라서 다른 방식의 정렬을 이용해야 한다.

병합정렬을 이용하면 N * log(N)의 시간복잡도가 소요된다. 

5*(10^5) * log(5*(10^5))이므로, 시간 복잡도는 해결이 된다.

 

(2) 숫자 간 교환 횟수를 어떻게 구할지?

버블 정렬에서는, 

1. 두 요소 간 순서가 맞으면 가만히 두고

2. 순서가 틀렸다면 숫자를 교환한다.

 

2번에 착안해서 버블 정렬에서 정렬하는 과정을 살펴보자.

1. 위 그림의 맨 밑에 줄부터 A와 B를 비교해서 B가 작으면 교환이 일어난다.

이 때, 교환될 숫자의 개수는 CNT에 더해주자. (남은 A의 개수만큼 더해주면 된다.)

[3] [1] => [1, 3]           [2] [4] => [2, 4]

CNT += len(A)

 

2. [1, 3] [2, 4]를 비교한다. 

A의 맨 앞과 B의 맨 앞을 비교했을 때, A가 작으므로, A의 1이 빠진다.

[1,3] [2, 4] => [3] [2, 4],  [1]

 

3. A의 맨 앞과 B의 맨 앞을 비교했을 때, B가 작으므로, B의 2가 빠진다.

B가 작으면, 스와핑이 일어나야 하므로 남은 A의 총 길이를 CNT에 더해준다. 

[3] [2, 4] => [3] [4],  [1, 2]

CNT += len(A)

 

4. A의 맨 앞과 B의 맨 앞을 비교했을 때, A가 작으므로, A의 3이 빠진다.

[3] [4], => [] [4], [1, 2, 3]

 

5. A에 들어있던 모든 원소가 빠졌으므로, B에 남은 원소를 추가한다

[] [4] => [] [], [1, 2, 3, 4]

 

6. 스와핑이 일어난 횟수(CNT):  2

 

반례

문제에는 동일한 원소가 나오지 않는다는 조건이 없다.

즉, 3 2 5 3 4와 같은 경우가 있을 수 있다.

이를 고려하여 코드를 작성해야 한다.

 

코드

import sys
# sys.stdin = open("C:/Users/JIn/PycharmProjects/coding_Test/input.txt", "rt")


def mergeSort(start, end):
    global cnt
    if start < end:
        mid = (start + end) // 2
        mergeSort(start, mid)
        mergeSort(mid + 1, end)

        a = start
        b = mid + 1
        tmp = []
        while a <= mid and b <= end:
            if arr[a] <= arr[b]:
                tmp.append(arr[a])
                a += 1

            else:
                tmp.append(arr[b])
                b += 1
                cnt += (mid - a + 1)  # 스와핑 했을 때 개수추가
        if a <= mid:
            tmp = tmp + arr[a:mid+1]
        if b <= end:
            tmp = tmp + arr[b:end+1]

        for i in range(len(tmp)):
            arr[start + i] = tmp[i]

if __name__ == '__main__':
    cnt = 0
    n = int(input())
    arr = list(map(int, input().split()))
    mergeSort(0, n-1)
    print(cnt)

+ Recent posts