Files
IC-Layout-Generate/run_experiment.py
2025-11-24 20:34:50 +08:00

68 lines
2.3 KiB
Python

#!/usr/bin/env python3
"""
Run the full experiment:
1. Build dataset from raw images
2. Train VAE model
3. Generate new skeletons
"""
import subprocess
import argparse
from pathlib import Path
def run_command(cmd):
print(f"Running: {' '.join(cmd)}")
subprocess.run(cmd, check=True)
def main():
p = argparse.ArgumentParser()
p.add_argument('--img_dir', type=str, default='ICCAD2019/img', help='Input raw images directory')
p.add_argument('--data_dir', type=str, default='out/dataset', help='Output dataset directory')
p.add_argument('--model_dir', type=str, default='out/models', help='Directory to save models')
p.add_argument('--gen_dir', type=str, default='out/generated', help='Directory for generated samples')
p.add_argument('--epochs', type=int, default=20, help='Training epochs')
p.add_argument('--batch_size', type=int, default=16, help='Batch size')
p.add_argument('--skip_data', action='store_true', help='Skip dataset building')
p.add_argument('--skip_train', action='store_true', help='Skip training')
args = p.parse_args()
# 1. Build Dataset
if not args.skip_data:
print("=== Step 1: Building Dataset ===")
# Note: build_skeleton_dataset.py expects img_folder and out_folder
cmd = [
'python3', 'datasets/build_skeleton_dataset.py',
args.img_dir,
args.data_dir
]
run_command(cmd)
# 2. Train Model
model_path = Path(args.model_dir) / 'vae_best.pth'
if not args.skip_train:
print("=== Step 2: Training VAE ===")
cmd = [
'python3', 'train/train_skeleton_vae.py',
args.data_dir,
'--epochs', str(args.epochs),
'--batch', str(args.batch_size),
'--save_path', str(model_path),
'--sample_dir', str(Path(args.model_dir) / 'train_samples')
]
run_command(cmd)
# 3. Generate
print("=== Step 3: Generating Skeletons ===")
cmd = [
'python3', 'scripts/generate_skeleton.py',
'--model_path', str(model_path),
'--out_dir', args.gen_dir,
'--num_samples', '20'
]
run_command(cmd)
print("=== Experiment Complete ===")
print(f"Generated images are in {args.gen_dir}")
if __name__ == '__main__':
main()