Variational Autoencoder(VAE)는 생성 모델 중 하나로, 데이터의 잠재 표현(latent representation)을 학습하면서 데이터를 재구성하는 모델입니다. VAE를 학습하기 위해 중요한 개념인 Evidence Lower Bound (ELBO)와 KL Divergence에 대해 자세히 알아보겠습니다.
목록
1. VAE의 기본 개념 |
2. Evidence Lower Bound (ELBO) |
3. KL Divergence의 형태 유도 |
4. MNIST 데이터를 활용한 VAE 구현 예제 (PyTorch) |
5. 결론 |
1. VAE의 기본 개념
VAE는 두 가지 네트워크로 구성됩니다.
- 인코더(Encoder): 입력 데이터 $x$로부터 잠재 변수 $z$의 근사 분포 $q(z|x)$ (보통 정규 분포의 평균 $\mu$와 분산 $\sigma^2$로 표현)를 추정합니다.
- 디코더(Decoder): 인코더에서 샘플링한 $z$를 바탕으로 입력 데이터 $x$를 재구성합니다.
VAE의 데이터 생성 과정 다이어그램
1. 잠재 변수 생성:
$z \sim p(z)$ (예: $p(z)=\cal{N}(0, I)$)
2. 데이터 생성:
$x \sim p(x|z)$
결합 확률 분포는 다음과 같이 정의됩니다.
$$ p(x,z) = p(z)p(x|z) $$
2. Evidence Lower Bound (ELBO)
목표로 하는 것은, 입력 이미지 $x$를 어떤 의미 있는 잠재 공간 $z$로 투영하는 것입니다. 그러나 실제 후방 분포 $p(z|x)$는 보통 계산하기 어렵거나 복잡하기 때문에, 대신에 $q(z|x)$라는 더 단순하고 다루기 쉬운 근사 분포를 사용하여 그 역할을 대신합니다.
$$
\begin{align*}
D_\text{KL}(q(z|x)||p(z|x)) &= \mathbb{E}_{q(z|x)} \big[ \log \frac{q(z|x)}{p(z|x)} \big] \\
&= \mathbb{E}_{q(z|x)} \big[ \log \frac{q(z|x)p(x)}{p(x,z)} \big] \\
&= \mathbb{E}_{q(z|x)} \big[ \log \frac{q(z|x)}{p(x,z)} \big]+\log(p(x)), \\
&= \mathbb{E}_{q(z|x)} \big[ \log \frac{q(z|x)}{p(x,z)} \big]+\log(p(x)) \\
\log(p(x)) &= - \mathbb{E}_{q(z|x)} \big[ \log \frac{q(z|x)}{p(x,z)} \big] + D_\text{KL}(q(z|x)||p(z|x)), \\
\log(p(x)) &= \mathbb{E}_{q(z|x)} \big[ \log \frac{p(x,z)}{q(z|x)} \big] + D_\text{KL}(q(z|x)||p(z|x)).
\end{align*}
$$
KL divergence는 0보다 크거나 같으므로
$$ \log(p(x)) \geq \mathbb{E}_{q(z|x)} \big[ \log \frac{p(x,z)}{q(z|x)} \big] $$
따라서, Evidence Lower Bound (ELBO)는 다음과 같이 정의됩니다.
$$ \text{ELBO} = \mathbb{E}_{q(z|x)} \big[ \text{log}\frac{p(x,z)}{q(z|x)} \big] $$.
식을 변형하면
$$
\begin{align*}
\text{ELBO} &= \mathbb{E}_{q(z|x)} \big[ \text{log}\frac{p(x,z)}{q(z|x)} \big] \\
&= \mathbb{E}_{q(z|x)} \big[ \text{log}\frac{p(x|z)p(z)}{q(z|x)} \big] \\
&= \mathbb{E}_{q(z|x)} [\log p(x|z)] + \mathbb{E}_{q(z|x)} \big[ \text{log}\frac{p(z)}{q(z|x)} \big] \\
&= \mathbb{E}_{q(z|x)} [\log p(x|z)] - \mathbb{E}_{q(z|x)} \big[ \text{log}\frac{q(z|x)}{p(z)} \big] \\
&= \mathbb{E}_{q(z|x)} [\log p(x|z)] - D_\text{KL} ( q(z|x) || p(z) )
\end{align*}
$$
$$ \text{ELBO} = \mathbb{E}_{q(z|x)} [\log p(x|z)]-D_\text{KL} ( q(z|x) || p(z) ) $$.
이를 두 항으로 나누면,
- 재구성 항: 디코더가 $z$로부터 $x$를 얼마나 잘 재구성하는지 평가합니다.
- KL 정규화 항: 인코더가 예측한 분포 $q(z|x)$가 사전 분포 $p(z)$ (보통 $\cal{N} (0,I)$)와 크게 벗어나지 않도록 합니다.
학습 시에는 ELBO를 최대화하는 대신, -ELBO를 손실 함수로 사용하여 최소화합니다.
3. KL Divergence의 형태 유도
두 정규 분포 $q(z|x)=\cal{N}(z;\mu, \sigma^2)$와 $p(z)=\cal{N}(z;0, I)$ 사이의 KL Divergence는 아래와 같이 계산할 수 있습니다.
KL Divergence 정의
$$D_\text{KL}(q(z|x) || p(z)) = \int q(z|x) \text{log} \frac{q(z|x)}{p(z)} dz $$.
1) 각 분포의 로그 확률 밀도
- $q(z|x)$의 로그 밀도:
$$ \text{log} q(z|x) = -\frac{1}{2} \text{log} (2\pi\sigma^2) - \frac{(z-\mu)^2}{2\sigma^2} $$.
- $p(z)$의 로그 밀도:
$$ \text{log} p(z) = -\frac{1}{2} \text{log} (2\pi) - \frac{z^2}{2} $$.
2) 로그 비율 계산
$$
\begin{align*}
\log \frac{q(z|x)}{p(z)} &= \log q(z|x) - \log p(z) \\
&= -\frac{1}{2}\log (2\pi\sigma^2) - \frac{(z-\mu)^2}{2\sigma^2} + \frac{1}{2}\log (2\pi) + \frac{z^2}{2} \\
&= -\frac{1}{2}\log (\sigma^2) - \frac{(z-\mu)^2}{2\sigma^2} + \frac{z^2}{2}
\end{align*}
$$
에 대한 기댓값 취하기
$$
\begin{align*}
D_\text{KL}(q(z|x) \,\|\, p(z))
&= \mathbb{E}_{q(z|x)} \left[ -\frac{1}{2}\log (\sigma^2) - \frac{(z-\mu)^2}{2\sigma^2} + \frac{z^2}{2} \right] \\
&= -\frac{1}{2}\log (\sigma^2) - \frac{1}{2\sigma^2} \mathbb{E}_{q(z|x)} \left[ (z-\mu)^2 \right] + \frac{1}{2} \mathbb{E}_{q(z|x)} \left[ z^2 \right]
\end{align*}
$$
정규 분포 $\cal{N}(\mu, \sigma^2)$의 성질에 의해,
$$ \mathbb{E}_{q(z|x)} \left[ (z-\mu)^2 \right] = \sigma^2, \mathbb{E}_{q(z|x)} \left[ z^2 \right] = \sigma^2+\mu^2 $$
따라서,
$$
\begin{align*}
D_\text{KL}(q(z|x) \,\|\, p(z))
&= -\frac{1}{2}\log (\sigma^2) - \frac{1}{2} + \frac{1}{2} \left[ \sigma^2+\mu^2 \right] \\
&= \frac{1}{2}\Bigl(\sigma^2+\mu^2-1-\log(\sigma^2)\Bigr)
\end{align*}
$$
만약 $z$가 다차원이라면 각 차원에 대해 합산합니다.
4. MNIST 데이터를 활용한 VAE 구현 예제 (PyTorch)
아래 코드는 MNIST 데이터셋(0부터 9까지의 숫자 이미지)을 사용하여 VAE를 학습하는 예제입니다.
Colab 바로가기: https://colab.research.google.com/drive/1Qpe8ARnbBcrfORItTP94ueUqBmLgoN_j?usp=sharing
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# VAE 모델 정의
class VAE(nn.Module):
def __init__(self, latent_dim=20):
super(VAE, self).__init__()
self.latent_dim = latent_dim
# Encoder: MNIST 이미지 (28x28) -> 잠재 변수의 평균과 로그 분산
self.fc1 = nn.Linear(28 * 28, 400)
self.fc_mu = nn.Linear(400, latent_dim)
self.fc_logvar = nn.Linear(400, latent_dim)
# Decoder: 잠재 변수 -> MNIST 이미지 복원
self.fc3 = nn.Linear(latent_dim, 400)
self.fc4 = nn.Linear(400, 28 * 28)
def encode(self, x):
h1 = F.relu(self.fc1(x))
mu = self.fc_mu(h1)
logvar = self.fc_logvar(h1)
return mu, logvar
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, z):
h3 = F.relu(self.fc3(z))
# Sigmoid 활성화를 통해 0과 1 사이의 값 출력 (이미지 재구성)
return torch.sigmoid(self.fc4(h3))
def forward(self, x):
mu, logvar = self.encode(x.view(-1, 28 * 28))
z = self.reparameterize(mu, logvar)
recon_x = self.decode(z)
return recon_x, mu, logvar
# Loss 함수: -ELBO
def loss_function(recon_x, x, mu, logvar):
# 재구성 손실 (Binary Cross-Entropy)
BCE = F.binary_cross_entropy(recon_x, x.view(-1, 28 * 28), reduction='sum')
# KL Divergence: closed-form expression
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return BCE + KLD
# MNIST 데이터셋 로드 (0부터 9까지의 숫자)
batch_size = 128
transform = transforms.ToTensor()
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
# 모델, 옵티마이저 정의
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VAE(latent_dim=20).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
# 학습 루프
def train(epoch):
model.train()
train_loss = 0
for batch_idx, (data, _) in enumerate(train_loader):
data = data.to(device)
optimizer.zero_grad()
recon_batch, mu, logvar = model(data)
loss = loss_function(recon_batch, data, mu, logvar)
loss.backward()
train_loss += loss.item()
optimizer.step()
if batch_idx % 100 == 0:
print(f"Epoch {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)}] Loss: {loss.item() / len(data):.4f}")
print(f"====> Epoch {epoch} Average loss: {train_loss / len(train_loader.dataset):.4f}")
def test(epoch):
model.eval()
test_loss = 0
with torch.no_grad():
# 첫 번째 배치만 가져와서 시각화를 위해 저장합니다.
for data, _ in test_loader:
data = data.to(device)
recon, mu, logvar = model(data)
test_loss += loss_function(recon, data, mu, logvar).item()
orig = data # 원본 이미지 저장
recon_images = recon # 재구성 이미지 저장
break # 첫 번째 배치만 사용
test_loss /= len(test_loader.dataset)
print(f"====> Test set loss: {test_loss:.4f}")
# 시각화: 원본 이미지(Top)와 재구성 이미지(Bottom)를 출력
import matplotlib.pyplot as plt
n = 8 # 출력할 이미지 수
plt.figure(figsize=(16, 4))
# 첫 번째 행: 원본 이미지
for i in range(n):
ax = plt.subplot(2, n, i + 1)
plt.imshow(orig[i].cpu().view(28, 28), cmap='gray')
ax.axis('off')
# 두 번째 행: 재구성 이미지
for i in range(n):
ax = plt.subplot(2, n, i + 1 + n)
plt.imshow(recon_images[i].cpu().view(28, 28), cmap='gray')
ax.axis('off')
plt.suptitle(f'Epoch {epoch}: Original (Top) vs Reconstructed (Bottom)')
plt.show()
# 실제 학습 실행
num_epochs = 10
for epoch in range(1, num_epochs + 1):
train(epoch)
test(epoch)
코드 설명
$$ z = \mu + \sigma \cdot \epsilon, \epsilon \sim \mathcal{N}(0,I) $$
- Reparameterize trick: 미분 가능한 재파라미터화 기법(reparameterization trick)을 사용하면, 직접 $z \sim \mathcal{N}(\mu, \sigma^2)$에서 샘플링하는 대신, 표준 정규분포 $\epsilon \sim \mathcal{N}(0, I)$로부터 샘플링한 후 위와 같이 변환하여 미분 가능하게 만듭니다.
5. 결론
이 글에서는 VAE의 핵심 개념인 ELBO와 KL Divergence에 대해 정리하고, MNIST 데이터셋을 활용한 PyTorch 기반의 VAE 구현 예제를 소개했습니다.
'딥러닝 기초' 카테고리의 다른 글
CLIP score, CLIP aesthetics score 란? (0) | 2024.09.04 |
---|---|
5분 안에 이해하는 ControlNet 간단 정리 자료 (0) | 2024.08.30 |
인공신경망을 이용한 MNIST 손글씨 분류하기 (1) | 2020.07.31 |
Softmax 회귀로 MNIST 손글씨 분류하기 (0) | 2020.07.29 |
머신러닝(기계학습)과 딥러닝의 차이점 (0) | 2020.07.29 |