for textmining

Generative Adversarial Network

|

이번 글에서는 Generative Adversarial Network(이하 GAN)에 대해 살펴보도록 하겠습니다. 이 글은 전인수 서울대 박사과정이 2017년 12월에 진행한 패스트캠퍼스 강의와 위키피디아 등을 정리했음을 먼저 밝힙니다. 그럼 시작하겠습니다.

concept

Ian Goodfellow가 2014년 제안한 GAN은 생성자(generator, $G$)와 구분자(discirimiator, $D$), 두 네트워크를 적대적(adversarial)으로 학습시키는 비지도 학습 기반 생성모델(unsupervised generative model)입니다. $G$는 Zero-Mean Gaussian으로 생성된 $z$를 받아서 실제 데이터와 비슷한 데이터를 만들어내도록 학습됩니다. $D$는 실제 데이터와 $G$가 생성한 가짜 데이터를 구별하도록 학습됩니다. GAN의 궁극적인 목적은 실제 데이터의 분포에 가까운 데이터를 생성하는 것입니다. GAN을 도식화한 그림은 다음과 같습니다.

목적함수

GAN의 목적함수는 다음과 같습니다. 게임이론 타입의 목적함수로 두 명의 플레이어($G$와 $D$)가 싸우면서 서로 균형점(nash equilibrium)을 찾아가도록 하는 방식입니다.

우선 $D$ 입장에서 살펴보겠습니다. 실제 데이터($x$)를 입력하면 높은 확률이 나오도록 하고($D(x)$를 높임), 가짜 데이터($G(z)$)를 입력하면 확률이 낮아지도록($1-D(G(z))$를 낮춤=$D(G(z))$를 높임) 학습됩니다. 다시 말해 $D$는 실제 데이터와 $G$가 만든 가상데이터를 잘 구분하도록 조금씩 업데이트됩니다.

이번엔 $G$를 살펴보겠습니다. Zero-Mean Gaussian으로 뽑은 노이즈 $z$를 받아 생성된 가짜 데이터($G(z)$)를 $D$에 넣었을 때, 실제 데이터처럼 확률이 높게 나오도록($1-D(G(z))$를 높임=$D(G(z))$를 낮춤) 학습됩니다. 다시 말해 $G$는 $D$가 잘 구분하지 못하는 데이터를 생성하도록 조금씩 업데이트됩니다.

실제 학습을 진행할 때는 두 네트워크를 동시에 학습시키지 않고 따로따로 업데이트를 합니다. $D$를 학습시킬 때는 $G$를 고정한 상태에서, $G$를 학습시킬 때는 $D$를 고정한 상태에서 진행한다는 이야기이죠.

우선 $D$를 학습하기 위한 목적함수는 다음과 같이 다시 쓸 수 있습니다. 즉 $G$의 파라메터를 고정한 상태에서 실제 데이터 $m$개, $G$가 생성한 가짜 데이터 $m$개를 $D$에 넣고 $V$를 계산한 뒤, $D$에 대한 $V$의 그래디언트를 구하고 $V$를 높이는 방향으로 $D$의 파라메터를 업데이트합니다.

$G$의 목적함수는 다음과 같습니다. $D$의 파라메터를 고정한 상태에서 가짜 데이터 $m$개를 생성해 $V$을 계산한 뒤, $G$에 대한 $V$의 그래디언트를 구하고 $V$를 낮추는 방향으로 $G$의 파라메터를 업데이트합니다.

학습 초기에 $G$가 생성하는 데이터의 품질은 조악할 것입니다. 이 경우 $D$는 확실하게 가짜 데이터라고 잘 구분하게 되겠죠. $D(G(z))$가 0에 가까울 것이라는 이야기입니다.

그런데 목적함수를 잘 보면 expectation 안쪽이 $\log{(1-x)}$ 꼴임을 알 수 있습니다. 이 경우 $x$가 0일 때 기울기가 작습니다. 학습 초기 $G$의 파라메터를 팍팍 업데이트해줘야 하는데 그러지 못할 가능성이 크다는 말입니다.

이에 $G$의 목적함수를 아래처럼 살짝 바꿔서 초기 $G$ 학습을 가속화합니다. $\log(x)$의 경우 $x$가 0일 때 기울기가 매우 큽니다.

학습

몇 가지 수식을 통해 GAN 학습 과정을 좀 더 살펴보도록 하겠습니다. 이상적인 경우, 즉 $G$가 매우 학습이 잘 되었다면 $G$가 Zero-Mean Gaussian으로 뽑은 노이즈 $z$를 받아 생성한 데이터와 실제 데이터가 일치할 것입니다($G(z)=x$, $p_g(x)=p_{data}(x)$). 이 경우 최적의 구분자 $D$는 다음과 같이 식을 쓸 수 있습니다.

위 식이 최대화되는 지점은 위 식을 우리가 알고자 하는 $D(x)$로 미분한 값이 0이 되는 지점입니다. 식을 $D(x)$로 미분한 결과를 0으로 만든 식을 $D(x)$로 정리하면 다음과 같습니다. 아래 식에 원래 가정($p_g(x)=p_{data}(x)$)을 대입해 풀면 최적의 구분자 $D$는 1/2로 수렴합니다.

$D$는 데이터 $x$가 주어졌을 때 실제 데이터($y=1$)일 확률을 의미합니다. 이를 베이즈룰을 이용해 정리하면 최적의 $D$는 사전확률 $p(y=1)$와 $p(y=0)$이 1/2로 서로 같을 때 도출됩니다.

GAN의 목적함수를 최적의 구분자 $D$를 전제하고 식을 정리하면 다음과 같습니다. 다시 말해 최적의 $D$가 전제된 상황이라면 GAN의 목적함수를 최적화하는 과정은 $p_{data}$와 $p_g$ 사이의 젠슨-섀넌 다이버전스(Jensen-Shannon divergence)를 최소화하는 것과 같습니다. 젠슨-섀넌 다이버전스는 두 확률 분포 간 차이를 재는 함수의 일종인데요. 데이터의 분포와 $G$가 생성하는 분포 사이의 차이를 줄인다고 해석할 수 있습니다.

GAN의 학습과정을 도식화한 그림은 다음과 같습니다. 학습이 진행될 수록 $p_g$(녹색 실선)는 $p_{data}$(검정 점선), $D$(파란 점선)는 1/2로 수렴하는 것을 확인할 수 있습니다.

단점과 극복방안

GAN의 이론적 배경은 탄탄하지만, 실제 적용에는 많은 문제가 있습니다. 대표적으로 학습이 어렵다는 점을 꼽을 수 있겠습니다. 차례대로 살펴보겠습니다.

mode collapsing

mode collapsing이란 우리가 학습시키려는 모형이 실제 데이터의 분포를 모두 커버하지 못하고 다양성을 잃어버리는 현상을 가리킵니다. 그저 손실(loss)만을 줄이려고 학습을 하기 때문에 $G$가 전체 데이터 분포를 찾지 못하고, 아래 그림처럼 한번에 하나의 mode에만 강하게 몰리게 되는 경우입니다. 예컨대 MNIST를 학습한 $G$가 특정 숫자만 생성한다든지 하는 사례가 바로 여기에 속합니다.

진동

$G$와 $D$가 서로 진동하듯 수렴하지 않는 문제 역시 mode collapsing 문제와 관련이 있습니다.

mode collapsing 해결방안

mode collapsing 해결방안 가운데 간단하면서 효과적인 것으로 알려진 세 가지는 다음과 같습니다. 모델이 전체 데이터 분포의 경계를 골고루 학습하게 하고, 그것을 계속 기억할 수 있도록 하는 것이 핵심입니다.

  • feature matching : 가짜 데이터와 실제 데이터 사이의 least square error를 목적함수에 추가
  • mini-batch discrimination : 미니배치별로 가짜 데이터와 실제 데이터 사이의 거리 합의 차이를 목적함수에 추가
  • historical averaging : 배치 단위로 파라메터를 업데이트하면 이전 학습은 잘 잊히게 되므로, 이전 학습 내용을 기억하는 방식으로 학습

힘의 균형

$D$보다 $G$를 학습시키는 것이 일반적으로 어렵습니다. $G$가 학습이 잘 안되어서 둘 사이의 힘의 균형이 깨지는 경우 GAN 학습이 더 이상 진전될 수 없습니다. GAN 연구 초기에는 $G$를 $k$번 업데이트시키고, $D$를 한번 업데이트시키거나, 목적함수 비율을 조절하는 등 밸런스를 맞추기 위해 다양한 방식을 시도하였으나 뾰족한 대안이 되진 못했습니다. 그러나 최근엔 LSGAN, WGAN, F-GAN, EBGAN 등 손실함수를 바꿔서 이 문제를 해결한 연구가 여럿 제안되었습니다.

평가

GAN은 데이터 생성이 목적이기 때문에 정량적인 평가가 어렵습니다. 그럼에도 다음과 같은 세 가지 방식이 널리 쓰입니다.

  • 정성평가 : 사람이 직접 평가
  • 학습된 분류기를 이용 : 기존 뉴럴네트워크를 활용해 label이 있는 데이터 셋을 학습시킨다. 동일한 데이터로 GAN을 학습한 후 $G$를 이용해서 새로운 데이터를 생성하고 미리 학습시켜둔 분류기 모델에 넣어 분류를 시행한다. 이 때 (1)생성된 새로운 데이터가 한 범주에 높은 확률로 분류되거나 (2)전체적으로 다양한 범주의 데이터가 생성됐다면 GAN의 성능을 높다고 평가할 수 있다.
  • inception score : $G$가 생성한 데이터의 다양성(개성)을 측정하는 지표로 클 수록 좋다.

Comments