Files
RoRD-Layout-Recognation/tools/diffusion/run_optimized_pipeline.py
2025-11-20 01:47:09 +08:00

355 lines
12 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
一键运行优化的IC版图扩散模型训练和生成管线
"""
import os
import sys
import yaml
import argparse
import subprocess
from pathlib import Path
import logging
import shutil
def setup_logging():
"""设置日志"""
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(sys.stdout),
logging.FileHandler('optimized_pipeline.log')
]
)
return logging.getLogger(__name__)
def run_command(cmd, description, logger):
"""运行命令并处理错误"""
logger.info(f"执行: {description}")
logger.info(f"命令: {' '.join(cmd)}")
try:
result = subprocess.run(cmd, check=True, capture_output=True, text=True)
logger.info(f"{description} - 成功")
if result.stdout:
logger.debug(f"输出: {result.stdout}")
return True
except subprocess.CalledProcessError as e:
logger.error(f"{description} - 失败")
logger.error(f"错误码: {e.returncode}")
logger.error(f"错误输出: {e.stderr}")
return False
def validate_data_directory(data_dir, logger):
"""验证数据目录"""
data_path = Path(data_dir)
if not data_path.exists():
logger.error(f"数据目录不存在: {data_path}")
return False
# 检查图像文件
image_extensions = ['.png', '.jpg', '.jpeg']
image_files = []
for ext in image_extensions:
image_files.extend(data_path.glob(f"*{ext}"))
image_files.extend(data_path.glob(f"*{ext.upper()}"))
if len(image_files) == 0:
logger.error(f"数据目录中没有找到图像文件: {data_path}")
return False
logger.info(f"数据验证通过 - 找到 {len(image_files)} 个图像文件")
return True
def create_sample_images(output_dir, logger, num_samples=5):
"""创建示例图像"""
logger.info("创建示例图像...")
# 创建简单的曼哈顿几何图案
from PIL import Image, ImageDraw
import numpy as np
sample_dir = Path(output_dir) / "reference_samples"
sample_dir.mkdir(parents=True, exist_ok=True)
for i in range(num_samples):
# 创建空白图像
img = Image.new('L', (256, 256), 255) # 白色背景
draw = ImageDraw.Draw(img)
# 绘制曼哈顿几何图案
np.random.seed(i)
# 外框
draw.rectangle([20, 20, 236, 236], outline=0, width=2)
# 随机内部矩形
for _ in range(np.random.randint(3, 8)):
x1 = np.random.randint(40, 180)
y1 = np.random.randint(40, 180)
x2 = x1 + np.random.randint(20, 60)
y2 = y1 + np.random.randint(20, 60)
if x2 < 220 and y2 < 220: # 确保不超出边界
draw.rectangle([x1, y1, x2, y2], outline=0, width=1)
# 保存图像
img.save(sample_dir / f"sample_{i:03d}.png")
logger.info(f"示例图像已保存到: {sample_dir}")
def run_optimized_pipeline(args):
"""运行优化管线"""
logger = setup_logging()
logger.info("=== 开始优化的IC版图扩散模型管线 ===")
# 验证输入
if not validate_data_directory(args.data_dir, logger):
return False
# 创建输出目录
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# 如果需要,创建示例数据
if args.create_sample_data:
create_sample_images(args.data_dir, logger)
# 训练阶段
if not args.skip_training:
logger.info("\n=== 第一阶段: 训练优化模型 ===")
train_cmd = [
sys.executable, "train_optimized.py",
"--data_dir", args.data_dir,
"--output_dir", str(output_dir / "model"),
"--image_size", str(args.image_size),
"--batch_size", str(args.batch_size),
"--epochs", str(args.epochs),
"--lr", str(args.lr),
"--timesteps", str(args.timesteps),
"--schedule_type", args.schedule_type,
"--manhattan_weight", str(args.manhattan_weight),
"--seed", str(args.seed),
"--save_interval", str(args.save_interval),
"--sample_interval", str(args.sample_interval),
"--num_samples", str(args.train_samples)
]
if args.edge_condition:
train_cmd.append("--edge_condition")
if args.augment:
train_cmd.append("--augment")
if args.resume:
train_cmd.extend(["--resume", args.resume])
success = run_command(train_cmd, "训练优化模型", logger)
if not success:
logger.error("训练阶段失败")
return False
# 查找最佳模型
model_checkpoint = output_dir / "model" / "best_model.pth"
if not model_checkpoint.exists():
# 如果没有最佳模型,使用最终模型
model_checkpoint = output_dir / "model" / "final_model.pth"
if not model_checkpoint.exists():
logger.error("找不到训练好的模型")
return False
else:
logger.info("\n=== 跳过训练阶段 ===")
model_checkpoint = args.checkpoint
if not model_checkpoint:
logger.error("跳过训练时需要提供 --checkpoint 参数")
return False
if not Path(model_checkpoint).exists():
logger.error(f"指定的检查点不存在: {model_checkpoint}")
return False
# 生成阶段
logger.info("\n=== 第二阶段: 生成样本 ===")
generate_cmd = [
sys.executable, "generate_optimized.py",
"--checkpoint", str(model_checkpoint),
"--output_dir", str(output_dir / "generated"),
"--num_samples", str(args.num_samples),
"--image_size", str(args.image_size),
"--batch_size", str(args.gen_batch_size),
"--num_steps", str(args.num_steps),
"--seed", str(args.seed),
"--timesteps", str(args.timesteps),
"--schedule_type", args.schedule_type
]
if args.use_ddim:
generate_cmd.append("--use_ddim")
if args.use_post_process:
generate_cmd.append("--use_post_process")
success = run_command(generate_cmd, "生成样本", logger)
if not success:
logger.error("生成阶段失败")
return False
# 更新配置文件(如果提供了)
if args.update_config and Path(args.update_config).exists():
logger.info("\n=== 第三阶段: 更新配置文件 ===")
config_path = Path(args.update_config)
with open(config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
# 更新扩散配置
if 'synthetic' not in config:
config['synthetic'] = {}
config['synthetic']['enabled'] = True
config['synthetic']['ratio'] = 0.0 # 禁用程序化合成
if 'diffusion' not in config['synthetic']:
config['synthetic']['diffusion'] = {}
config['synthetic']['diffusion']['enabled'] = True
config['synthetic']['diffusion']['png_dir'] = str(output_dir / "generated")
config['synthetic']['diffusion']['ratio'] = args.diffusion_ratio
config['synthetic']['diffusion']['model_checkpoint'] = str(model_checkpoint)
# 保存配置
with open(config_path, 'w', encoding='utf-8') as f:
yaml.dump(config, f, default_flow_style=False, allow_unicode=True)
logger.info(f"配置文件已更新: {config_path}")
logger.info(f"扩散数据比例: {args.diffusion_ratio}")
# 创建管线报告
create_pipeline_report(output_dir, model_checkpoint, args, logger)
logger.info("\n=== 优化管线完成 ===")
logger.info(f"模型: {model_checkpoint}")
logger.info(f"生成数据: {output_dir / 'generated'}")
logger.info(f"管线报告: {output_dir / 'pipeline_report.txt'}")
return True
def create_pipeline_report(output_dir, model_checkpoint, args, logger):
"""创建管线报告"""
report_content = f"""
IC版图扩散模型优化管线报告
============================
管线配置:
- 数据目录: {args.data_dir}
- 输出目录: {args.output_dir}
- 图像尺寸: {args.image_size}x{args.image_size}
- 训练轮数: {args.epochs}
- 批次大小: {args.batch_size}
- 学习率: {args.lr}
- 时间步数: {args.timesteps}
- 调度类型: {args.schedule_type}
- 曼哈顿权重: {args.manhattan_weight}
- 随机种子: {args.seed}
模型配置:
- 边缘条件: {args.edge_condition}
- 数据增强: {args.augment}
- 最终模型: {model_checkpoint}
生成配置:
- 生成样本数: {args.num_samples}
- 生成批次大小: {args.gen_batch_size}
- 采样步数: {args.num_steps}
- DDIM采样: {args.use_ddim}
- 后处理: {args.use_post_process}
优化特性:
- 曼哈顿几何感知的U-Net架构
- 边缘感知损失函数
- 多尺度结构损失
- 曼哈顿约束正则化
- 几何保持的数据增强
- 后处理优化
输出目录结构:
- model/: 训练好的模型和检查点
- generated/: 生成的IC版图样本
- pipeline_report.txt: 本报告
质量评估:
生成完成后,请查看 generated/quality_metrics.yaml 和 generation_report.txt 获取详细的质量评估。
使用说明:
1. 训练数据应包含高质量的IC版图图像
2. 建议使用边缘条件来提高生成质量
3. 生成的样本可以使用后处理进一步优化
4. 可根据质量评估结果调整训练参数
"""
report_path = output_dir / 'pipeline_report.txt'
with open(report_path, 'w', encoding='utf-8') as f:
f.write(report_content)
logger.info(f"管线报告已保存: {report_path}")
def main():
parser = argparse.ArgumentParser(description="一键运行优化的IC版图扩散模型管线")
# 基本参数
parser.add_argument("--data_dir", type=str, required=True, help="训练数据目录")
parser.add_argument("--output_dir", type=str, required=True, help="输出目录")
# 训练参数
parser.add_argument("--image_size", type=int, default=256, help="图像尺寸")
parser.add_argument("--batch_size", type=int, default=4, help="训练批次大小")
parser.add_argument("--epochs", type=int, default=100, help="训练轮数")
parser.add_argument("--lr", type=float, default=1e-4, help="学习率")
parser.add_argument("--timesteps", type=int, default=1000, help="扩散时间步数")
parser.add_argument("--schedule_type", type=str, default='cosine', choices=['linear', 'cosine'], help="噪声调度类型")
parser.add_argument("--manhattan_weight", type=float, default=0.1, help="曼哈顿正则化权重")
parser.add_argument("--seed", type=int, default=42, help="随机种子")
parser.add_argument("--save_interval", type=int, default=10, help="模型保存间隔")
parser.add_argument("--sample_interval", type=int, default=20, help="样本生成间隔")
parser.add_argument("--train_samples", type=int, default=16, help="训练时生成的样本数量")
# 训练控制
parser.add_argument("--skip_training", action='store_true', help="跳过训练,使用现有模型")
parser.add_argument("--checkpoint", type=str, help="现有模型检查点路径skip_training时使用")
parser.add_argument("--resume", type=str, help="恢复训练的检查点路径")
parser.add_argument("--edge_condition", action='store_true', help="使用边缘条件")
parser.add_argument("--augment", action='store_true', help="启用数据增强")
# 生成参数
parser.add_argument("--num_samples", type=int, default=200, help="生成样本数量")
parser.add_argument("--gen_batch_size", type=int, default=8, help="生成批次大小")
parser.add_argument("--num_steps", type=int, default=50, help="采样步数")
parser.add_argument("--use_ddim", action='store_true', default=True, help="使用DDIM采样")
parser.add_argument("--use_post_process", action='store_true', default=True, help="启用后处理")
# 配置更新
parser.add_argument("--update_config", type=str, help="要更新的配置文件路径")
parser.add_argument("--diffusion_ratio", type=float, default=0.3, help="扩散数据在训练中的比例")
# 开发选项
parser.add_argument("--create_sample_data", action='store_true', help="创建示例训练数据")
args = parser.parse_args()
# 验证参数
if args.skip_training and not args.checkpoint:
print("错误: 跳过训练时必须提供 --checkpoint 参数")
sys.exit(1)
# 运行管线
success = run_optimized_pipeline(args)
sys.exit(0 if success else 1)
if __name__ == "__main__":
main()