Files
RoRD-Layout-Recognation/tools/diffusion/example_usage.py
2025-11-20 01:47:09 +08:00

198 lines
6.2 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
优化扩散模型使用示例
演示如何使用优化后的IC版图扩散模型
"""
import os
import sys
import torch
import subprocess
from pathlib import Path
def example_basic_training():
"""基本训练示例"""
print("=== 基本训练示例 ===")
# 创建示例数据目录
data_dir = "example_data/ic_layouts"
output_dir = "example_outputs/basic_training"
# 训练命令
cmd = [
sys.executable, "run_optimized_pipeline.py",
"--data_dir", data_dir,
"--output_dir", output_dir,
"--epochs", 20, # 示例用较少轮数
"--batch_size", 2,
"--image_size", 256,
"--num_samples", 10,
"--create_sample_data", # 创建示例数据
"--edge_condition",
"--augment"
]
print(f"运行命令: {' '.join(map(str, cmd))}")
print("注意:这将创建示例数据并开始训练")
# 实际运行时取消注释
# subprocess.run(cmd)
def example_advanced_training():
"""高级训练示例"""
print("\n=== 高级训练示例 ===")
cmd = [
sys.executable, "train_optimized.py",
"--data_dir", "data/high_quality_ic_layouts",
"--output_dir", "models/advanced_diffusion",
"--image_size", 512, # 更高分辨率
"--batch_size", 8,
"--epochs", 200,
"--lr", 5e-5, # 更低学习率
"--manhattan_weight", 0.15, # 更强的几何约束
"--edge_condition",
"--augment",
"--schedule_type", "cosine",
"--save_interval", 5,
"--sample_interval", 10
]
print(f"高级训练命令: {' '.join(map(str, cmd))}")
def example_generation_only():
"""仅生成示例"""
print("\n=== 仅生成示例 ===")
cmd = [
sys.executable, "generate_optimized.py",
"--checkpoint", "models/diffusion_optimized/best_model.pth",
"--output_dir", "generated_samples/high_quality",
"--num_samples", 100,
"--num_steps", 30, # 更快采样
"--use_ddim",
"--batch_size", 16,
"--use_post_process",
"--post_process_threshold", 0.45
]
print(f"生成命令: {' '.join(map(str, cmd))}")
def example_custom_parameters():
"""自定义参数示例"""
print("\n=== 自定义参数示例 ===")
# 针对特定需求的参数调整
scenarios = {
"高质量生成": {
"description": "追求最高质量的生成结果",
"params": {
"--num_steps": 100,
"--guidance_scale": 2.0,
"--eta": 0.0, # 完全确定性
"--use_post_process": True,
"--post_process_threshold": 0.5
}
},
"快速生成": {
"description": "快速生成大量样本",
"params": {
"--num_steps": 20,
"--batch_size": 32,
"--eta": 0.3, # 增加随机性
"--use_ddim": True
}
},
"几何约束严格": {
"description": "严格要求曼哈顿几何",
"params": {
"--manhattan_weight": 0.3, # 更强约束
"--use_post_process": True,
"--post_process_threshold": 0.4
}
}
}
for scenario_name, config in scenarios.items():
print(f"\n{scenario_name}: {config['description']}")
for param, value in config['params'].items():
print(f" {param}: {value}")
def example_integration():
"""集成到现有管线示例"""
print("\n=== 集成示例 ===")
# 更新配置文件
config_update = {
"config_file": "configs/train_config.yaml",
"updates": {
"synthetic.enabled": True,
"synthetic.ratio": 0.0,
"synthetic.diffusion.enabled": True,
"synthetic.diffusion.png_dir": "outputs/diffusion_optimized/generated",
"synthetic.diffusion.ratio": 0.4,
"synthetic.diffusion.model_checkpoint": "outputs/diffusion_optimized/model/best_model.pth"
}
}
print("配置文件更新示例:")
print(f"配置文件: {config_update['config_file']}")
for key, value in config_update['updates'].items():
print(f" {key}: {value}")
integration_cmd = [
sys.executable, "run_optimized_pipeline.py",
"--data_dir", "data/training_layouts",
"--output_dir", "outputs/integration",
"--update_config", config_update["config_file"],
"--diffusion_ratio", 0.4,
"--epochs", 100,
"--num_samples", 500
]
print(f"\n集成命令: {' '.join(map(str, integration_cmd))}")
def show_tips():
"""显示使用建议"""
print("\n=== 使用建议 ===")
tips = [
"🎯 数据质量是关键使用高质量、多样化的IC版图数据进行训练",
"⚖️ 平衡约束曼哈顿权重不宜过高0.05-0.2),避免过度约束影响生成多样性",
"🔄 迭代优化:根据生成结果调整损失函数权重和后处理参数",
"📊 质量监控:定期检查生成样本的质量指标",
"💾 定期保存:设置合理的保存间隔,避免训练中断导致损失",
"🚀 性能优化使用DDIM采样可以显著提高生成速度",
"🔧 参数调优:根据具体任务需求调整模型参数"
]
for tip in tips:
print(tip)
def main():
"""主函数"""
print("IC版图扩散模型优化版本 - 使用示例")
print("=" * 50)
# 检查是否在正确的目录
if not Path("ic_layout_diffusion_optimized.py").exists():
print("错误:请在 tools/diffusion/ 目录下运行此脚本")
sys.exit(1)
# 显示示例
example_basic_training()
example_advanced_training()
example_generation_only()
example_custom_parameters()
example_integration()
show_tips()
print("\n" + "=" * 50)
print("运行示例:")
print("1. 基本使用python run_optimized_pipeline.py --data_dir data/ic_layouts --output_dir outputs")
print("2. 查看完整参数python train_optimized.py --help")
print("3. 查看生成参数python generate_optimized.py --help")
print("4. 阅读详细文档README_OPTIMIZED.md")
if __name__ == "__main__":
main()