initial commit
This commit is contained in:
63
scripts/generate_skeleton.py
Normal file
63
scripts/generate_skeleton.py
Normal file
@@ -0,0 +1,63 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Generate new skeletons using the trained VAE model.
|
||||
"""
|
||||
import argparse
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import torch
|
||||
from torchvision.utils import save_image
|
||||
import numpy as np
|
||||
|
||||
# Add root to path to allow importing models
|
||||
sys.path.append(str(Path(__file__).parents[1]))
|
||||
|
||||
from models.skeleton_vae import SkeletonVAE
|
||||
|
||||
def main():
|
||||
p = argparse.ArgumentParser()
|
||||
p.add_argument('--model_path', type=str, required=True, help='Path to trained .pth model')
|
||||
p.add_argument('--out_dir', type=str, default='out/generated', help='Output directory for generated images')
|
||||
p.add_argument('--num_samples', type=int, default=10, help='Number of samples to generate')
|
||||
p.add_argument('--z_dim', type=int, default=64, help='Latent dimension size')
|
||||
p.add_argument('--threshold', type=float, default=0.5, help='Binarization threshold')
|
||||
args = p.parse_args()
|
||||
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
# Load model
|
||||
model = SkeletonVAE(z_dim=args.z_dim).to(device)
|
||||
if not Path(args.model_path).exists():
|
||||
print(f"Error: Model not found at {args.model_path}")
|
||||
return
|
||||
|
||||
model.load_state_dict(torch.load(args.model_path, map_location=device))
|
||||
model.eval()
|
||||
|
||||
out_dir = Path(args.out_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"Generating {args.num_samples} samples...")
|
||||
|
||||
with torch.no_grad():
|
||||
# Sample from standard normal distribution
|
||||
z = torch.randn(args.num_samples, args.z_dim).to(device)
|
||||
|
||||
# Decode
|
||||
generated = model.decoder(z)
|
||||
|
||||
# Save raw and binarized
|
||||
for i in range(args.num_samples):
|
||||
img = generated[i]
|
||||
|
||||
# Save raw probability map
|
||||
save_image(img, out_dir / f'gen_{i}_raw.png')
|
||||
|
||||
# Binarize
|
||||
bin_img = (img > args.threshold).float()
|
||||
save_image(bin_img, out_dir / f'gen_{i}_bin.png')
|
||||
|
||||
print(f"Saved generated images to {out_dir}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Reference in New Issue
Block a user