#!/usr/bin/env python3 """ Generate new skeletons using the trained VAE model. """ import argparse import sys from pathlib import Path import torch from torchvision.utils import save_image import numpy as np # Add root to path to allow importing models sys.path.append(str(Path(__file__).parents[1])) from models.skeleton_vae import SkeletonVAE def main(): p = argparse.ArgumentParser() p.add_argument('--model_path', type=str, required=True, help='Path to trained .pth model') p.add_argument('--out_dir', type=str, default='out/generated', help='Output directory for generated images') p.add_argument('--num_samples', type=int, default=10, help='Number of samples to generate') p.add_argument('--z_dim', type=int, default=64, help='Latent dimension size') p.add_argument('--threshold', type=float, default=0.5, help='Binarization threshold') args = p.parse_args() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Load model model = SkeletonVAE(z_dim=args.z_dim).to(device) if not Path(args.model_path).exists(): print(f"Error: Model not found at {args.model_path}") return model.load_state_dict(torch.load(args.model_path, map_location=device)) model.eval() out_dir = Path(args.out_dir) out_dir.mkdir(parents=True, exist_ok=True) print(f"Generating {args.num_samples} samples...") with torch.no_grad(): # Sample from standard normal distribution z = torch.randn(args.num_samples, args.z_dim).to(device) # Decode generated = model.decoder(z) # Save raw and binarized for i in range(args.num_samples): img = generated[i] # Save raw probability map save_image(img, out_dir / f'gen_{i}_raw.png') # Binarize bin_img = (img > args.threshold).float() save_image(bin_img, out_dir / f'gen_{i}_bin.png') print(f"Saved generated images to {out_dir}") if __name__ == '__main__': main()