63 lines
1.6 KiB
Python
63 lines
1.6 KiB
Python
"""
|
|
Simple convolutional VAE for skeleton images (single-channel)
|
|
"""
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
class ConvEncoder(nn.Module):
|
|
def __init__(self, z_dim=64):
|
|
super().__init__()
|
|
self.enc = nn.Sequential(
|
|
nn.Conv2d(1,32,4,2,1),
|
|
nn.ReLU(),
|
|
nn.Conv2d(32,64,4,2,1),
|
|
nn.ReLU(),
|
|
nn.Conv2d(64,128,4,2,1),
|
|
nn.ReLU(),
|
|
nn.Flatten()
|
|
)
|
|
self.fc_mu = nn.Linear(128*8*8, z_dim)
|
|
self.fc_logvar = nn.Linear(128*8*8, z_dim)
|
|
|
|
def forward(self, x):
|
|
h = self.enc(x)
|
|
return self.fc_mu(h), self.fc_logvar(h)
|
|
|
|
|
|
class ConvDecoder(nn.Module):
|
|
def __init__(self, z_dim=64):
|
|
super().__init__()
|
|
self.fc = nn.Linear(z_dim, 128*8*8)
|
|
self.dec = nn.Sequential(
|
|
nn.Unflatten(1, (128,8,8)),
|
|
nn.ConvTranspose2d(128,64,4,2,1),
|
|
nn.ReLU(),
|
|
nn.ConvTranspose2d(64,32,4,2,1),
|
|
nn.ReLU(),
|
|
nn.ConvTranspose2d(32,1,4,2,1),
|
|
nn.Sigmoid()
|
|
)
|
|
|
|
def forward(self, z):
|
|
h = self.fc(z)
|
|
return self.dec(h)
|
|
|
|
|
|
class SkeletonVAE(nn.Module):
|
|
def __init__(self, z_dim=64):
|
|
super().__init__()
|
|
self.encoder = ConvEncoder(z_dim)
|
|
self.decoder = ConvDecoder(z_dim)
|
|
|
|
def reparameterize(self, mu, logvar):
|
|
std = (0.5 * logvar).exp()
|
|
eps = torch.randn_like(std)
|
|
return mu + eps * std
|
|
|
|
def forward(self, x):
|
|
mu, logvar = self.encoder(x)
|
|
z = self.reparameterize(mu, logvar)
|
|
out = self.decoder(z)
|
|
return out, mu, logvar
|