initial commit

This commit is contained in:
Jiao77
2025-11-24 20:34:50 +08:00
commit 633749886e
15 changed files with 2665 additions and 0 deletions

136
train/train_skeleton_vae.py Normal file
View File

@@ -0,0 +1,136 @@
#!/usr/bin/env python3
"""
Minimal training loop for SkeletonVAE. Expects a directory of skeleton PNGs.
"""
import argparse
from pathlib import Path
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torchvision.transforms as T
import numpy as np
import logging
import time
from models.skeleton_vae import SkeletonVAE
class SkeletonDataset(Dataset):
def __init__(self, folder, size=64):
self.files = list(Path(folder).glob('**/*_sk.png'))
self.tr = T.Compose([T.Resize((size,size)), T.ToTensor()])
def __len__(self):
return len(self.files)
def __getitem__(self, idx):
p = self.files[idx]
im = Image.open(p).convert('L')
t = self.tr(im)
return t
def loss_fn(recon_x, x, mu, logvar):
BCE = torch.nn.functional.binary_cross_entropy(recon_x, x, reduction='sum')
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return BCE + KLD, BCE, KLD
def setup_logger(log_dir):
log_dir = Path(log_dir)
log_dir.mkdir(parents=True, exist_ok=True)
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(log_dir / 'train.log'),
logging.StreamHandler()
],
force=True
)
return logging.getLogger(__name__)
def main():
p = argparse.ArgumentParser()
p.add_argument('skel_folder')
p.add_argument('--epochs', type=int, default=10)
p.add_argument('--batch', type=int, default=16)
p.add_argument('--lr', type=float, default=1e-3)
p.add_argument('--save_path', type=str, default='out/vae_model.pth')
p.add_argument('--sample_dir', type=str, default='out/samples')
args = p.parse_args()
Path(args.sample_dir).mkdir(parents=True, exist_ok=True)
Path(args.save_path).parent.mkdir(parents=True, exist_ok=True)
logger = setup_logger(Path(args.save_path).parent)
ds = SkeletonDataset(args.skel_folder, size=64)
if len(ds) == 0:
logger.error(f"No skeleton images found in {args.skel_folder}")
return
dl = DataLoader(ds, batch_size=args.batch, shuffle=True, num_workers=4)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"Using device: {device}")
logger.info(f"Dataset size: {len(ds)} images")
model = SkeletonVAE(z_dim=64).to(device)
opt = torch.optim.Adam(model.parameters(), lr=args.lr)
logger.info("Starting training...")
for epoch in range(args.epochs):
start_time = time.time()
model.train()
total_loss = 0.0
total_bce = 0.0
total_kld = 0.0
for i, xb in enumerate(dl):
xb = xb.to(device)
recon, mu, logvar = model(xb)
loss, bce, kld = loss_fn(recon, xb, mu, logvar)
opt.zero_grad()
loss.backward()
opt.step()
total_loss += loss.item()
total_bce += bce.item()
total_kld += kld.item()
if (i + 1) % 10 == 0:
logger.info(f"Epoch [{epoch+1}/{args.epochs}] Batch [{i+1}/{len(dl)}] "
f"Loss: {loss.item()/len(xb):.4f} (BCE: {bce.item()/len(xb):.4f}, KLD: {kld.item()/len(xb):.4f})")
avg_loss = total_loss / len(ds)
avg_bce = total_bce / len(ds)
avg_kld = total_kld / len(ds)
epoch_time = time.time() - start_time
logger.info(f'Epoch {epoch+1}/{args.epochs} completed in {epoch_time:.2f}s')
logger.info(f' Average Loss: {avg_loss:.4f}')
logger.info(f' BCE: {avg_bce:.4f} | KLD: {avg_kld:.4f}')
# Save model
torch.save(model.state_dict(), args.save_path)
# Save sample reconstructions
if (epoch + 1) % 5 == 0:
logger.info(f"Saving reconstruction samples to {args.sample_dir}")
with torch.no_grad():
model.eval()
# Reconstruct first batch
xb = next(iter(dl)).to(device)
recon, _, _ = model(xb)
# Concat input and recon
vis = torch.cat([xb[:8], recon[:8]], dim=0)
# Save grid
from torchvision.utils import save_image
save_image(vis, Path(args.sample_dir) / f'epoch_{epoch+1}.png', nrow=8)
if __name__ == '__main__':
main()