for textmining

Advanced VAEs

|

이번 글에서는 Variational AutoEncoder(VAE)의 발전된 모델들에 대해 살펴보도록 하겠습니다. 이 글은 전인수 서울대 박사과정이 2017년 12월에 진행한 패스트캠퍼스 강의와 위키피디아 등을 정리했음을 먼저 밝힙니다. PyTorch 코드는 이곳을 참고하였습니다. VAE의 기본적 내용에 대해서는 이곳을 참고하시면 좋을 것 같습니다. 그럼 시작하겠습니다.

Conditional VAE

Conditional VAE(CVAE)란 다음 그림과 같이 기존 VAE 구조를 지도학습(supervised learning)이 가능하도록 바꾼 것입니다. encoderdecoder에 정답 레이블 $y$가 추가된 형태입니다.

encoder에서 $z$를 만들 때 $y$ 정보가 추가됩니다. PyTorch 코드는 다음과 같습니다.

def Q(X, c):
    inputs = torch.cat([X, c], 1)# (X,y)
    z = encoder(inputs)
    z_mu = z[:,:Z_dim]
    z_var = z[:,Z_dim:]
    return z_mu, z_var

decoder에서 $x$를 복원할 때 $y$ 정보가 필요합니다. PyTorch 코드는 다음과 같습니다.

def P(z, c):
    inputs = torch.cat([z, c], 1) # (Z,y)
    x = decoder(inputs)
    return x

CVAE의 손실함수는 $y$의 추가로 수식은 달라지지만, 코드상으로는 기존 VAE와 동일합니다. 마지막 아웃풋에 VAE처럼 $x$만 있기 때문입니다. 다음과 같습니다.

def sample_z(mu, log_var):
    eps = Variable(torch.randn(mb_size, Z_dim))
    return mu + torch.exp(log_var / 2) * eps.cuda()

# Forward
z_mu, z_var = Q(X, c)
z = sample_z(z_mu, z_var)
X_sample = P(z, c)

# Loss
recon_loss = nn.functional.binary_cross_entropy(X_sample, X, size_average=False) / mb_size
kl_loss = torch.mean(0.5 * torch.sum(torch.exp(z_var) + z_mu**2 - 1. - z_var, 1))
loss = recon_loss + kl_loss

학습된 CVAE에 아래 그림과 같이 실제 손으로 3이라고 쓴 그림과 함께 label 정보를 바꿔가며 입력하게 되면 다음과 같이 출력된다고 합니다. 다시 말해 CVAE 모델이 데이터 분포를 학습할 때 범주 정보까지 함께 고려하게 된다는 의미입니다.

Adversarial Autoencoder

Adversarial Autoencoder(AAE)란 VAE에 GAN를 덧입힌 구조입니다. 다음 그림과 같습니다. GAN과 관련 자세한 내용은 이곳을 참고하시면 좋을 것 같습니다.

AAE에서는 기존 VAE 구조가 기존 GAN에서의 생성자(generator) 역할을 합니다. 생성자의 encoder는 데이터 $x$를 받아서 잠재변수 $z$를 샘플링하고, 생성자의 decoder는 이로부터 다시 $x$를 복원합니다. AAE가 기존 VAE와 다른 점은 기존 GAN의 구분자(discriminator) 역할을 하는 네트워크가 추가되었다는 점입니다. 이 구분자는 생성자의 encoder가 샘플링한 가짜 $z$와 $p(z)$로부터 직접 샘플링한 진짜 $z$를 구분하는 역할을 합니다.

이렇게 복잡한 네트워크를 만든 이유는 VAE 특유의 단점 때문입니다. VAE는 사전확률 분포 $p(z)$를 표준정규분포로 가정하고, $q(z$|$x)$를 이와 비슷하게 맞추는 과정에서 학습이 이루어집니다. VAE 아키텍처가 이처럼 구성되어 있는 이유는 표준정규분포와 같이 간단한(?) 확률함수여야 샘플링에 용이하고, KLD 계산을 쉽게 할 수 있기 때문입니다. (자세한 내용은 이곳 참고) 그런데 실제 데이터 분포가 정규분포를 따르지 않거나 이보다 복잡할 경우 VAE 성능에 문제가 발생할 수 있습니다.

그런데 GAN의 경우 모델에 특정 확률분포를 전제할 필요가 없습니다. GAN은 데이터가 어떤 분포를 따르든, 데이터의 실제 분포와 생성자(모델)가 만들어내는 분포 사이의 차이를 줄이도록 학습되기 때문입니다. (자세한 내용은 이곳 참고) VAE의 regularization term을 GAN Loss로 대체할 경우 사전확률과 사후확률 분포를 정규분포 이외에 다른 분포를 쓸 수 있게 돼 모델 선택의 폭이 넓어지는 효과를 누릴 수 있습니다. 어쨌든 개별 데이터 샘플 $x_i$와 사전확률분포 $p(z)$에서 뽑은 $z_i$에 대해 AAE의 학습과정은 다음과 같습니다.

AAE의 PyTorch 코드는 다음과 같습니다. 우선 생성자(encoder, decoder)와 구분자를 정의합니다.

# Encoder
Q = torch.nn.Sequential(
    torch.nn.Linear(X_dim, h_dim),
    torch.nn.ReLU(),
    torch.nn.Linear(h_dim, z_dim))

# Decoder
P = torch.nn.Sequential(
    torch.nn.Linear(z_dim, h_dim),
    torch.nn.ReLU(),
    torch.nn.Linear(h_dim, X_dim),
    torch.nn.Sigmoid())

# Discriminator
D = torch.nn.Sequential(
    torch.nn.Linear(z_dim, h_dim),
    torch.nn.ReLU(),
    torch.nn.Linear(h_dim, 1),
    torch.nn.Sigmoid())

Step1의 reconstruction error를 계산하는 과정은 다음과 같습니다.

""" Reconstruction phase """
z_sample = Q(X)
X_sample = P(z_sample)
recon_loss = nn.binary_cross_entropy(X_sample, X)

Step2의 구분자 학습을 위한 손실을 구하는 과정은 다음과 같습니다. 생성자의 encoder가 샘플링하는 $z$는 가짜, $p(z)$로부터 직접 뽑는 $z$는 진짜라고 레이블을 부여합니다.

# Discriminator
z_real = Variable(torch.randn(mb_size, z_dim))
z_fake = Q(X)
D_real = D(z_real)
D_fake = D(z_fake)
D_loss = -torch.mean(torch.log(D_real) + torch.log(1 - D_fake))

Step3의 생성자 학습을 위한 손실을 구하는 과정은 다음과 같습니다.

# Generator
z_fake = Q(X)
D_fake = D(z_fake)
G_loss = -torch.mean(torch.log(D_fake))

GAN vs VAE

GAN과 VAE의 차이점을 도식적으로 나타낸 표는 다음과 같습니다

Model Optimization Converge Image Quality Generalization 비고
VAE Stochastic Gradient Descent Local Minimum 부드럽고 흐리다 오버피팅 경향이 상대적으로 큼 -
GAN Alternating Stochastic Gradient Descent Saddel points 선명하나 아티팩트가 많다 새로운 영상을 잘 생성해냄 Mode collapsing 문제 발생, 수렴이 어렵다

GAN과 VAE의 장점을 모두 취해 만든 연구로는 Energy-based GAN(EBGAN), Stack GAN 등이 있습니다.

Sketch RNN

Sketch RNN은 VAE에 RNN 구조를 덧입힌 아키텍처입니다. encoderdecoder에 RNN를 썼습니다. 다음 그림과 같습니다.



Comments