#!/usr/bin/env python3 """ Minimal training loop for SkeletonVAE. Expects a directory of skeleton PNGs. """ import argparse from pathlib import Path import torch from torch.utils.data import Dataset, DataLoader from PIL import Image import torchvision.transforms as T import numpy as np import logging import time from models.skeleton_vae import SkeletonVAE class SkeletonDataset(Dataset): def __init__(self, folder, size=64): self.files = list(Path(folder).glob('**/*_sk.png')) self.tr = T.Compose([T.Resize((size,size)), T.ToTensor()]) def __len__(self): return len(self.files) def __getitem__(self, idx): p = self.files[idx] im = Image.open(p).convert('L') t = self.tr(im) return t def loss_fn(recon_x, x, mu, logvar): BCE = torch.nn.functional.binary_cross_entropy(recon_x, x, reduction='sum') KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) return BCE + KLD, BCE, KLD def setup_logger(log_dir): log_dir = Path(log_dir) log_dir.mkdir(parents=True, exist_ok=True) logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler(log_dir / 'train.log'), logging.StreamHandler() ], force=True ) return logging.getLogger(__name__) def main(): p = argparse.ArgumentParser() p.add_argument('skel_folder') p.add_argument('--epochs', type=int, default=10) p.add_argument('--batch', type=int, default=16) p.add_argument('--lr', type=float, default=1e-3) p.add_argument('--save_path', type=str, default='out/vae_model.pth') p.add_argument('--sample_dir', type=str, default='out/samples') args = p.parse_args() Path(args.sample_dir).mkdir(parents=True, exist_ok=True) Path(args.save_path).parent.mkdir(parents=True, exist_ok=True) logger = setup_logger(Path(args.save_path).parent) ds = SkeletonDataset(args.skel_folder, size=64) if len(ds) == 0: logger.error(f"No skeleton images found in {args.skel_folder}") return dl = DataLoader(ds, batch_size=args.batch, shuffle=True, num_workers=4) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') logger.info(f"Using device: {device}") logger.info(f"Dataset size: {len(ds)} images") model = SkeletonVAE(z_dim=64).to(device) opt = torch.optim.Adam(model.parameters(), lr=args.lr) logger.info("Starting training...") for epoch in range(args.epochs): start_time = time.time() model.train() total_loss = 0.0 total_bce = 0.0 total_kld = 0.0 for i, xb in enumerate(dl): xb = xb.to(device) recon, mu, logvar = model(xb) loss, bce, kld = loss_fn(recon, xb, mu, logvar) opt.zero_grad() loss.backward() opt.step() total_loss += loss.item() total_bce += bce.item() total_kld += kld.item() if (i + 1) % 10 == 0: logger.info(f"Epoch [{epoch+1}/{args.epochs}] Batch [{i+1}/{len(dl)}] " f"Loss: {loss.item()/len(xb):.4f} (BCE: {bce.item()/len(xb):.4f}, KLD: {kld.item()/len(xb):.4f})") avg_loss = total_loss / len(ds) avg_bce = total_bce / len(ds) avg_kld = total_kld / len(ds) epoch_time = time.time() - start_time logger.info(f'Epoch {epoch+1}/{args.epochs} completed in {epoch_time:.2f}s') logger.info(f' Average Loss: {avg_loss:.4f}') logger.info(f' BCE: {avg_bce:.4f} | KLD: {avg_kld:.4f}') # Save model torch.save(model.state_dict(), args.save_path) # Save sample reconstructions if (epoch + 1) % 5 == 0: logger.info(f"Saving reconstruction samples to {args.sample_dir}") with torch.no_grad(): model.eval() # Reconstruct first batch xb = next(iter(dl)).to(device) recon, _, _ = model(xb) # Concat input and recon vis = torch.cat([xb[:8], recon[:8]], dim=0) # Save grid from torchvision.utils import save_image save_image(vis, Path(args.sample_dir) / f'epoch_{epoch+1}.png', nrow=8) if __name__ == '__main__': main()