""" Simple convolutional VAE for skeleton images (single-channel) """ import torch import torch.nn as nn class ConvEncoder(nn.Module): def __init__(self, z_dim=64): super().__init__() self.enc = nn.Sequential( nn.Conv2d(1,32,4,2,1), nn.ReLU(), nn.Conv2d(32,64,4,2,1), nn.ReLU(), nn.Conv2d(64,128,4,2,1), nn.ReLU(), nn.Flatten() ) self.fc_mu = nn.Linear(128*8*8, z_dim) self.fc_logvar = nn.Linear(128*8*8, z_dim) def forward(self, x): h = self.enc(x) return self.fc_mu(h), self.fc_logvar(h) class ConvDecoder(nn.Module): def __init__(self, z_dim=64): super().__init__() self.fc = nn.Linear(z_dim, 128*8*8) self.dec = nn.Sequential( nn.Unflatten(1, (128,8,8)), nn.ConvTranspose2d(128,64,4,2,1), nn.ReLU(), nn.ConvTranspose2d(64,32,4,2,1), nn.ReLU(), nn.ConvTranspose2d(32,1,4,2,1), nn.Sigmoid() ) def forward(self, z): h = self.fc(z) return self.dec(h) class SkeletonVAE(nn.Module): def __init__(self, z_dim=64): super().__init__() self.encoder = ConvEncoder(z_dim) self.decoder = ConvDecoder(z_dim) def reparameterize(self, mu, logvar): std = (0.5 * logvar).exp() eps = torch.randn_like(std) return mu + eps * std def forward(self, x): mu, logvar = self.encoder(x) z = self.reparameterize(mu, logvar) out = self.decoder(z) return out, mu, logvar