본문 바로가기

Coding Test

백준 7453 합이 0인 네 정수 (Java)

1. 문제

정수로 이루어진 크기가 같은 배열 A, B, C, D가 있다.

A[a], B[b], C[c], D[d]의 합이 0인 (a, b, c, d) 쌍의 개수를 구하는 프로그램을 작성하시오.

입력

첫째 줄에 배열의 크기 n (1 ≤ n ≤ 4000)이 주어진다. 다음 n개 줄에는 A, B, C, D에 포함되는 정수가 공백으로 구분되어져서 주어진다. 배열에 들어있는 정수의 절댓값은 최대 2^28이다.

출력

합이 0이 되는 쌍의 개수를 출력한다.

2. 풀이

네 개의 배열을 모두 탐색하면 최대 O(400^4)로 시간초과가 날 수 있으므로 두 개씩 묶어서 계산한다.

        int idx = 0;
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                arr2[0][idx] = arr[0][i] + arr[1][j];
                arr2[1][idx] = arr[2][i] + arr[3][j];
                idx++;
            }
        }

우선 입력받은 배열의 각 2개, 즉 A와 B, C와 D의 요소를 각각 더해 새로운 배열을 만든다.

        Arrays.sort(arr2[0]);
        Arrays.sort(arr2[1]);

        int s = 0;
        int e = idx - 1;

        while (s < n * n && e >= 0) {
            long sum = arr2[0][s] + arr2[1][e];

            if (sum == 0) {
                long cnt1 = 1;
                long cnt2 = 1;

                while (s < n * n - 1 && arr2[0][s] == arr2[0][s + 1]) {
                    cnt1++;
                    s++;
                }

                while (e > 0 && arr2[1][e] == arr2[1][e - 1]) {
                    cnt2++;
                    e--;
                }

                ans += (cnt1 * cnt2);
                s++;
                e--;
            } else if (sum < 0) {
                s++;
            } else {
                e--;
            }
        }

 이분 탐색을 위해 각 배열을 정렬한다. A와 B의 합에서는 제일 작은 부분, C와 D의 합에서는 제일 큰 부분부터 탐색한다. 두 인덱스가 반대쪽 끝으로 갈 때 까지 해당하는 값들을 더한다. 합이 0보다 작을 경우 값이 커져야 0으로 가까이가고, 0보다 클 경우 작아져야 0보다 가까이 가므로 이에 대한 처리를 해준다.

 0일 경우에는 각 배열에서 같은 값이 나올 수 있으므로 다를 때까지 더해준다. C와 D에서도 마찬가지로 더해준다. 만약 처음 배열에서 a개, 두 번째 배열에서 b개가 나왔을 경우 이 두 경우를 조합해주어야 하므로 a*b를 하여 정답에 추가해준다.

3. 코드 전문

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.Arrays;
import java.util.StringTokenizer;

public class Main {

    static int n;
    static long[][] arr;

    static long[][] arr2;

    static long ans = 0;

    public static void main(String[] args) throws IOException {
        BufferedReader reader = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer tokenizer;

        n = Integer.parseInt(reader.readLine());

        arr = new long[4][n];
        arr2 = new long[2][n * n];

        for (int i = 0; i < n; i++) {
            tokenizer = new StringTokenizer(reader.readLine());

            for (int j = 0; j < 4; j++) {
                arr[j][i] = Integer.parseInt(tokenizer.nextToken());
            }
        }

        int idx = 0;
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                arr2[0][idx] = arr[0][i] + arr[1][j];
                arr2[1][idx] = arr[2][i] + arr[3][j];
                idx++;
            }
        }

        Arrays.sort(arr2[0]);
        Arrays.sort(arr2[1]);

        int s = 0;
        int e = idx - 1;

        while (s < n * n && e >= 0) {
            long sum = arr2[0][s] + arr2[1][e];

            if (sum == 0) {
                long cnt1 = 1;
                long cnt2 = 1;

                while (s < n * n - 1 && arr2[0][s] == arr2[0][s + 1]) {
                    cnt1++;
                    s++;
                }

                while (e > 0 && arr2[1][e] == arr2[1][e - 1]) {
                    cnt2++;
                    e--;
                }

                ans += (cnt1 * cnt2);
                s++;
                e--;
            } else if (sum < 0) {
                s++;
            } else {
                e--;
            }
        }

        System.out.println(ans);

    }

}