Files
IC-Layout-Generate/scripts/generate_skeleton.py
2025-11-24 20:34:50 +08:00

64 lines
2.1 KiB
Python

#!/usr/bin/env python3
"""
Generate new skeletons using the trained VAE model.
"""
import argparse
import sys
from pathlib import Path
import torch
from torchvision.utils import save_image
import numpy as np
# Add root to path to allow importing models
sys.path.append(str(Path(__file__).parents[1]))
from models.skeleton_vae import SkeletonVAE
def main():
p = argparse.ArgumentParser()
p.add_argument('--model_path', type=str, required=True, help='Path to trained .pth model')
p.add_argument('--out_dir', type=str, default='out/generated', help='Output directory for generated images')
p.add_argument('--num_samples', type=int, default=10, help='Number of samples to generate')
p.add_argument('--z_dim', type=int, default=64, help='Latent dimension size')
p.add_argument('--threshold', type=float, default=0.5, help='Binarization threshold')
args = p.parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load model
model = SkeletonVAE(z_dim=args.z_dim).to(device)
if not Path(args.model_path).exists():
print(f"Error: Model not found at {args.model_path}")
return
model.load_state_dict(torch.load(args.model_path, map_location=device))
model.eval()
out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
print(f"Generating {args.num_samples} samples...")
with torch.no_grad():
# Sample from standard normal distribution
z = torch.randn(args.num_samples, args.z_dim).to(device)
# Decode
generated = model.decoder(z)
# Save raw and binarized
for i in range(args.num_samples):
img = generated[i]
# Save raw probability map
save_image(img, out_dir / f'gen_{i}_raw.png')
# Binarize
bin_img = (img > args.threshold).float()
save_image(bin_img, out_dir / f'gen_{i}_bin.png')
print(f"Saved generated images to {out_dir}")
if __name__ == '__main__':
main()