initial commit

This commit is contained in:
Jiao77
2025-11-24 20:34:50 +08:00
commit 633749886e
15 changed files with 2665 additions and 0 deletions

0
models/__init__.py Normal file
View File

62
models/skeleton_vae.py Normal file
View File

@@ -0,0 +1,62 @@
"""
Simple convolutional VAE for skeleton images (single-channel)
"""
import torch
import torch.nn as nn
class ConvEncoder(nn.Module):
def __init__(self, z_dim=64):
super().__init__()
self.enc = nn.Sequential(
nn.Conv2d(1,32,4,2,1),
nn.ReLU(),
nn.Conv2d(32,64,4,2,1),
nn.ReLU(),
nn.Conv2d(64,128,4,2,1),
nn.ReLU(),
nn.Flatten()
)
self.fc_mu = nn.Linear(128*8*8, z_dim)
self.fc_logvar = nn.Linear(128*8*8, z_dim)
def forward(self, x):
h = self.enc(x)
return self.fc_mu(h), self.fc_logvar(h)
class ConvDecoder(nn.Module):
def __init__(self, z_dim=64):
super().__init__()
self.fc = nn.Linear(z_dim, 128*8*8)
self.dec = nn.Sequential(
nn.Unflatten(1, (128,8,8)),
nn.ConvTranspose2d(128,64,4,2,1),
nn.ReLU(),
nn.ConvTranspose2d(64,32,4,2,1),
nn.ReLU(),
nn.ConvTranspose2d(32,1,4,2,1),
nn.Sigmoid()
)
def forward(self, z):
h = self.fc(z)
return self.dec(h)
class SkeletonVAE(nn.Module):
def __init__(self, z_dim=64):
super().__init__()
self.encoder = ConvEncoder(z_dim)
self.decoder = ConvDecoder(z_dim)
def reparameterize(self, mu, logvar):
std = (0.5 * logvar).exp()
eps = torch.randn_like(std)
return mu + eps * std
def forward(self, x):
mu, logvar = self.encoder(x)
z = self.reparameterize(mu, logvar)
out = self.decoder(z)
return out, mu, logvar