for textmining

Variational AutoEncoder

|

이번 글에서는 Variational AutoEncoder(VAE)에 대해 살펴보도록 하겠습니다. 이 글은 전인수 서울대 박사과정이 2017년 12월에 진행한 패스트캠퍼스 강의와 위키피디아, 그리고 이곳 등을 정리했음을 먼저 밝힙니다. PyTorch 코드는 이곳을 참고하였습니다. 그럼 시작하겠습니다.

concept

VAE는 데이터가 생성되는 과정, 즉 데이터의 확률분포를 학습하기 위한 두 개의 뉴럴네트워크로 구성되어 있습니다. VAE는 잠재변수(latent variable) $z$를 가정하고 있는데요. 우선 encoder라 불리는 뉴럴네트워크는 관측된 데이터 $x$를 받아서 잠재변수 $z$를 만들어 냅니다. decoder라 불리는 뉴럴네트워크는 encoder가 만든 $z$를 활용해 $x$를 복원해내는 역할을 합니다. VAE 아키텍처는 다음 그림과 같습니다.

그렇다면 여기에서 잠재변수 $z$는 어떤 의미인 걸까요? 고양이 그림 예시를 들어 생각해보겠습니다. 수많은 고양이 사진이 있다고 칩시다. 사람은 고양이 사진들이 저마다 다르게 생겼다 하더라도 이들 사진이 고양이임을 단박에 알아낼 수 있습니다. 사람들은 고양이 사진을 픽셀 단위로 자세하게 보고 고양이라고 판단하는게 아니라, 털 색깔, 눈 모양, 이빨 개수 등 추상화된 특징을 보고 고양이라는 결론을 냅니다.

이를 잠재변수 $z$와 VAE 아키텍처 관점에서 이해해 보자면, encoder는 입력 데이터를 추상화하여 잠재적인 특징을 추출하는 역할, decoder는 이러한 잠재적인 특징을 바탕으로 원 데이터로 복원하는 역할을 한다고 해석해볼 수 있겠습니다. 실제로 잘 학습된 VAE는 임의의 $z$값을 decoder에 넣으면 다양한 데이터를 생성할 수 있다고 합니다.

latent vector 만들기

VAE의 decoder 파트는 다음과 같이 정규분포를 전제로 하고 있습니다. 다시 말해 encoder가 만들어낸 $z$의 평균과 분산을 모수로 하는 정규분포입니다.

최대우도추정(MLE) 방식으로 VAE 모델의 파라메터를 추정하려면 다음과 같이 정의된 marginal log-likelihood $\log{p(x)}$를 최대화하면 됩니다. 아래 식을 최대화하면 모델이 데이터를 그럴싸하게 설명할 수 있게 됩니다.

위 식은 최적화하기 어렵습니다. $z$는 무수히 많은 경우가 존재할 수 있는데 가능한 모든 $z$에 대해서 고려해야 하기 때문입니다. 이럴 때 써먹는 것이 변분추론(Variational Inference)입니다. 변분추론은 계산이 어려운 확률분포를, 다루기 쉬운 분포 $q(z)$로 근사하는 방법입니다. 한편 $p(x)$는 베이즈 정리에서 evidence라고 이름이 붙여진 항인데요. 몇 가지 수식 유도 과정을 거치면 evidence의 하한(ELBO)을 다음과 같이 구할 수 있습니다.

계산이 쉬운 위 부등식 우변, 즉 ELBO를 최대화하면 $\log{p(x)}$를 최대화할 수 있을 것입니다. 일반적인 변분추론에서는 $q(z)$를 정규분포로 정합니다. 예컨대 다음과 같습니다.

그런데 데이터 $x$가 고차원일 때는 $q$를 위와 같이 정하게 되면 학습이 대단히 어렵다고 합니다. 그도 그럴 것이 모든 데이터에 대해 동일한 평균과 분산, 즉 단 하나의 정규분포를 가정하게 되는 셈이니, 데이터가 복잡한 데 비해 모델이 너무 단순하기 때문인 것 같습니다. VAE에서는 이 문제를 해결하기 위해 $q$의 파라메터를 $x$에 대한 함수로 둡니다. 다음과 같습니다.

$q$를 위와 같이 설정하고 ELBO를 최대화하는 방향으로 $q$를 잘 학습하면, $x$가 달라질 때마다 $q$의 분포도 계속 달라지게 됩니다. $x$에 따라 $q$의 모수(평균, 분산)가 바뀌게 되니까요. VAE의 encoder에는, $x$를 받아서 $z$의 평균과 분산을 만들어내는 뉴럴네트워크 두 개($f_μ$, $f_σ$)가 포함되어 있습니다. 이 덕분에 복잡한 데이터에 대해서도 모델이 적절하게 대응할 수 있게 됩니다.

어쨌든 노이즈를 zero-mean Gaussian에서 하나 뽑아 $f_μ$, $f_σ$가 산출한 평균과 분산을 더하고 곱해줘서 sampled latent vector $z$를 만듭니다. 수식은 다음과 같으며, 이같은 과정을 reparameterization trick이라고 부릅니다. 다시 말해 $z$를 직접 샘플링하는게 아니고 노이즈를 샘플링하는 방식입니다. 이렇게 되면 역전파를 통해 encoder가 산출하면 평균과 분산을 업데이트할 수 있게 됩니다.

VAE는 코드로 보는 것이 훨씬 잘 이해가 되는데요. 지금까지 설명한 내용이 아래 코드에 함축돼 있습니다(pytorch).

def __init__(self):
    self.fc1_1 = nn.Linear(784, hidden_size)
    self.fc1_2 = nn.Linear(784, hidden_size)
    self.relu = nn.ReLU()
                        
def encode(self,x):
    x = x.view(batch_size,-1)
    mu = self.relu(self.fc1_1(x))
    log_var = self.relu(self.fc1_2(x))
    return mu,log_var
    
def reparametrize(self, mu, logvar):
    std = logvar.mul(0.5).exp_()
    eps = torch.FloatTensor(std.size()).normal_()
    return eps.mul(std).add_(mu)

VAE는 latent vector $z$를 위와 같이 만들기 때문에 데이터 $x$가 동일하다 하더라도 $z$는 얼마든지 달라질 수 있고, decoder의 최종 결과물 역시 변종이 발생할 가능성이 있습니다. $z$ 생성 과정에 zero-mean Gaussian으로 뽑은 노이즈가 개입되기 때문입니다. 데이터 $x$를 넣어 다시 $x$가 출력되는 구조의 autoencoder이지만, 맨 앞에 variational이 붙은 이유가 바로 여기에 있는 것 같습니다.

VAE의 목적함수

VAE의 decoder는 데이터의 사후확률 $p(z$|$x)$를 학습합니다. 하지만 사후확률은 계산이 어렵기 때문에 다루기 쉬운 분포 $q(z)$로 근사하는 변분추론 기법을 적용하게 됩니다. 변분추론은 $p(z$|$x)$와 $q(z)$ 사이의 KL Divergence를 계산하고, KLD가 줄어드는 쪽으로 $q(z)$를 조금씩 업데이트해서 $q(z)$를 얻어냅니다. KLD 식을 조금 변형하면 다음과 같은 식을 유도할 수 있습니다. (자세한 내용은 이곳을 참고하시면 좋을 것 같습니다)

그런데 우리는 전 챕터에서 $q$를 정규분포로 두고, $q$의 평균과 분산을 $x$에 대한 함수로 정의한 바 있습니다. 이를 반영하고, evidence인 $\log{p(x)}$ 중심으로 식을 다시 쓰면 아래와 같습니다. 이 식을 $A$라고 두겠습니다.

동일한 확률변수에 대한 KLD 값(위 식 우변 세번째 항)은 항상 양수이므로 아래와 같은 부등식이 항상 성립합니다. 아래 식을 $B$라고 두겠습니다.

따라서 위 부등식 우변, 즉 ELBO를 최대화하면 우리의 목적인 marginal log-likelihood $\log{p(x)}$를 최대화할 수 있게 됩니다. 아울러 $A$와 $B$를 비교하면서 보면 ELBO를 최대화한다는 것은 $q(z$|$x)$와 $p(z$|$x)$ 사이의 KLD를 최소화하는 의미가 됩니다.

딥러닝 모델은 보통 손실함수를 목적함수로 쓰는 경향이 있으므로 위 부등식의 우변에 음수를 곱한 식이 loss function이 되고, 이 함수를 최소화하는 게 학습 목표가 됩니다. 손실함수는 다음과 같습니다.

위 식 우변 첫번째 항은 reconstruction loss에 해당합니다. encoder가 데이터 $x$를 받아서 $q$로부터 $z$를 뽑습니다. decoderencoder가 만든 $z$를 받아서 원 데이터 $x$를 복원합니다. 위 식 우변 첫번째 항은 이 둘 사이의 크로스 엔트로피를 가리킵니다.

위 식 우변 두번째 항은 KL Divergence Regularizer에 해당합니다. VAE는 $z$가 zero-mean Gaussian이라고 가정합니다. 정규분포끼리의 KLD는 분석적인 방식으로 도출 가능합니다. 계산이 쉽다는 말이지요. $q(z$|$x)$를 각 변수가 독립인 다변량 정규분포로 볼 경우 위 식 우변 두번째 항을 다음과 같이 다시 쓸 수 있습니다.

위 식 우변 두번째 항을 최소화한다는 말은 $q$를 zero-mean Gaussian에 가깝게 만든다는 의미입니다. 지금까지 말씀드린 내용을 종합해 VAE의 손실함수를 도식적으로 나타내면 다음과 같습니다.

KLD를 아래처럼 분해해서 다음과 같이 해석하는 것도 가능합니다.

VAE의 목적함수를 PyTorch 코드로 구현한 결과는 다음과 같습니다.

# the Binary Cross Entropy between the target and the output
reconstruction_function = nn.BCELoss(size_average=False)

def loss_function(recon_x, x, mu, logvar):
    BCE = reconstruction_function(recon_x, x)
    KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
    KLD = torch.sum(KLD_element).mul_(-0.5)
    return BCE + KLD

VAE의 장단점

VAE는 GAN에 비해 학습이 안정적인 편이라고 합니다. 손실함수에서 확인할 수 있듯 reconstruction error과 같이 평가 기준이 명확하기 때문입니다. 아울러 데이터뿐 아니라 데이터에 내재한 잠재변수 $z$도 함께 학습할 수 있다는 장점이 있습니다(feature learning). 하지만 출력이 선명하지 않고 평균값 형태로 표시되는 문제, reparameterization trick이 모든 경우에 적용되지 않는 문제 등이 단점으로 꼽힌다고 합니다.



Comments