본문 바로가기
딥러닝 기초

Variational Autoencoder (VAE)와 ELBO, KL Divergence 이해하기

by 루루트 2025. 2. 13.
반응형

Variational Autoencoder(VAE)는 생성 모델 중 하나로, 데이터의 잠재 표현(latent representation)을 학습하면서 데이터를 재구성하는 모델입니다. VAE를 학습하기 위해 중요한 개념인 Evidence Lower Bound (ELBO)와 KL Divergence에 대해 자세히 알아보겠습니다.

 

목록

1. VAE의 기본 개념
2. Evidence Lower Bound (ELBO)
3. KL Divergence의 형태 유도
4. MNIST 데이터를 활용한 VAE 구현 예제 (PyTorch)
5. 결론

 

 

1. VAE의 기본 개념

VAE는 두 가지 네트워크로 구성됩니다.

  • 인코더(Encoder): 입력 데이터 $x$로부터 잠재 변수 $z$의 근사 분포 $q(z|x)$ (보통 정규 분포의 평균 $\mu$와 분산 $\sigma^2$로 표현)를 추정합니다.
  • 디코더(Decoder): 인코더에서 샘플링한 $z$를 바탕으로 입력 데이터 $x$를 재구성합니다.

VAE의 데이터 생성 과정 다이어그램

1. 잠재 변수 생성:

$z \sim p(z)$ (예: $p(z)=\cal{N}(0, I)$)

2. 데이터 생성:

$x \sim p(x|z)$

 

결합 확률 분포는 다음과 같이 정의됩니다.

$$ p(x,z) = p(z)p(x|z) $$

 

2. Evidence Lower Bound (ELBO)

목표로 하는 것은, 입력 이미지 $x$를 어떤 의미 있는 잠재 공간 $z$로 투영하는 것입니다. 그러나 실제 후방 분포 $p(z|x)$는 보통 계산하기 어렵거나 복잡하기 때문에, 대신에 $q(z|x)$라는 더 단순하고 다루기 쉬운 근사 분포를 사용하여 그 역할을 대신합니다.

 

$$
\begin{align*}
D_\text{KL}(q(z|x)||p(z|x)) &= \mathbb{E}_{q(z|x)} \big[ \log \frac{q(z|x)}{p(z|x)} \big] \\
                                         &= \mathbb{E}_{q(z|x)} \big[ \log \frac{q(z|x)p(x)}{p(x,z)} \big] \\
                                         &= \mathbb{E}_{q(z|x)} \big[ \log \frac{q(z|x)}{p(x,z)} \big]+\log(p(x)), \\
                                         &= \mathbb{E}_{q(z|x)} \big[ \log \frac{q(z|x)}{p(x,z)} \big]+\log(p(x)) \\
\log(p(x))  &= - \mathbb{E}_{q(z|x)} \big[ \log \frac{q(z|x)}{p(x,z)} \big] + D_\text{KL}(q(z|x)||p(z|x)), \\
\log(p(x))  &= \mathbb{E}_{q(z|x)} \big[ \log \frac{p(x,z)}{q(z|x)} \big] + D_\text{KL}(q(z|x)||p(z|x)).
\end{align*}
$$

KL divergence는 0보다 크거나 같으므로

$$ \log(p(x)) \geq \mathbb{E}_{q(z|x)} \big[ \log \frac{p(x,z)}{q(z|x)} \big] $$

 

따라서,  Evidence Lower Bound (ELBO)는 다음과 같이 정의됩니다.

$$ \text{ELBO} = \mathbb{E}_{q(z|x)} \big[ \text{log}\frac{p(x,z)}{q(z|x)} \big] $$.

 

식을 변형하면

$$
\begin{align*}
\text{ELBO} &= \mathbb{E}_{q(z|x)} \big[ \text{log}\frac{p(x,z)}{q(z|x)} \big] \\
&= \mathbb{E}_{q(z|x)} \big[ \text{log}\frac{p(x|z)p(z)}{q(z|x)} \big] \\
&= \mathbb{E}_{q(z|x)} [\log p(x|z)] + \mathbb{E}_{q(z|x)} \big[ \text{log}\frac{p(z)}{q(z|x)} \big] \\
&= \mathbb{E}_{q(z|x)} [\log p(x|z)] - \mathbb{E}_{q(z|x)} \big[ \text{log}\frac{q(z|x)}{p(z)} \big] \\
&= \mathbb{E}_{q(z|x)} [\log p(x|z)] - D_\text{KL} ( q(z|x) || p(z) ) 
\end{align*}
$$

 

$$ \text{ELBO} = \mathbb{E}_{q(z|x)} [\log p(x|z)]-D_\text{KL} ( q(z|x) || p(z) ) $$.

이를 두 항으로 나누면,

  • 재구성 항: 디코더가 $z$로부터 $x$를 얼마나 잘 재구성하는지 평가합니다.
  • KL 정규화 항: 인코더가 예측한 분포 $q(z|x)$가 사전 분포 $p(z)$ (보통 $\cal{N} (0,I)$)와 크게 벗어나지 않도록 합니다.

학습 시에는 ELBO를 최대화하는 대신, -ELBO를 손실 함수로 사용하여 최소화합니다.

 

3. KL Divergence의 형태 유도

두 정규 분포 $q(z|x)=\cal{N}(z;\mu, \sigma^2)$와 $p(z)=\cal{N}(z;0, I)$ 사이의 KL Divergence는 아래와 같이 계산할 수 있습니다.

KL Divergence 정의

$$D_\text{KL}(q(z|x) || p(z)) = \int q(z|x) \text{log} \frac{q(z|x)}{p(z)} dz $$.

 

1) 각 분포의 로그 확률 밀도

  • $q(z|x)$의 로그 밀도:

$$ \text{log} q(z|x) = -\frac{1}{2} \text{log} (2\pi\sigma^2) - \frac{(z-\mu)^2}{2\sigma^2} $$.

  • $p(z)$의 로그 밀도:

$$ \text{log} p(z) = -\frac{1}{2} \text{log} (2\pi) - \frac{z^2}{2} $$.

 

2) 로그 비율 계산

$$
\begin{align*}
\log \frac{q(z|x)}{p(z)} &= \log q(z|x) - \log p(z) \\
                         &= -\frac{1}{2}\log (2\pi\sigma^2) - \frac{(z-\mu)^2}{2\sigma^2} + \frac{1}{2}\log (2\pi) + \frac{z^2}{2} \\
                         &= -\frac{1}{2}\log (\sigma^2) - \frac{(z-\mu)^2}{2\sigma^2} + \frac{z^2}{2}
\end{align*}
$$
 

에 대한 기댓값 취하기

$$
\begin{align*}
D_\text{KL}(q(z|x) \,\|\, p(z)) 
  &= \mathbb{E}_{q(z|x)} \left[ -\frac{1}{2}\log (\sigma^2) - \frac{(z-\mu)^2}{2\sigma^2} + \frac{z^2}{2} \right] \\
  &= -\frac{1}{2}\log (\sigma^2) - \frac{1}{2\sigma^2} \mathbb{E}_{q(z|x)} \left[ (z-\mu)^2 \right] + \frac{1}{2} \mathbb{E}_{q(z|x)} \left[ z^2 \right]
\end{align*}
$$

 

정규 분포 $\cal{N}(\mu, \sigma^2)$의 성질에 의해,

$$ \mathbb{E}_{q(z|x)} \left[ (z-\mu)^2 \right] = \sigma^2, \mathbb{E}_{q(z|x)} \left[ z^2 \right] = \sigma^2+\mu^2  $$

 

따라서,

$$
\begin{align*}
D_\text{KL}(q(z|x) \,\|\, p(z)) 
  &= -\frac{1}{2}\log (\sigma^2) - \frac{1}{2} + \frac{1}{2} \left[ \sigma^2+\mu^2 \right] \\
  &= \frac{1}{2}\Bigl(\sigma^2+\mu^2-1-\log(\sigma^2)\Bigr)
\end{align*}
$$

만약 $z$가 다차원이라면 각 차원에 대해 합산합니다.

4. MNIST 데이터를 활용한 VAE 구현 예제 (PyTorch)

아래 코드는 MNIST 데이터셋(0부터 9까지의 숫자 이미지)을 사용하여 VAE를 학습하는 예제입니다.

Colab 바로가기: https://colab.research.google.com/drive/1Qpe8ARnbBcrfORItTP94ueUqBmLgoN_j?usp=sharing

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# VAE 모델 정의
class VAE(nn.Module):
    def __init__(self, latent_dim=20):
        super(VAE, self).__init__()
        self.latent_dim = latent_dim
        
        # Encoder: MNIST 이미지 (28x28) -> 잠재 변수의 평균과 로그 분산
        self.fc1 = nn.Linear(28 * 28, 400)
        self.fc_mu = nn.Linear(400, latent_dim)
        self.fc_logvar = nn.Linear(400, latent_dim)
        
        # Decoder: 잠재 변수 -> MNIST 이미지 복원
        self.fc3 = nn.Linear(latent_dim, 400)
        self.fc4 = nn.Linear(400, 28 * 28)
    
    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        mu = self.fc_mu(h1)
        logvar = self.fc_logvar(h1)
        return mu, logvar
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        # Sigmoid 활성화를 통해 0과 1 사이의 값 출력 (이미지 재구성)
        return torch.sigmoid(self.fc4(h3))
    
    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 28 * 28))
        z = self.reparameterize(mu, logvar)
        recon_x = self.decode(z)
        return recon_x, mu, logvar

# Loss 함수: -ELBO
def loss_function(recon_x, x, mu, logvar):
    # 재구성 손실 (Binary Cross-Entropy)
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 28 * 28), reduction='sum')
    # KL Divergence: closed-form expression
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

# MNIST 데이터셋 로드 (0부터 9까지의 숫자)
batch_size = 128
transform = transforms.ToTensor()

train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset  = datasets.MNIST('./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader  = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# 모델, 옵티마이저 정의
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VAE(latent_dim=20).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# 학습 루프
def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % 100 == 0:
            print(f"Epoch {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)}] Loss: {loss.item() / len(data):.4f}")
    print(f"====> Epoch {epoch} Average loss: {train_loss / len(train_loader.dataset):.4f}")

def test(epoch):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        # 첫 번째 배치만 가져와서 시각화를 위해 저장합니다.
        for data, _ in test_loader:
            data = data.to(device)
            recon, mu, logvar = model(data)
            test_loss += loss_function(recon, data, mu, logvar).item()
            orig = data  # 원본 이미지 저장
            recon_images = recon  # 재구성 이미지 저장
            break  # 첫 번째 배치만 사용

    test_loss /= len(test_loader.dataset)
    print(f"====> Test set loss: {test_loss:.4f}")

    # 시각화: 원본 이미지(Top)와 재구성 이미지(Bottom)를 출력
    import matplotlib.pyplot as plt
    n = 8  # 출력할 이미지 수
    plt.figure(figsize=(16, 4))
    
    # 첫 번째 행: 원본 이미지
    for i in range(n):
        ax = plt.subplot(2, n, i + 1)
        plt.imshow(orig[i].cpu().view(28, 28), cmap='gray')
        ax.axis('off')
    
    # 두 번째 행: 재구성 이미지
    for i in range(n):
        ax = plt.subplot(2, n, i + 1 + n)
        plt.imshow(recon_images[i].cpu().view(28, 28), cmap='gray')
        ax.axis('off')
    
    plt.suptitle(f'Epoch {epoch}: Original (Top) vs Reconstructed (Bottom)')
    plt.show()

# 실제 학습 실행
num_epochs = 10
for epoch in range(1, num_epochs + 1):
    train(epoch)
    test(epoch)

코드 설명

$$ z = \mu + \sigma \cdot \epsilon, \epsilon \sim \mathcal{N}(0,I) $$

  • Reparameterize trick: 미분 가능한 재파라미터화 기법(reparameterization trick)을 사용하면, 직접 $z \sim \mathcal{N}(\mu, \sigma^2)$에서 샘플링하는 대신, 표준 정규분포 $\epsilon \sim \mathcal{N}(0, I)$로부터 샘플링한 후 위와 같이 변환하여 미분 가능하게 만듭니다.

5. 결론

이 글에서는 VAE의 핵심 개념인 ELBO와 KL Divergence에 대해 정리하고, MNIST 데이터셋을 활용한 PyTorch 기반의 VAE 구현 예제를 소개했습니다.

반응형