improve IC Layout Diffussion model 20251120

This commit is contained in:
Jiao77
2025-11-20 01:47:09 +08:00
parent 930f1952d5
commit 49fe21fb2f
8 changed files with 2254 additions and 0 deletions

198
tools/diffusion/example_usage.py Executable file
View File

@@ -0,0 +1,198 @@
#!/usr/bin/env python3
"""
优化扩散模型使用示例
演示如何使用优化后的IC版图扩散模型
"""
import os
import sys
import torch
import subprocess
from pathlib import Path
def example_basic_training():
"""基本训练示例"""
print("=== 基本训练示例 ===")
# 创建示例数据目录
data_dir = "example_data/ic_layouts"
output_dir = "example_outputs/basic_training"
# 训练命令
cmd = [
sys.executable, "run_optimized_pipeline.py",
"--data_dir", data_dir,
"--output_dir", output_dir,
"--epochs", 20, # 示例用较少轮数
"--batch_size", 2,
"--image_size", 256,
"--num_samples", 10,
"--create_sample_data", # 创建示例数据
"--edge_condition",
"--augment"
]
print(f"运行命令: {' '.join(map(str, cmd))}")
print("注意:这将创建示例数据并开始训练")
# 实际运行时取消注释
# subprocess.run(cmd)
def example_advanced_training():
"""高级训练示例"""
print("\n=== 高级训练示例 ===")
cmd = [
sys.executable, "train_optimized.py",
"--data_dir", "data/high_quality_ic_layouts",
"--output_dir", "models/advanced_diffusion",
"--image_size", 512, # 更高分辨率
"--batch_size", 8,
"--epochs", 200,
"--lr", 5e-5, # 更低学习率
"--manhattan_weight", 0.15, # 更强的几何约束
"--edge_condition",
"--augment",
"--schedule_type", "cosine",
"--save_interval", 5,
"--sample_interval", 10
]
print(f"高级训练命令: {' '.join(map(str, cmd))}")
def example_generation_only():
"""仅生成示例"""
print("\n=== 仅生成示例 ===")
cmd = [
sys.executable, "generate_optimized.py",
"--checkpoint", "models/diffusion_optimized/best_model.pth",
"--output_dir", "generated_samples/high_quality",
"--num_samples", 100,
"--num_steps", 30, # 更快采样
"--use_ddim",
"--batch_size", 16,
"--use_post_process",
"--post_process_threshold", 0.45
]
print(f"生成命令: {' '.join(map(str, cmd))}")
def example_custom_parameters():
"""自定义参数示例"""
print("\n=== 自定义参数示例 ===")
# 针对特定需求的参数调整
scenarios = {
"高质量生成": {
"description": "追求最高质量的生成结果",
"params": {
"--num_steps": 100,
"--guidance_scale": 2.0,
"--eta": 0.0, # 完全确定性
"--use_post_process": True,
"--post_process_threshold": 0.5
}
},
"快速生成": {
"description": "快速生成大量样本",
"params": {
"--num_steps": 20,
"--batch_size": 32,
"--eta": 0.3, # 增加随机性
"--use_ddim": True
}
},
"几何约束严格": {
"description": "严格要求曼哈顿几何",
"params": {
"--manhattan_weight": 0.3, # 更强约束
"--use_post_process": True,
"--post_process_threshold": 0.4
}
}
}
for scenario_name, config in scenarios.items():
print(f"\n{scenario_name}: {config['description']}")
for param, value in config['params'].items():
print(f" {param}: {value}")
def example_integration():
"""集成到现有管线示例"""
print("\n=== 集成示例 ===")
# 更新配置文件
config_update = {
"config_file": "configs/train_config.yaml",
"updates": {
"synthetic.enabled": True,
"synthetic.ratio": 0.0,
"synthetic.diffusion.enabled": True,
"synthetic.diffusion.png_dir": "outputs/diffusion_optimized/generated",
"synthetic.diffusion.ratio": 0.4,
"synthetic.diffusion.model_checkpoint": "outputs/diffusion_optimized/model/best_model.pth"
}
}
print("配置文件更新示例:")
print(f"配置文件: {config_update['config_file']}")
for key, value in config_update['updates'].items():
print(f" {key}: {value}")
integration_cmd = [
sys.executable, "run_optimized_pipeline.py",
"--data_dir", "data/training_layouts",
"--output_dir", "outputs/integration",
"--update_config", config_update["config_file"],
"--diffusion_ratio", 0.4,
"--epochs", 100,
"--num_samples", 500
]
print(f"\n集成命令: {' '.join(map(str, integration_cmd))}")
def show_tips():
"""显示使用建议"""
print("\n=== 使用建议 ===")
tips = [
"🎯 数据质量是关键使用高质量、多样化的IC版图数据进行训练",
"⚖️ 平衡约束曼哈顿权重不宜过高0.05-0.2),避免过度约束影响生成多样性",
"🔄 迭代优化:根据生成结果调整损失函数权重和后处理参数",
"📊 质量监控:定期检查生成样本的质量指标",
"💾 定期保存:设置合理的保存间隔,避免训练中断导致损失",
"🚀 性能优化使用DDIM采样可以显著提高生成速度",
"🔧 参数调优:根据具体任务需求调整模型参数"
]
for tip in tips:
print(tip)
def main():
"""主函数"""
print("IC版图扩散模型优化版本 - 使用示例")
print("=" * 50)
# 检查是否在正确的目录
if not Path("ic_layout_diffusion_optimized.py").exists():
print("错误:请在 tools/diffusion/ 目录下运行此脚本")
sys.exit(1)
# 显示示例
example_basic_training()
example_advanced_training()
example_generation_only()
example_custom_parameters()
example_integration()
show_tips()
print("\n" + "=" * 50)
print("运行示例:")
print("1. 基本使用python run_optimized_pipeline.py --data_dir data/ic_layouts --output_dir outputs")
print("2. 查看完整参数python train_optimized.py --help")
print("3. 查看生成参数python generate_optimized.py --help")
print("4. 阅读详细文档README_OPTIMIZED.md")
if __name__ == "__main__":
main()