68 lines
2.3 KiB
Python
68 lines
2.3 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Run the full experiment:
|
|
1. Build dataset from raw images
|
|
2. Train VAE model
|
|
3. Generate new skeletons
|
|
"""
|
|
import subprocess
|
|
import argparse
|
|
from pathlib import Path
|
|
|
|
def run_command(cmd):
|
|
print(f"Running: {' '.join(cmd)}")
|
|
subprocess.run(cmd, check=True)
|
|
|
|
def main():
|
|
p = argparse.ArgumentParser()
|
|
p.add_argument('--img_dir', type=str, default='ICCAD2019/img', help='Input raw images directory')
|
|
p.add_argument('--data_dir', type=str, default='out/dataset', help='Output dataset directory')
|
|
p.add_argument('--model_dir', type=str, default='out/models', help='Directory to save models')
|
|
p.add_argument('--gen_dir', type=str, default='out/generated', help='Directory for generated samples')
|
|
p.add_argument('--epochs', type=int, default=20, help='Training epochs')
|
|
p.add_argument('--batch_size', type=int, default=16, help='Batch size')
|
|
p.add_argument('--skip_data', action='store_true', help='Skip dataset building')
|
|
p.add_argument('--skip_train', action='store_true', help='Skip training')
|
|
args = p.parse_args()
|
|
|
|
# 1. Build Dataset
|
|
if not args.skip_data:
|
|
print("=== Step 1: Building Dataset ===")
|
|
# Note: build_skeleton_dataset.py expects img_folder and out_folder
|
|
cmd = [
|
|
'python3', 'datasets/build_skeleton_dataset.py',
|
|
args.img_dir,
|
|
args.data_dir
|
|
]
|
|
run_command(cmd)
|
|
|
|
# 2. Train Model
|
|
model_path = Path(args.model_dir) / 'vae_best.pth'
|
|
if not args.skip_train:
|
|
print("=== Step 2: Training VAE ===")
|
|
cmd = [
|
|
'python3', 'train/train_skeleton_vae.py',
|
|
args.data_dir,
|
|
'--epochs', str(args.epochs),
|
|
'--batch', str(args.batch_size),
|
|
'--save_path', str(model_path),
|
|
'--sample_dir', str(Path(args.model_dir) / 'train_samples')
|
|
]
|
|
run_command(cmd)
|
|
|
|
# 3. Generate
|
|
print("=== Step 3: Generating Skeletons ===")
|
|
cmd = [
|
|
'python3', 'scripts/generate_skeleton.py',
|
|
'--model_path', str(model_path),
|
|
'--out_dir', args.gen_dir,
|
|
'--num_samples', '20'
|
|
]
|
|
run_command(cmd)
|
|
|
|
print("=== Experiment Complete ===")
|
|
print(f"Generated images are in {args.gen_dir}")
|
|
|
|
if __name__ == '__main__':
|
|
main()
|