355 lines
12 KiB
Python
Executable File
355 lines
12 KiB
Python
Executable File
#!/usr/bin/env python3
|
||
"""
|
||
一键运行优化的IC版图扩散模型训练和生成管线
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import yaml
|
||
import argparse
|
||
import subprocess
|
||
from pathlib import Path
|
||
import logging
|
||
import shutil
|
||
|
||
def setup_logging():
|
||
"""设置日志"""
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||
handlers=[
|
||
logging.StreamHandler(sys.stdout),
|
||
logging.FileHandler('optimized_pipeline.log')
|
||
]
|
||
)
|
||
return logging.getLogger(__name__)
|
||
|
||
def run_command(cmd, description, logger):
|
||
"""运行命令并处理错误"""
|
||
logger.info(f"执行: {description}")
|
||
logger.info(f"命令: {' '.join(cmd)}")
|
||
|
||
try:
|
||
result = subprocess.run(cmd, check=True, capture_output=True, text=True)
|
||
logger.info(f"{description} - 成功")
|
||
if result.stdout:
|
||
logger.debug(f"输出: {result.stdout}")
|
||
return True
|
||
except subprocess.CalledProcessError as e:
|
||
logger.error(f"{description} - 失败")
|
||
logger.error(f"错误码: {e.returncode}")
|
||
logger.error(f"错误输出: {e.stderr}")
|
||
return False
|
||
|
||
def validate_data_directory(data_dir, logger):
|
||
"""验证数据目录"""
|
||
data_path = Path(data_dir)
|
||
if not data_path.exists():
|
||
logger.error(f"数据目录不存在: {data_path}")
|
||
return False
|
||
|
||
# 检查图像文件
|
||
image_extensions = ['.png', '.jpg', '.jpeg']
|
||
image_files = []
|
||
for ext in image_extensions:
|
||
image_files.extend(data_path.glob(f"*{ext}"))
|
||
image_files.extend(data_path.glob(f"*{ext.upper()}"))
|
||
|
||
if len(image_files) == 0:
|
||
logger.error(f"数据目录中没有找到图像文件: {data_path}")
|
||
return False
|
||
|
||
logger.info(f"数据验证通过 - 找到 {len(image_files)} 个图像文件")
|
||
return True
|
||
|
||
def create_sample_images(output_dir, logger, num_samples=5):
|
||
"""创建示例图像"""
|
||
logger.info("创建示例图像...")
|
||
|
||
# 创建简单的曼哈顿几何图案
|
||
from PIL import Image, ImageDraw
|
||
import numpy as np
|
||
|
||
sample_dir = Path(output_dir) / "reference_samples"
|
||
sample_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
for i in range(num_samples):
|
||
# 创建空白图像
|
||
img = Image.new('L', (256, 256), 255) # 白色背景
|
||
draw = ImageDraw.Draw(img)
|
||
|
||
# 绘制曼哈顿几何图案
|
||
np.random.seed(i)
|
||
|
||
# 外框
|
||
draw.rectangle([20, 20, 236, 236], outline=0, width=2)
|
||
|
||
# 随机内部矩形
|
||
for _ in range(np.random.randint(3, 8)):
|
||
x1 = np.random.randint(40, 180)
|
||
y1 = np.random.randint(40, 180)
|
||
x2 = x1 + np.random.randint(20, 60)
|
||
y2 = y1 + np.random.randint(20, 60)
|
||
if x2 < 220 and y2 < 220: # 确保不超出边界
|
||
draw.rectangle([x1, y1, x2, y2], outline=0, width=1)
|
||
|
||
# 保存图像
|
||
img.save(sample_dir / f"sample_{i:03d}.png")
|
||
|
||
logger.info(f"示例图像已保存到: {sample_dir}")
|
||
|
||
def run_optimized_pipeline(args):
|
||
"""运行优化管线"""
|
||
logger = setup_logging()
|
||
|
||
logger.info("=== 开始优化的IC版图扩散模型管线 ===")
|
||
|
||
# 验证输入
|
||
if not validate_data_directory(args.data_dir, logger):
|
||
return False
|
||
|
||
# 创建输出目录
|
||
output_dir = Path(args.output_dir)
|
||
output_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
# 如果需要,创建示例数据
|
||
if args.create_sample_data:
|
||
create_sample_images(args.data_dir, logger)
|
||
|
||
# 训练阶段
|
||
if not args.skip_training:
|
||
logger.info("\n=== 第一阶段: 训练优化模型 ===")
|
||
|
||
train_cmd = [
|
||
sys.executable, "train_optimized.py",
|
||
"--data_dir", args.data_dir,
|
||
"--output_dir", str(output_dir / "model"),
|
||
"--image_size", str(args.image_size),
|
||
"--batch_size", str(args.batch_size),
|
||
"--epochs", str(args.epochs),
|
||
"--lr", str(args.lr),
|
||
"--timesteps", str(args.timesteps),
|
||
"--schedule_type", args.schedule_type,
|
||
"--manhattan_weight", str(args.manhattan_weight),
|
||
"--seed", str(args.seed),
|
||
"--save_interval", str(args.save_interval),
|
||
"--sample_interval", str(args.sample_interval),
|
||
"--num_samples", str(args.train_samples)
|
||
]
|
||
|
||
if args.edge_condition:
|
||
train_cmd.append("--edge_condition")
|
||
|
||
if args.augment:
|
||
train_cmd.append("--augment")
|
||
|
||
if args.resume:
|
||
train_cmd.extend(["--resume", args.resume])
|
||
|
||
success = run_command(train_cmd, "训练优化模型", logger)
|
||
if not success:
|
||
logger.error("训练阶段失败")
|
||
return False
|
||
|
||
# 查找最佳模型
|
||
model_checkpoint = output_dir / "model" / "best_model.pth"
|
||
if not model_checkpoint.exists():
|
||
# 如果没有最佳模型,使用最终模型
|
||
model_checkpoint = output_dir / "model" / "final_model.pth"
|
||
|
||
if not model_checkpoint.exists():
|
||
logger.error("找不到训练好的模型")
|
||
return False
|
||
|
||
else:
|
||
logger.info("\n=== 跳过训练阶段 ===")
|
||
model_checkpoint = args.checkpoint
|
||
if not model_checkpoint:
|
||
logger.error("跳过训练时需要提供 --checkpoint 参数")
|
||
return False
|
||
|
||
if not Path(model_checkpoint).exists():
|
||
logger.error(f"指定的检查点不存在: {model_checkpoint}")
|
||
return False
|
||
|
||
# 生成阶段
|
||
logger.info("\n=== 第二阶段: 生成样本 ===")
|
||
|
||
generate_cmd = [
|
||
sys.executable, "generate_optimized.py",
|
||
"--checkpoint", str(model_checkpoint),
|
||
"--output_dir", str(output_dir / "generated"),
|
||
"--num_samples", str(args.num_samples),
|
||
"--image_size", str(args.image_size),
|
||
"--batch_size", str(args.gen_batch_size),
|
||
"--num_steps", str(args.num_steps),
|
||
"--seed", str(args.seed),
|
||
"--timesteps", str(args.timesteps),
|
||
"--schedule_type", args.schedule_type
|
||
]
|
||
|
||
if args.use_ddim:
|
||
generate_cmd.append("--use_ddim")
|
||
|
||
if args.use_post_process:
|
||
generate_cmd.append("--use_post_process")
|
||
|
||
success = run_command(generate_cmd, "生成样本", logger)
|
||
if not success:
|
||
logger.error("生成阶段失败")
|
||
return False
|
||
|
||
# 更新配置文件(如果提供了)
|
||
if args.update_config and Path(args.update_config).exists():
|
||
logger.info("\n=== 第三阶段: 更新配置文件 ===")
|
||
|
||
config_path = Path(args.update_config)
|
||
with open(config_path, 'r', encoding='utf-8') as f:
|
||
config = yaml.safe_load(f)
|
||
|
||
# 更新扩散配置
|
||
if 'synthetic' not in config:
|
||
config['synthetic'] = {}
|
||
|
||
config['synthetic']['enabled'] = True
|
||
config['synthetic']['ratio'] = 0.0 # 禁用程序化合成
|
||
|
||
if 'diffusion' not in config['synthetic']:
|
||
config['synthetic']['diffusion'] = {}
|
||
|
||
config['synthetic']['diffusion']['enabled'] = True
|
||
config['synthetic']['diffusion']['png_dir'] = str(output_dir / "generated")
|
||
config['synthetic']['diffusion']['ratio'] = args.diffusion_ratio
|
||
config['synthetic']['diffusion']['model_checkpoint'] = str(model_checkpoint)
|
||
|
||
# 保存配置
|
||
with open(config_path, 'w', encoding='utf-8') as f:
|
||
yaml.dump(config, f, default_flow_style=False, allow_unicode=True)
|
||
|
||
logger.info(f"配置文件已更新: {config_path}")
|
||
logger.info(f"扩散数据比例: {args.diffusion_ratio}")
|
||
|
||
# 创建管线报告
|
||
create_pipeline_report(output_dir, model_checkpoint, args, logger)
|
||
|
||
logger.info("\n=== 优化管线完成 ===")
|
||
logger.info(f"模型: {model_checkpoint}")
|
||
logger.info(f"生成数据: {output_dir / 'generated'}")
|
||
logger.info(f"管线报告: {output_dir / 'pipeline_report.txt'}")
|
||
|
||
return True
|
||
|
||
def create_pipeline_report(output_dir, model_checkpoint, args, logger):
|
||
"""创建管线报告"""
|
||
report_content = f"""
|
||
IC版图扩散模型优化管线报告
|
||
============================
|
||
|
||
管线配置:
|
||
- 数据目录: {args.data_dir}
|
||
- 输出目录: {args.output_dir}
|
||
- 图像尺寸: {args.image_size}x{args.image_size}
|
||
- 训练轮数: {args.epochs}
|
||
- 批次大小: {args.batch_size}
|
||
- 学习率: {args.lr}
|
||
- 时间步数: {args.timesteps}
|
||
- 调度类型: {args.schedule_type}
|
||
- 曼哈顿权重: {args.manhattan_weight}
|
||
- 随机种子: {args.seed}
|
||
|
||
模型配置:
|
||
- 边缘条件: {args.edge_condition}
|
||
- 数据增强: {args.augment}
|
||
- 最终模型: {model_checkpoint}
|
||
|
||
生成配置:
|
||
- 生成样本数: {args.num_samples}
|
||
- 生成批次大小: {args.gen_batch_size}
|
||
- 采样步数: {args.num_steps}
|
||
- DDIM采样: {args.use_ddim}
|
||
- 后处理: {args.use_post_process}
|
||
|
||
优化特性:
|
||
- 曼哈顿几何感知的U-Net架构
|
||
- 边缘感知损失函数
|
||
- 多尺度结构损失
|
||
- 曼哈顿约束正则化
|
||
- 几何保持的数据增强
|
||
- 后处理优化
|
||
|
||
输出目录结构:
|
||
- model/: 训练好的模型和检查点
|
||
- generated/: 生成的IC版图样本
|
||
- pipeline_report.txt: 本报告
|
||
|
||
质量评估:
|
||
生成完成后,请查看 generated/quality_metrics.yaml 和 generation_report.txt 获取详细的质量评估。
|
||
|
||
使用说明:
|
||
1. 训练数据应包含高质量的IC版图图像
|
||
2. 建议使用边缘条件来提高生成质量
|
||
3. 生成的样本可以使用后处理进一步优化
|
||
4. 可根据质量评估结果调整训练参数
|
||
"""
|
||
|
||
report_path = output_dir / 'pipeline_report.txt'
|
||
with open(report_path, 'w', encoding='utf-8') as f:
|
||
f.write(report_content)
|
||
|
||
logger.info(f"管线报告已保存: {report_path}")
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description="一键运行优化的IC版图扩散模型管线")
|
||
|
||
# 基本参数
|
||
parser.add_argument("--data_dir", type=str, required=True, help="训练数据目录")
|
||
parser.add_argument("--output_dir", type=str, required=True, help="输出目录")
|
||
|
||
# 训练参数
|
||
parser.add_argument("--image_size", type=int, default=256, help="图像尺寸")
|
||
parser.add_argument("--batch_size", type=int, default=4, help="训练批次大小")
|
||
parser.add_argument("--epochs", type=int, default=100, help="训练轮数")
|
||
parser.add_argument("--lr", type=float, default=1e-4, help="学习率")
|
||
parser.add_argument("--timesteps", type=int, default=1000, help="扩散时间步数")
|
||
parser.add_argument("--schedule_type", type=str, default='cosine', choices=['linear', 'cosine'], help="噪声调度类型")
|
||
parser.add_argument("--manhattan_weight", type=float, default=0.1, help="曼哈顿正则化权重")
|
||
parser.add_argument("--seed", type=int, default=42, help="随机种子")
|
||
parser.add_argument("--save_interval", type=int, default=10, help="模型保存间隔")
|
||
parser.add_argument("--sample_interval", type=int, default=20, help="样本生成间隔")
|
||
parser.add_argument("--train_samples", type=int, default=16, help="训练时生成的样本数量")
|
||
|
||
# 训练控制
|
||
parser.add_argument("--skip_training", action='store_true', help="跳过训练,使用现有模型")
|
||
parser.add_argument("--checkpoint", type=str, help="现有模型检查点路径(skip_training时使用)")
|
||
parser.add_argument("--resume", type=str, help="恢复训练的检查点路径")
|
||
parser.add_argument("--edge_condition", action='store_true', help="使用边缘条件")
|
||
parser.add_argument("--augment", action='store_true', help="启用数据增强")
|
||
|
||
# 生成参数
|
||
parser.add_argument("--num_samples", type=int, default=200, help="生成样本数量")
|
||
parser.add_argument("--gen_batch_size", type=int, default=8, help="生成批次大小")
|
||
parser.add_argument("--num_steps", type=int, default=50, help="采样步数")
|
||
parser.add_argument("--use_ddim", action='store_true', default=True, help="使用DDIM采样")
|
||
parser.add_argument("--use_post_process", action='store_true', default=True, help="启用后处理")
|
||
|
||
# 配置更新
|
||
parser.add_argument("--update_config", type=str, help="要更新的配置文件路径")
|
||
parser.add_argument("--diffusion_ratio", type=float, default=0.3, help="扩散数据在训练中的比例")
|
||
|
||
# 开发选项
|
||
parser.add_argument("--create_sample_data", action='store_true', help="创建示例训练数据")
|
||
|
||
args = parser.parse_args()
|
||
|
||
# 验证参数
|
||
if args.skip_training and not args.checkpoint:
|
||
print("错误: 跳过训练时必须提供 --checkpoint 参数")
|
||
sys.exit(1)
|
||
|
||
# 运行管线
|
||
success = run_optimized_pipeline(args)
|
||
sys.exit(0 if success else 1)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main() |