64 lines
2.1 KiB
Python
64 lines
2.1 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Generate new skeletons using the trained VAE model.
|
|
"""
|
|
import argparse
|
|
import sys
|
|
from pathlib import Path
|
|
import torch
|
|
from torchvision.utils import save_image
|
|
import numpy as np
|
|
|
|
# Add root to path to allow importing models
|
|
sys.path.append(str(Path(__file__).parents[1]))
|
|
|
|
from models.skeleton_vae import SkeletonVAE
|
|
|
|
def main():
|
|
p = argparse.ArgumentParser()
|
|
p.add_argument('--model_path', type=str, required=True, help='Path to trained .pth model')
|
|
p.add_argument('--out_dir', type=str, default='out/generated', help='Output directory for generated images')
|
|
p.add_argument('--num_samples', type=int, default=10, help='Number of samples to generate')
|
|
p.add_argument('--z_dim', type=int, default=64, help='Latent dimension size')
|
|
p.add_argument('--threshold', type=float, default=0.5, help='Binarization threshold')
|
|
args = p.parse_args()
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
# Load model
|
|
model = SkeletonVAE(z_dim=args.z_dim).to(device)
|
|
if not Path(args.model_path).exists():
|
|
print(f"Error: Model not found at {args.model_path}")
|
|
return
|
|
|
|
model.load_state_dict(torch.load(args.model_path, map_location=device))
|
|
model.eval()
|
|
|
|
out_dir = Path(args.out_dir)
|
|
out_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
print(f"Generating {args.num_samples} samples...")
|
|
|
|
with torch.no_grad():
|
|
# Sample from standard normal distribution
|
|
z = torch.randn(args.num_samples, args.z_dim).to(device)
|
|
|
|
# Decode
|
|
generated = model.decoder(z)
|
|
|
|
# Save raw and binarized
|
|
for i in range(args.num_samples):
|
|
img = generated[i]
|
|
|
|
# Save raw probability map
|
|
save_image(img, out_dir / f'gen_{i}_raw.png')
|
|
|
|
# Binarize
|
|
bin_img = (img > args.threshold).float()
|
|
save_image(bin_img, out_dir / f'gen_{i}_bin.png')
|
|
|
|
print(f"Saved generated images to {out_dir}")
|
|
|
|
if __name__ == '__main__':
|
|
main()
|