improve IC Layout Diffussion model 20251120
This commit is contained in:
355
tools/diffusion/run_optimized_pipeline.py
Executable file
355
tools/diffusion/run_optimized_pipeline.py
Executable file
@@ -0,0 +1,355 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
一键运行优化的IC版图扩散模型训练和生成管线
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import yaml
|
||||
import argparse
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
import logging
|
||||
import shutil
|
||||
|
||||
def setup_logging():
|
||||
"""设置日志"""
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.StreamHandler(sys.stdout),
|
||||
logging.FileHandler('optimized_pipeline.log')
|
||||
]
|
||||
)
|
||||
return logging.getLogger(__name__)
|
||||
|
||||
def run_command(cmd, description, logger):
|
||||
"""运行命令并处理错误"""
|
||||
logger.info(f"执行: {description}")
|
||||
logger.info(f"命令: {' '.join(cmd)}")
|
||||
|
||||
try:
|
||||
result = subprocess.run(cmd, check=True, capture_output=True, text=True)
|
||||
logger.info(f"{description} - 成功")
|
||||
if result.stdout:
|
||||
logger.debug(f"输出: {result.stdout}")
|
||||
return True
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.error(f"{description} - 失败")
|
||||
logger.error(f"错误码: {e.returncode}")
|
||||
logger.error(f"错误输出: {e.stderr}")
|
||||
return False
|
||||
|
||||
def validate_data_directory(data_dir, logger):
|
||||
"""验证数据目录"""
|
||||
data_path = Path(data_dir)
|
||||
if not data_path.exists():
|
||||
logger.error(f"数据目录不存在: {data_path}")
|
||||
return False
|
||||
|
||||
# 检查图像文件
|
||||
image_extensions = ['.png', '.jpg', '.jpeg']
|
||||
image_files = []
|
||||
for ext in image_extensions:
|
||||
image_files.extend(data_path.glob(f"*{ext}"))
|
||||
image_files.extend(data_path.glob(f"*{ext.upper()}"))
|
||||
|
||||
if len(image_files) == 0:
|
||||
logger.error(f"数据目录中没有找到图像文件: {data_path}")
|
||||
return False
|
||||
|
||||
logger.info(f"数据验证通过 - 找到 {len(image_files)} 个图像文件")
|
||||
return True
|
||||
|
||||
def create_sample_images(output_dir, logger, num_samples=5):
|
||||
"""创建示例图像"""
|
||||
logger.info("创建示例图像...")
|
||||
|
||||
# 创建简单的曼哈顿几何图案
|
||||
from PIL import Image, ImageDraw
|
||||
import numpy as np
|
||||
|
||||
sample_dir = Path(output_dir) / "reference_samples"
|
||||
sample_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for i in range(num_samples):
|
||||
# 创建空白图像
|
||||
img = Image.new('L', (256, 256), 255) # 白色背景
|
||||
draw = ImageDraw.Draw(img)
|
||||
|
||||
# 绘制曼哈顿几何图案
|
||||
np.random.seed(i)
|
||||
|
||||
# 外框
|
||||
draw.rectangle([20, 20, 236, 236], outline=0, width=2)
|
||||
|
||||
# 随机内部矩形
|
||||
for _ in range(np.random.randint(3, 8)):
|
||||
x1 = np.random.randint(40, 180)
|
||||
y1 = np.random.randint(40, 180)
|
||||
x2 = x1 + np.random.randint(20, 60)
|
||||
y2 = y1 + np.random.randint(20, 60)
|
||||
if x2 < 220 and y2 < 220: # 确保不超出边界
|
||||
draw.rectangle([x1, y1, x2, y2], outline=0, width=1)
|
||||
|
||||
# 保存图像
|
||||
img.save(sample_dir / f"sample_{i:03d}.png")
|
||||
|
||||
logger.info(f"示例图像已保存到: {sample_dir}")
|
||||
|
||||
def run_optimized_pipeline(args):
|
||||
"""运行优化管线"""
|
||||
logger = setup_logging()
|
||||
|
||||
logger.info("=== 开始优化的IC版图扩散模型管线 ===")
|
||||
|
||||
# 验证输入
|
||||
if not validate_data_directory(args.data_dir, logger):
|
||||
return False
|
||||
|
||||
# 创建输出目录
|
||||
output_dir = Path(args.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 如果需要,创建示例数据
|
||||
if args.create_sample_data:
|
||||
create_sample_images(args.data_dir, logger)
|
||||
|
||||
# 训练阶段
|
||||
if not args.skip_training:
|
||||
logger.info("\n=== 第一阶段: 训练优化模型 ===")
|
||||
|
||||
train_cmd = [
|
||||
sys.executable, "train_optimized.py",
|
||||
"--data_dir", args.data_dir,
|
||||
"--output_dir", str(output_dir / "model"),
|
||||
"--image_size", str(args.image_size),
|
||||
"--batch_size", str(args.batch_size),
|
||||
"--epochs", str(args.epochs),
|
||||
"--lr", str(args.lr),
|
||||
"--timesteps", str(args.timesteps),
|
||||
"--schedule_type", args.schedule_type,
|
||||
"--manhattan_weight", str(args.manhattan_weight),
|
||||
"--seed", str(args.seed),
|
||||
"--save_interval", str(args.save_interval),
|
||||
"--sample_interval", str(args.sample_interval),
|
||||
"--num_samples", str(args.train_samples)
|
||||
]
|
||||
|
||||
if args.edge_condition:
|
||||
train_cmd.append("--edge_condition")
|
||||
|
||||
if args.augment:
|
||||
train_cmd.append("--augment")
|
||||
|
||||
if args.resume:
|
||||
train_cmd.extend(["--resume", args.resume])
|
||||
|
||||
success = run_command(train_cmd, "训练优化模型", logger)
|
||||
if not success:
|
||||
logger.error("训练阶段失败")
|
||||
return False
|
||||
|
||||
# 查找最佳模型
|
||||
model_checkpoint = output_dir / "model" / "best_model.pth"
|
||||
if not model_checkpoint.exists():
|
||||
# 如果没有最佳模型,使用最终模型
|
||||
model_checkpoint = output_dir / "model" / "final_model.pth"
|
||||
|
||||
if not model_checkpoint.exists():
|
||||
logger.error("找不到训练好的模型")
|
||||
return False
|
||||
|
||||
else:
|
||||
logger.info("\n=== 跳过训练阶段 ===")
|
||||
model_checkpoint = args.checkpoint
|
||||
if not model_checkpoint:
|
||||
logger.error("跳过训练时需要提供 --checkpoint 参数")
|
||||
return False
|
||||
|
||||
if not Path(model_checkpoint).exists():
|
||||
logger.error(f"指定的检查点不存在: {model_checkpoint}")
|
||||
return False
|
||||
|
||||
# 生成阶段
|
||||
logger.info("\n=== 第二阶段: 生成样本 ===")
|
||||
|
||||
generate_cmd = [
|
||||
sys.executable, "generate_optimized.py",
|
||||
"--checkpoint", str(model_checkpoint),
|
||||
"--output_dir", str(output_dir / "generated"),
|
||||
"--num_samples", str(args.num_samples),
|
||||
"--image_size", str(args.image_size),
|
||||
"--batch_size", str(args.gen_batch_size),
|
||||
"--num_steps", str(args.num_steps),
|
||||
"--seed", str(args.seed),
|
||||
"--timesteps", str(args.timesteps),
|
||||
"--schedule_type", args.schedule_type
|
||||
]
|
||||
|
||||
if args.use_ddim:
|
||||
generate_cmd.append("--use_ddim")
|
||||
|
||||
if args.use_post_process:
|
||||
generate_cmd.append("--use_post_process")
|
||||
|
||||
success = run_command(generate_cmd, "生成样本", logger)
|
||||
if not success:
|
||||
logger.error("生成阶段失败")
|
||||
return False
|
||||
|
||||
# 更新配置文件(如果提供了)
|
||||
if args.update_config and Path(args.update_config).exists():
|
||||
logger.info("\n=== 第三阶段: 更新配置文件 ===")
|
||||
|
||||
config_path = Path(args.update_config)
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
# 更新扩散配置
|
||||
if 'synthetic' not in config:
|
||||
config['synthetic'] = {}
|
||||
|
||||
config['synthetic']['enabled'] = True
|
||||
config['synthetic']['ratio'] = 0.0 # 禁用程序化合成
|
||||
|
||||
if 'diffusion' not in config['synthetic']:
|
||||
config['synthetic']['diffusion'] = {}
|
||||
|
||||
config['synthetic']['diffusion']['enabled'] = True
|
||||
config['synthetic']['diffusion']['png_dir'] = str(output_dir / "generated")
|
||||
config['synthetic']['diffusion']['ratio'] = args.diffusion_ratio
|
||||
config['synthetic']['diffusion']['model_checkpoint'] = str(model_checkpoint)
|
||||
|
||||
# 保存配置
|
||||
with open(config_path, 'w', encoding='utf-8') as f:
|
||||
yaml.dump(config, f, default_flow_style=False, allow_unicode=True)
|
||||
|
||||
logger.info(f"配置文件已更新: {config_path}")
|
||||
logger.info(f"扩散数据比例: {args.diffusion_ratio}")
|
||||
|
||||
# 创建管线报告
|
||||
create_pipeline_report(output_dir, model_checkpoint, args, logger)
|
||||
|
||||
logger.info("\n=== 优化管线完成 ===")
|
||||
logger.info(f"模型: {model_checkpoint}")
|
||||
logger.info(f"生成数据: {output_dir / 'generated'}")
|
||||
logger.info(f"管线报告: {output_dir / 'pipeline_report.txt'}")
|
||||
|
||||
return True
|
||||
|
||||
def create_pipeline_report(output_dir, model_checkpoint, args, logger):
|
||||
"""创建管线报告"""
|
||||
report_content = f"""
|
||||
IC版图扩散模型优化管线报告
|
||||
============================
|
||||
|
||||
管线配置:
|
||||
- 数据目录: {args.data_dir}
|
||||
- 输出目录: {args.output_dir}
|
||||
- 图像尺寸: {args.image_size}x{args.image_size}
|
||||
- 训练轮数: {args.epochs}
|
||||
- 批次大小: {args.batch_size}
|
||||
- 学习率: {args.lr}
|
||||
- 时间步数: {args.timesteps}
|
||||
- 调度类型: {args.schedule_type}
|
||||
- 曼哈顿权重: {args.manhattan_weight}
|
||||
- 随机种子: {args.seed}
|
||||
|
||||
模型配置:
|
||||
- 边缘条件: {args.edge_condition}
|
||||
- 数据增强: {args.augment}
|
||||
- 最终模型: {model_checkpoint}
|
||||
|
||||
生成配置:
|
||||
- 生成样本数: {args.num_samples}
|
||||
- 生成批次大小: {args.gen_batch_size}
|
||||
- 采样步数: {args.num_steps}
|
||||
- DDIM采样: {args.use_ddim}
|
||||
- 后处理: {args.use_post_process}
|
||||
|
||||
优化特性:
|
||||
- 曼哈顿几何感知的U-Net架构
|
||||
- 边缘感知损失函数
|
||||
- 多尺度结构损失
|
||||
- 曼哈顿约束正则化
|
||||
- 几何保持的数据增强
|
||||
- 后处理优化
|
||||
|
||||
输出目录结构:
|
||||
- model/: 训练好的模型和检查点
|
||||
- generated/: 生成的IC版图样本
|
||||
- pipeline_report.txt: 本报告
|
||||
|
||||
质量评估:
|
||||
生成完成后,请查看 generated/quality_metrics.yaml 和 generation_report.txt 获取详细的质量评估。
|
||||
|
||||
使用说明:
|
||||
1. 训练数据应包含高质量的IC版图图像
|
||||
2. 建议使用边缘条件来提高生成质量
|
||||
3. 生成的样本可以使用后处理进一步优化
|
||||
4. 可根据质量评估结果调整训练参数
|
||||
"""
|
||||
|
||||
report_path = output_dir / 'pipeline_report.txt'
|
||||
with open(report_path, 'w', encoding='utf-8') as f:
|
||||
f.write(report_content)
|
||||
|
||||
logger.info(f"管线报告已保存: {report_path}")
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="一键运行优化的IC版图扩散模型管线")
|
||||
|
||||
# 基本参数
|
||||
parser.add_argument("--data_dir", type=str, required=True, help="训练数据目录")
|
||||
parser.add_argument("--output_dir", type=str, required=True, help="输出目录")
|
||||
|
||||
# 训练参数
|
||||
parser.add_argument("--image_size", type=int, default=256, help="图像尺寸")
|
||||
parser.add_argument("--batch_size", type=int, default=4, help="训练批次大小")
|
||||
parser.add_argument("--epochs", type=int, default=100, help="训练轮数")
|
||||
parser.add_argument("--lr", type=float, default=1e-4, help="学习率")
|
||||
parser.add_argument("--timesteps", type=int, default=1000, help="扩散时间步数")
|
||||
parser.add_argument("--schedule_type", type=str, default='cosine', choices=['linear', 'cosine'], help="噪声调度类型")
|
||||
parser.add_argument("--manhattan_weight", type=float, default=0.1, help="曼哈顿正则化权重")
|
||||
parser.add_argument("--seed", type=int, default=42, help="随机种子")
|
||||
parser.add_argument("--save_interval", type=int, default=10, help="模型保存间隔")
|
||||
parser.add_argument("--sample_interval", type=int, default=20, help="样本生成间隔")
|
||||
parser.add_argument("--train_samples", type=int, default=16, help="训练时生成的样本数量")
|
||||
|
||||
# 训练控制
|
||||
parser.add_argument("--skip_training", action='store_true', help="跳过训练,使用现有模型")
|
||||
parser.add_argument("--checkpoint", type=str, help="现有模型检查点路径(skip_training时使用)")
|
||||
parser.add_argument("--resume", type=str, help="恢复训练的检查点路径")
|
||||
parser.add_argument("--edge_condition", action='store_true', help="使用边缘条件")
|
||||
parser.add_argument("--augment", action='store_true', help="启用数据增强")
|
||||
|
||||
# 生成参数
|
||||
parser.add_argument("--num_samples", type=int, default=200, help="生成样本数量")
|
||||
parser.add_argument("--gen_batch_size", type=int, default=8, help="生成批次大小")
|
||||
parser.add_argument("--num_steps", type=int, default=50, help="采样步数")
|
||||
parser.add_argument("--use_ddim", action='store_true', default=True, help="使用DDIM采样")
|
||||
parser.add_argument("--use_post_process", action='store_true', default=True, help="启用后处理")
|
||||
|
||||
# 配置更新
|
||||
parser.add_argument("--update_config", type=str, help="要更新的配置文件路径")
|
||||
parser.add_argument("--diffusion_ratio", type=float, default=0.3, help="扩散数据在训练中的比例")
|
||||
|
||||
# 开发选项
|
||||
parser.add_argument("--create_sample_data", action='store_true', help="创建示例训练数据")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 验证参数
|
||||
if args.skip_training and not args.checkpoint:
|
||||
print("错误: 跳过训练时必须提供 --checkpoint 参数")
|
||||
sys.exit(1)
|
||||
|
||||
# 运行管线
|
||||
success = run_optimized_pipeline(args)
|
||||
sys.exit(0 if success else 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user