initial commit
This commit is contained in:
136
train/train_skeleton_vae.py
Normal file
136
train/train_skeleton_vae.py
Normal file
@@ -0,0 +1,136 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Minimal training loop for SkeletonVAE. Expects a directory of skeleton PNGs.
|
||||
"""
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
import torch
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from PIL import Image
|
||||
import torchvision.transforms as T
|
||||
import numpy as np
|
||||
import logging
|
||||
import time
|
||||
from models.skeleton_vae import SkeletonVAE
|
||||
|
||||
|
||||
class SkeletonDataset(Dataset):
|
||||
def __init__(self, folder, size=64):
|
||||
self.files = list(Path(folder).glob('**/*_sk.png'))
|
||||
self.tr = T.Compose([T.Resize((size,size)), T.ToTensor()])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.files)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
p = self.files[idx]
|
||||
im = Image.open(p).convert('L')
|
||||
t = self.tr(im)
|
||||
return t
|
||||
|
||||
|
||||
def loss_fn(recon_x, x, mu, logvar):
|
||||
BCE = torch.nn.functional.binary_cross_entropy(recon_x, x, reduction='sum')
|
||||
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
|
||||
return BCE + KLD, BCE, KLD
|
||||
|
||||
|
||||
def setup_logger(log_dir):
|
||||
log_dir = Path(log_dir)
|
||||
log_dir.mkdir(parents=True, exist_ok=True)
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.FileHandler(log_dir / 'train.log'),
|
||||
logging.StreamHandler()
|
||||
],
|
||||
force=True
|
||||
)
|
||||
return logging.getLogger(__name__)
|
||||
|
||||
|
||||
def main():
|
||||
p = argparse.ArgumentParser()
|
||||
p.add_argument('skel_folder')
|
||||
p.add_argument('--epochs', type=int, default=10)
|
||||
p.add_argument('--batch', type=int, default=16)
|
||||
p.add_argument('--lr', type=float, default=1e-3)
|
||||
p.add_argument('--save_path', type=str, default='out/vae_model.pth')
|
||||
p.add_argument('--sample_dir', type=str, default='out/samples')
|
||||
args = p.parse_args()
|
||||
|
||||
Path(args.sample_dir).mkdir(parents=True, exist_ok=True)
|
||||
Path(args.save_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logger = setup_logger(Path(args.save_path).parent)
|
||||
|
||||
ds = SkeletonDataset(args.skel_folder, size=64)
|
||||
if len(ds) == 0:
|
||||
logger.error(f"No skeleton images found in {args.skel_folder}")
|
||||
return
|
||||
|
||||
dl = DataLoader(ds, batch_size=args.batch, shuffle=True, num_workers=4)
|
||||
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
logger.info(f"Using device: {device}")
|
||||
logger.info(f"Dataset size: {len(ds)} images")
|
||||
|
||||
model = SkeletonVAE(z_dim=64).to(device)
|
||||
opt = torch.optim.Adam(model.parameters(), lr=args.lr)
|
||||
|
||||
logger.info("Starting training...")
|
||||
for epoch in range(args.epochs):
|
||||
start_time = time.time()
|
||||
model.train()
|
||||
total_loss = 0.0
|
||||
total_bce = 0.0
|
||||
total_kld = 0.0
|
||||
|
||||
for i, xb in enumerate(dl):
|
||||
xb = xb.to(device)
|
||||
recon, mu, logvar = model(xb)
|
||||
loss, bce, kld = loss_fn(recon, xb, mu, logvar)
|
||||
|
||||
opt.zero_grad()
|
||||
loss.backward()
|
||||
opt.step()
|
||||
|
||||
total_loss += loss.item()
|
||||
total_bce += bce.item()
|
||||
total_kld += kld.item()
|
||||
|
||||
if (i + 1) % 10 == 0:
|
||||
logger.info(f"Epoch [{epoch+1}/{args.epochs}] Batch [{i+1}/{len(dl)}] "
|
||||
f"Loss: {loss.item()/len(xb):.4f} (BCE: {bce.item()/len(xb):.4f}, KLD: {kld.item()/len(xb):.4f})")
|
||||
|
||||
avg_loss = total_loss / len(ds)
|
||||
avg_bce = total_bce / len(ds)
|
||||
avg_kld = total_kld / len(ds)
|
||||
epoch_time = time.time() - start_time
|
||||
|
||||
logger.info(f'Epoch {epoch+1}/{args.epochs} completed in {epoch_time:.2f}s')
|
||||
logger.info(f' Average Loss: {avg_loss:.4f}')
|
||||
logger.info(f' BCE: {avg_bce:.4f} | KLD: {avg_kld:.4f}')
|
||||
|
||||
# Save model
|
||||
torch.save(model.state_dict(), args.save_path)
|
||||
|
||||
# Save sample reconstructions
|
||||
if (epoch + 1) % 5 == 0:
|
||||
logger.info(f"Saving reconstruction samples to {args.sample_dir}")
|
||||
with torch.no_grad():
|
||||
model.eval()
|
||||
# Reconstruct first batch
|
||||
xb = next(iter(dl)).to(device)
|
||||
recon, _, _ = model(xb)
|
||||
# Concat input and recon
|
||||
vis = torch.cat([xb[:8], recon[:8]], dim=0)
|
||||
# Save grid
|
||||
from torchvision.utils import save_image
|
||||
save_image(vis, Path(args.sample_dir) / f'epoch_{epoch+1}.png', nrow=8)
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Reference in New Issue
Block a user