본문 바로가기

BOJ

[백준 C++] 11050번, 11051번, 11401번 이항계수 1, 2, 3

이항 계수 1

이항 계수 2

이항 계수 3


비록 문과생이지만 나름 수학에 자신있던 편이었어서 쉽게 풀 줄 알았다가 이틀동안 개고생을 했다. 페르마의 소정리, 확장 유클리드 알고리즘, 나머지 연산 등등... 아는 걸 찾는게 빠를 정도로 다 몰랐다. 이번 기회에 이항 계수 문제를 푸는 몇가지 알고리즘을 정리해야겠다는 생각을 했다.


이항 계수는 고등학교 수학 과정에서, 그리고 대학교 통계 수업에서 배웠던 걸로 기억한다. 이항 계수란 n 개의 원소에서 r 개를 뽑아내는 방법의 수를 의미한다. 


보통 nCr 의 형태를 띄고 있으며, 계산을 할 때는 다음과 같이 했던 것으로 기억한다.


$$\binom{n}{r} = \left( \frac{n!}{r!(n - r)!} \right)$$


1. 이항 계수 1


이항계수 1은 다음을 활용해서 풀었다.

$$\binom{n}{r} = \binom{n - 1}{r} + \binom{n - 1}{r - 1}$$


n 이 10 이하의 값이어서 이 코드로도 충분하다. K 가 0이거나 K 와 N 의 값이 같다면 이항 계수는 1 이다.


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
#include <iostream>
using namespace std;
 
int bi(int N, int K) {
  if (K == 0 || K == N) return 1;
  else return bi(N - 1, K) + bi(N - 1, K - 1);
}
 
int main() {
 
  int N, K;
  cin >> N >> K;
 
  cout << bi(N, K) << '\n';
}
cs


2. 이항 계수 2


이항계수 2는 n 이 1,000 까지 커진다. 그 정도로 수가 커지면 이항 계수의 값도 크기 때문에 10,007 로 나눈 나머지를 출력해야 한다.


두 가지를 신경써야 한다.


첫 번째로, 1,000 까지 계산을 해야 하므로 이항 계수 1 에서 했던 것보다 빠르게 계산을 해야 한다.

두 번째로, 결과값을 10,007로 나눈 나머지를 출력해야 한다. 


첫 번째 사항을 해결하기 위해 메모이제이션이라는 방법을 활용한다. 메모이제이션은 동일한 계산을 반복해야 할 경우에 한 번 계산한 결과를 메모리에 저장해 둿다가 꺼내 씀으로써 중복 계산을 방지할 수 있게 하는 기법이다. 5C2를 계산할 경우에 메모이제이션 기법을 활용하지 않고 이항 계수 1 에서 풀었던 것처럼 풀면 3C2와 2C1이 여러번 등장하는데 똑같은 계산을 여러번 할수록 시간이 그만큼 오래 걸린다. 특히 숫자가 클 수록 같은 계산을 반복하는데 오랜 시간을 잡아 먹는다.


2차원 배열을 만들어 결과를 넣어서 바로 arr[N][K] 를 출력했다.


두번째 사항을 해결하기 위해 분배 법칙을 공부했다(참고한 블로그). 기본적으로 나머지 연산은 덧셈에 대해 분배 법칙이 성립하기 때문에 아래와 같은 식이 성립한다.


$$(a + b) \% M = (( a \% M) + (b \% M)) \% M$$


그렇기 때문에  두 값을 더해서 새로운 값을 배열에 추가할 때 마다 MOD로 나눠주는 작업을 한다.


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
#include <iostream>
using namespace std;
#define SIZE 1001
#define MOD 10007
int arr[SIZE][SIZE];
 
int main() {
 
  int N, K;
  cin >> N >> K;
 
  arr[0][0= arr[1][1= arr[1][0= 1;
 
  for (int i = 2; i <= N; i++) {
    for (int j = 0; j <= i; j++) {
      if ((i == j) || (j == 0)) arr[i][j] = 1;
      else if ((j == 1|| (j == i - 1)) arr[i][j] = i;
      else
        arr[i][j] = ((arr[i - 1][j]) + (arr[i - 1][j - 1])) % MOD;
    }
  }
  cout << arr[N][K] << '\n';
}
cs

3. 이항 계수 3

이항 계수 3 에서는 입력되는 수가 급격하게 커진다. N 의 경우에는 4,000,000 까지 커질 수 있다. 그렇기 때문에 mod 도 1,000,000,007 로 상당히 크다. 수가 커진데 비해서 시간은 1초 밖에 주어지지 않았으므로 이항 계수 2 번에서 작성했던 코드보다 더 효율적으로 계산해야 한다. 참고로 백준 블로그에 이 문제를 설명 해놓은 것이 있다.


맨 앞에서 언급했던 $\binom{n}{r} = \left( \frac{n!}{r!(n - r)!} \right)$ 공식을 사용한다.


이 때 페르마의 소정리(Fermat’s Little Theorem)가 쓰인다. 페르마의 소정리는 두 수 a 와 m 가 서로소라면 $a^{m-1}\%m = 1$ 이다. 증명하는 방법도 찾긴 찾았지만 수학 공부 하는 것은 아니기 때문에 일단 저 공식만 가져왔다. $a^{m-1} \%m= (a * a^{m-2})\%m = 1$이 되고, $a^{m-2}$가 $(a * x)\%m = 1$을 만족하는 $x$가 되기 때문에, 역원은 $a^{m-2}$가 된다.


정리를 해보자면 $n!$을 A라고 놓고, $r!(n - r)$을 B라고 한다면 A / B를 구해야 한다(m = 1,000,000,007). m은 소수이기 때문에 페르마의 소정리를 이용할 수 있다. 따라서 구해야 하는 값은 $A * B^{m - 2}$다.


m 의 값이 1,000,000,007 인 것을 감안하면 이 계산도 만만치 않다. 다른 분의 블로그에서 어떻게 할지 배웠다.


$2^{10}$ 을 계산한다고 하자. 편의를 위해 작은 숫자인 10을 썼다. 


$2^{10} = 4^5$

$4^5 = 4 * 16^2$

$4 * 16^2 = 4 * 256^1$

$4 * 256^1 = 4 * 256 * (256 * 256)^0$


그러므로 답은 4 * 256인 1,024다.


거듭제곱을 구하는 코드를 함수로 만들어 보면 다음과 같다.


1
2
3
4
5
6
7
8
long long power (long long x, long long n )
 
if( n == 0 ) return 1;
else if ( (n % 2== 0 )
    return power( x*x , n/2 );
else 
    return x*power(x*x, n-1/2);
}
cs

아래는 실제로 제출했던 코드다. 위에 나왔던 함수는 재귀를 사용했던 반면 이 함수는 반복문으로 문제를 해결한다. 충분히 큰 데이터 타입인 long long 으로 변수 선언을 해 줬지만 결과값이 그것보다 훨씬 크므로 나머지 연산을 계속 해줘야 한다.


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
#include <iostream>
using namespace std;
#define ll long long
 
ll power (ll x, ll y, ll mod) {
  ll ans = 1;
  while (y > 0) {
    if (y % 2 == 1) {
      ans *= x;
      ans %= mod;
    }
    x *= x;
    x %= mod;
    y /= 2;
  }
  return ans;
}
 
int main() {
  int N, K;
  cin >> N >> K;
  ll mod = 1000000007;
 
  ll n1 = 1, n2 = 1, n3;
 
  for (int i = 2; i <= N; i++) {
    n1 *= i;
    n1 %= mod;
  }
 
  for (int i = 2; i <= K; i++) {
    n2 *= i;
    n2 %= mod;
  }
 
  for (int i = 2; i <= (N - K); i++) {
    n2 *= i;
    n2 %= mod;
  }
 
  n3 = power(n2, mod - 2, mod);
  n3 %= mod;
  n3 *= n1;
  n3 %= mod;
  cout << n3 << '\n';
 
  return 0;
}
cs



참고 1

참고 2

참고 3