137 lines
4.4 KiB
Python
137 lines
4.4 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Minimal training loop for SkeletonVAE. Expects a directory of skeleton PNGs.
|
|
"""
|
|
import argparse
|
|
from pathlib import Path
|
|
import torch
|
|
from torch.utils.data import Dataset, DataLoader
|
|
from PIL import Image
|
|
import torchvision.transforms as T
|
|
import numpy as np
|
|
import logging
|
|
import time
|
|
from models.skeleton_vae import SkeletonVAE
|
|
|
|
|
|
class SkeletonDataset(Dataset):
|
|
def __init__(self, folder, size=64):
|
|
self.files = list(Path(folder).glob('**/*_sk.png'))
|
|
self.tr = T.Compose([T.Resize((size,size)), T.ToTensor()])
|
|
|
|
def __len__(self):
|
|
return len(self.files)
|
|
|
|
def __getitem__(self, idx):
|
|
p = self.files[idx]
|
|
im = Image.open(p).convert('L')
|
|
t = self.tr(im)
|
|
return t
|
|
|
|
|
|
def loss_fn(recon_x, x, mu, logvar):
|
|
BCE = torch.nn.functional.binary_cross_entropy(recon_x, x, reduction='sum')
|
|
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
|
|
return BCE + KLD, BCE, KLD
|
|
|
|
|
|
def setup_logger(log_dir):
|
|
log_dir = Path(log_dir)
|
|
log_dir.mkdir(parents=True, exist_ok=True)
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(levelname)s - %(message)s',
|
|
handlers=[
|
|
logging.FileHandler(log_dir / 'train.log'),
|
|
logging.StreamHandler()
|
|
],
|
|
force=True
|
|
)
|
|
return logging.getLogger(__name__)
|
|
|
|
|
|
def main():
|
|
p = argparse.ArgumentParser()
|
|
p.add_argument('skel_folder')
|
|
p.add_argument('--epochs', type=int, default=10)
|
|
p.add_argument('--batch', type=int, default=16)
|
|
p.add_argument('--lr', type=float, default=1e-3)
|
|
p.add_argument('--save_path', type=str, default='out/vae_model.pth')
|
|
p.add_argument('--sample_dir', type=str, default='out/samples')
|
|
args = p.parse_args()
|
|
|
|
Path(args.sample_dir).mkdir(parents=True, exist_ok=True)
|
|
Path(args.save_path).parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
logger = setup_logger(Path(args.save_path).parent)
|
|
|
|
ds = SkeletonDataset(args.skel_folder, size=64)
|
|
if len(ds) == 0:
|
|
logger.error(f"No skeleton images found in {args.skel_folder}")
|
|
return
|
|
|
|
dl = DataLoader(ds, batch_size=args.batch, shuffle=True, num_workers=4)
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
logger.info(f"Using device: {device}")
|
|
logger.info(f"Dataset size: {len(ds)} images")
|
|
|
|
model = SkeletonVAE(z_dim=64).to(device)
|
|
opt = torch.optim.Adam(model.parameters(), lr=args.lr)
|
|
|
|
logger.info("Starting training...")
|
|
for epoch in range(args.epochs):
|
|
start_time = time.time()
|
|
model.train()
|
|
total_loss = 0.0
|
|
total_bce = 0.0
|
|
total_kld = 0.0
|
|
|
|
for i, xb in enumerate(dl):
|
|
xb = xb.to(device)
|
|
recon, mu, logvar = model(xb)
|
|
loss, bce, kld = loss_fn(recon, xb, mu, logvar)
|
|
|
|
opt.zero_grad()
|
|
loss.backward()
|
|
opt.step()
|
|
|
|
total_loss += loss.item()
|
|
total_bce += bce.item()
|
|
total_kld += kld.item()
|
|
|
|
if (i + 1) % 10 == 0:
|
|
logger.info(f"Epoch [{epoch+1}/{args.epochs}] Batch [{i+1}/{len(dl)}] "
|
|
f"Loss: {loss.item()/len(xb):.4f} (BCE: {bce.item()/len(xb):.4f}, KLD: {kld.item()/len(xb):.4f})")
|
|
|
|
avg_loss = total_loss / len(ds)
|
|
avg_bce = total_bce / len(ds)
|
|
avg_kld = total_kld / len(ds)
|
|
epoch_time = time.time() - start_time
|
|
|
|
logger.info(f'Epoch {epoch+1}/{args.epochs} completed in {epoch_time:.2f}s')
|
|
logger.info(f' Average Loss: {avg_loss:.4f}')
|
|
logger.info(f' BCE: {avg_bce:.4f} | KLD: {avg_kld:.4f}')
|
|
|
|
# Save model
|
|
torch.save(model.state_dict(), args.save_path)
|
|
|
|
# Save sample reconstructions
|
|
if (epoch + 1) % 5 == 0:
|
|
logger.info(f"Saving reconstruction samples to {args.sample_dir}")
|
|
with torch.no_grad():
|
|
model.eval()
|
|
# Reconstruct first batch
|
|
xb = next(iter(dl)).to(device)
|
|
recon, _, _ = model(xb)
|
|
# Concat input and recon
|
|
vis = torch.cat([xb[:8], recon[:8]], dim=0)
|
|
# Save grid
|
|
from torchvision.utils import save_image
|
|
save_image(vis, Path(args.sample_dir) / f'epoch_{epoch+1}.png', nrow=8)
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|