initial commit
This commit is contained in:
62
models/skeleton_vae.py
Normal file
62
models/skeleton_vae.py
Normal file
@@ -0,0 +1,62 @@
|
||||
"""
|
||||
Simple convolutional VAE for skeleton images (single-channel)
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class ConvEncoder(nn.Module):
|
||||
def __init__(self, z_dim=64):
|
||||
super().__init__()
|
||||
self.enc = nn.Sequential(
|
||||
nn.Conv2d(1,32,4,2,1),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(32,64,4,2,1),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(64,128,4,2,1),
|
||||
nn.ReLU(),
|
||||
nn.Flatten()
|
||||
)
|
||||
self.fc_mu = nn.Linear(128*8*8, z_dim)
|
||||
self.fc_logvar = nn.Linear(128*8*8, z_dim)
|
||||
|
||||
def forward(self, x):
|
||||
h = self.enc(x)
|
||||
return self.fc_mu(h), self.fc_logvar(h)
|
||||
|
||||
|
||||
class ConvDecoder(nn.Module):
|
||||
def __init__(self, z_dim=64):
|
||||
super().__init__()
|
||||
self.fc = nn.Linear(z_dim, 128*8*8)
|
||||
self.dec = nn.Sequential(
|
||||
nn.Unflatten(1, (128,8,8)),
|
||||
nn.ConvTranspose2d(128,64,4,2,1),
|
||||
nn.ReLU(),
|
||||
nn.ConvTranspose2d(64,32,4,2,1),
|
||||
nn.ReLU(),
|
||||
nn.ConvTranspose2d(32,1,4,2,1),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
def forward(self, z):
|
||||
h = self.fc(z)
|
||||
return self.dec(h)
|
||||
|
||||
|
||||
class SkeletonVAE(nn.Module):
|
||||
def __init__(self, z_dim=64):
|
||||
super().__init__()
|
||||
self.encoder = ConvEncoder(z_dim)
|
||||
self.decoder = ConvDecoder(z_dim)
|
||||
|
||||
def reparameterize(self, mu, logvar):
|
||||
std = (0.5 * logvar).exp()
|
||||
eps = torch.randn_like(std)
|
||||
return mu + eps * std
|
||||
|
||||
def forward(self, x):
|
||||
mu, logvar = self.encoder(x)
|
||||
z = self.reparameterize(mu, logvar)
|
||||
out = self.decoder(z)
|
||||
return out, mu, logvar
|
||||
Reference in New Issue
Block a user