Compare commits
13 Commits
f95a2bd2db
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d8dabd1951 | ||
|
|
d29bc650c3 | ||
|
|
afd48c2d86 | ||
|
|
bacf8cd69d | ||
|
|
26763fa75c | ||
|
|
6e3d01bc83 | ||
|
|
3258b7b6de | ||
|
|
3d75ed722a | ||
|
|
116551af18 | ||
|
|
f8975b26b4 | ||
|
|
ebda75fa5e | ||
|
|
0a45856b14 | ||
|
|
d2c75a2d14 |
@@ -1,283 +0,0 @@
|
||||
# 优化的IC版图扩散模型
|
||||
|
||||
针对曼哈顿多边形IC版图光栅化图像生成的去噪扩散模型优化版本。
|
||||
|
||||
## 🎯 优化目标
|
||||
|
||||
专门优化以曼哈顿多边形为全部组成元素的IC版图光栅化图像生成,主要特点:
|
||||
|
||||
- **曼哈顿几何感知**:模型架构专门处理水平/垂直线条特征
|
||||
- **边缘锐化**:保持IC版图清晰的边缘特性
|
||||
- **多尺度结构**:保持从微观到宏观的结构一致性
|
||||
- **几何约束**:确保生成结果符合曼哈顿几何规则
|
||||
- **后处理优化**:进一步提升生成质量
|
||||
|
||||
## 📁 文件结构
|
||||
|
||||
```
|
||||
tools/diffusion/
|
||||
├── ic_layout_diffusion_optimized.py # 优化的核心模型实现
|
||||
├── train_optimized.py # 训练脚本
|
||||
├── generate_optimized.py # 生成脚本
|
||||
├── run_optimized_pipeline.py # 一键运行管线
|
||||
├── README_OPTIMIZED.md # 本文档
|
||||
└── original/ # 原始实现(参考用)
|
||||
├── ic_layout_diffusion.py
|
||||
└── ...
|
||||
```
|
||||
|
||||
## 🚀 快速开始
|
||||
|
||||
### 1. 基本使用 - 一键运行
|
||||
|
||||
```bash
|
||||
# 完整管线(训练 + 生成)
|
||||
python tools/diffusion/run_optimized_pipeline.py \
|
||||
--data_dir data/ic_layouts \
|
||||
--output_dir outputs/diffusion_optimized \
|
||||
--epochs 50 \
|
||||
--num_samples 200
|
||||
|
||||
# 仅生成(使用已有模型)
|
||||
python tools/diffusion/run_optimized_pipeline.py \
|
||||
--skip_training \
|
||||
--checkpoint outputs/diffusion_optimized/model/best_model.pth \
|
||||
--data_dir data/ic_layouts \
|
||||
--output_dir outputs/diffusion_generated \
|
||||
--num_samples 500
|
||||
```
|
||||
|
||||
### 2. 分步使用
|
||||
|
||||
#### 训练模型
|
||||
|
||||
```bash
|
||||
python tools/diffusion/train_optimized.py \
|
||||
--data_dir data/ic_layouts \
|
||||
--output_dir models/diffusion_optimized \
|
||||
--image_size 256 \
|
||||
--batch_size 4 \
|
||||
--epochs 100 \
|
||||
--lr 1e-4 \
|
||||
--edge_condition \
|
||||
--augment \
|
||||
--manhattan_weight 0.1
|
||||
```
|
||||
|
||||
#### 生成样本
|
||||
|
||||
```bash
|
||||
python tools/diffusion/generate_optimized.py \
|
||||
--checkpoint models/diffusion_optimized/best_model.pth \
|
||||
--output_dir generated_layouts \
|
||||
--num_samples 200 \
|
||||
--num_steps 50 \
|
||||
--use_ddim \
|
||||
--use_post_process
|
||||
```
|
||||
|
||||
## 🔧 关键优化特性
|
||||
|
||||
### 1. 曼哈顿几何感知U-Net
|
||||
|
||||
```python
|
||||
class ManhattanAwareUNet(nn.Module):
|
||||
"""曼哈顿几何感知的U-Net架构"""
|
||||
|
||||
def __init__(self, use_edge_condition=False):
|
||||
# 专门的方向感知卷积
|
||||
self.horiz_conv = nn.Conv2d(in_channels, 32, (1, 7), padding=(0, 3))
|
||||
self.vert_conv = nn.Conv2d(in_channels, 32, (7, 1), padding=(3, 0))
|
||||
self.standard_conv = nn.Conv2d(in_channels, 32, 3, padding=1)
|
||||
|
||||
# 特征融合
|
||||
self.fusion = nn.Conv2d(96, 64, 3, padding=1)
|
||||
```
|
||||
|
||||
**优势**:
|
||||
- 专门提取水平和垂直特征
|
||||
- 保持曼哈顿几何结构
|
||||
- 增强线条检测能力
|
||||
|
||||
### 2. 多目标损失函数
|
||||
|
||||
```python
|
||||
# 组合损失函数
|
||||
total_loss = mse_loss +
|
||||
0.3 * edge_loss + # 边缘感知损失
|
||||
0.2 * structure_loss + # 多尺度结构损失
|
||||
0.1 * manhattan_loss # 曼哈顿约束损失
|
||||
```
|
||||
|
||||
**优势**:
|
||||
- 保持边缘锐利度
|
||||
- 维持多尺度结构一致性
|
||||
- 强制曼哈顿几何约束
|
||||
|
||||
### 3. 几何保持的数据增强
|
||||
|
||||
```python
|
||||
# 只使用不破坏曼哈顿几何的增强
|
||||
self.aug_transform = transforms.Compose([
|
||||
transforms.RandomHorizontalFlip(p=0.5),
|
||||
transforms.RandomVerticalFlip(p=0.5),
|
||||
# 移除旋转,保持几何约束
|
||||
])
|
||||
```
|
||||
|
||||
### 4. 后处理优化
|
||||
|
||||
```python
|
||||
def manhattan_post_process(image):
|
||||
"""曼哈顿化后处理"""
|
||||
# 形态学操作强化直角特征
|
||||
# 水平和垂直增强
|
||||
# 二值化处理
|
||||
return processed_image
|
||||
```
|
||||
|
||||
## 📊 质量评估指标
|
||||
|
||||
生成样本会自动评估以下指标:
|
||||
|
||||
1. **曼哈顿几何合规性** - 角度偏差损失(越低越好)
|
||||
2. **边缘锐度** - 边缘强度平均值
|
||||
3. **对比度** - 图像标准差
|
||||
4. **稀疏性** - 低像素值比例(IC版图特性)
|
||||
|
||||
## 🎛️ 参数调优指南
|
||||
|
||||
### 训练参数
|
||||
|
||||
| 参数 | 推荐值 | 说明 |
|
||||
|------|--------|------|
|
||||
| `manhattan_weight` | 0.05 - 0.2 | 曼哈顿约束权重 |
|
||||
| `schedule_type` | cosine | 余弦调度通常效果更好 |
|
||||
| `edge_condition` | True | 使用边缘条件提高质量 |
|
||||
| `batch_size` | 4 - 8 | 根据GPU内存调整 |
|
||||
|
||||
### 生成参数
|
||||
|
||||
| 参数 | 推荐值 | 说明 |
|
||||
|------|--------|------|
|
||||
| `num_steps` | 20 - 50 | DDIM采样步数 |
|
||||
| `eta` | 0.0 - 0.3 | 随机性控制(0=确定性) |
|
||||
| `guidance_scale` | 1.0 - 3.0 | 引导强度 |
|
||||
| `post_process_threshold` | 0.4 - 0.6 | 后处理阈值 |
|
||||
|
||||
## 🔍 故障排除
|
||||
|
||||
### 1. 训练问题
|
||||
|
||||
**Q: 损失不下降**
|
||||
- 检查数据质量和格式
|
||||
- 降低学习率
|
||||
- 增加批次大小
|
||||
- 调整曼哈顿权重
|
||||
|
||||
**Q: 生成的图像模糊**
|
||||
- 增加边缘损失权重
|
||||
- 使用边缘条件训练
|
||||
- 调整后处理阈值
|
||||
- 增加训练轮数
|
||||
|
||||
### 2. 生成问题
|
||||
|
||||
**Q: 生成结果不符合曼哈顿几何**
|
||||
- 增加 `manhattan_weight`
|
||||
- 启用后处理
|
||||
- 降低 `eta` 参数
|
||||
|
||||
**Q: 生成速度慢**
|
||||
- 使用DDIM采样
|
||||
- 减少 `num_steps`
|
||||
- 增加 `batch_size`
|
||||
|
||||
### 3. 内存问题
|
||||
|
||||
**Q: GPU内存不足**
|
||||
- 减少批次大小
|
||||
- 减小图像尺寸
|
||||
- 使用梯度累积
|
||||
|
||||
## 📈 性能对比
|
||||
|
||||
| 特性 | 原始模型 | 优化模型 |
|
||||
|------|----------|----------|
|
||||
| 曼哈顿几何合规性 | ❌ | ✅ |
|
||||
| 边缘锐度 | 中等 | 优秀 |
|
||||
| 训练稳定性 | 一般 | 优秀 |
|
||||
| 生成质量 | 基础 | 优秀 |
|
||||
| 后处理 | 无 | 有 |
|
||||
| 质量评估 | 无 | 有 |
|
||||
|
||||
## 🔄 与现有管线集成
|
||||
|
||||
更新配置文件以使用优化的扩散数据:
|
||||
|
||||
```yaml
|
||||
synthetic:
|
||||
enabled: true
|
||||
ratio: 0.0 # 禁用程序化合成
|
||||
|
||||
diffusion:
|
||||
enabled: true
|
||||
png_dir: "outputs/diffusion_optimized/generated"
|
||||
ratio: 0.3 # 扩散数据在训练中的比例
|
||||
model_checkpoint: "outputs/diffusion_optimized/model/best_model.pth"
|
||||
```
|
||||
|
||||
## 📚 技术原理
|
||||
|
||||
### 曼哈顿几何约束
|
||||
|
||||
IC版图具有以下几何特征:
|
||||
- 所有线条都是水平或垂直的
|
||||
- 角度只能是90°
|
||||
- 结构具有高度的规则性
|
||||
|
||||
模型通过以下方式强制这些约束:
|
||||
1. 方向感知卷积核
|
||||
2. 角度偏差损失函数
|
||||
3. 几何保持后处理
|
||||
|
||||
### 多尺度结构损失
|
||||
|
||||
确保生成结果在不同尺度下都保持结构一致性:
|
||||
- 原始分辨率:细节保持
|
||||
- 2x下采样:中层结构
|
||||
- 4x下采样:整体布局
|
||||
|
||||
## 🛠️ 开发者指南
|
||||
|
||||
### 添加新的损失函数
|
||||
|
||||
```python
|
||||
class CustomLoss(nn.Module):
|
||||
def forward(self, pred, target):
|
||||
# 实现自定义损失
|
||||
return loss
|
||||
|
||||
# 在训练器中使用
|
||||
self.custom_loss = CustomLoss()
|
||||
```
|
||||
|
||||
### 自定义后处理
|
||||
|
||||
```python
|
||||
def custom_post_process(image):
|
||||
# 实现自定义后处理逻辑
|
||||
return processed_image
|
||||
```
|
||||
|
||||
## 📄 许可证
|
||||
|
||||
本项目遵循与主项目相同的许可证。
|
||||
|
||||
## 🤝 贡献
|
||||
|
||||
欢迎提交问题报告和改进建议!
|
||||
|
||||
---
|
||||
|
||||
**注意**:这是针对特定IC版图生成任务的优化版本,对于一般的图像生成任务,请使用原始的扩散模型实现。
|
||||
@@ -1,198 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
优化扩散模型使用示例
|
||||
演示如何使用优化后的IC版图扩散模型
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
def example_basic_training():
|
||||
"""基本训练示例"""
|
||||
print("=== 基本训练示例 ===")
|
||||
|
||||
# 创建示例数据目录
|
||||
data_dir = "example_data/ic_layouts"
|
||||
output_dir = "example_outputs/basic_training"
|
||||
|
||||
# 训练命令
|
||||
cmd = [
|
||||
sys.executable, "run_optimized_pipeline.py",
|
||||
"--data_dir", data_dir,
|
||||
"--output_dir", output_dir,
|
||||
"--epochs", 20, # 示例用较少轮数
|
||||
"--batch_size", 2,
|
||||
"--image_size", 256,
|
||||
"--num_samples", 10,
|
||||
"--create_sample_data", # 创建示例数据
|
||||
"--edge_condition",
|
||||
"--augment"
|
||||
]
|
||||
|
||||
print(f"运行命令: {' '.join(map(str, cmd))}")
|
||||
print("注意:这将创建示例数据并开始训练")
|
||||
|
||||
# 实际运行时取消注释
|
||||
# subprocess.run(cmd)
|
||||
|
||||
def example_advanced_training():
|
||||
"""高级训练示例"""
|
||||
print("\n=== 高级训练示例 ===")
|
||||
|
||||
cmd = [
|
||||
sys.executable, "train_optimized.py",
|
||||
"--data_dir", "data/high_quality_ic_layouts",
|
||||
"--output_dir", "models/advanced_diffusion",
|
||||
"--image_size", 512, # 更高分辨率
|
||||
"--batch_size", 8,
|
||||
"--epochs", 200,
|
||||
"--lr", 5e-5, # 更低学习率
|
||||
"--manhattan_weight", 0.15, # 更强的几何约束
|
||||
"--edge_condition",
|
||||
"--augment",
|
||||
"--schedule_type", "cosine",
|
||||
"--save_interval", 5,
|
||||
"--sample_interval", 10
|
||||
]
|
||||
|
||||
print(f"高级训练命令: {' '.join(map(str, cmd))}")
|
||||
|
||||
def example_generation_only():
|
||||
"""仅生成示例"""
|
||||
print("\n=== 仅生成示例 ===")
|
||||
|
||||
cmd = [
|
||||
sys.executable, "generate_optimized.py",
|
||||
"--checkpoint", "models/diffusion_optimized/best_model.pth",
|
||||
"--output_dir", "generated_samples/high_quality",
|
||||
"--num_samples", 100,
|
||||
"--num_steps", 30, # 更快采样
|
||||
"--use_ddim",
|
||||
"--batch_size", 16,
|
||||
"--use_post_process",
|
||||
"--post_process_threshold", 0.45
|
||||
]
|
||||
|
||||
print(f"生成命令: {' '.join(map(str, cmd))}")
|
||||
|
||||
def example_custom_parameters():
|
||||
"""自定义参数示例"""
|
||||
print("\n=== 自定义参数示例 ===")
|
||||
|
||||
# 针对特定需求的参数调整
|
||||
scenarios = {
|
||||
"高质量生成": {
|
||||
"description": "追求最高质量的生成结果",
|
||||
"params": {
|
||||
"--num_steps": 100,
|
||||
"--guidance_scale": 2.0,
|
||||
"--eta": 0.0, # 完全确定性
|
||||
"--use_post_process": True,
|
||||
"--post_process_threshold": 0.5
|
||||
}
|
||||
},
|
||||
"快速生成": {
|
||||
"description": "快速生成大量样本",
|
||||
"params": {
|
||||
"--num_steps": 20,
|
||||
"--batch_size": 32,
|
||||
"--eta": 0.3, # 增加随机性
|
||||
"--use_ddim": True
|
||||
}
|
||||
},
|
||||
"几何约束严格": {
|
||||
"description": "严格要求曼哈顿几何",
|
||||
"params": {
|
||||
"--manhattan_weight": 0.3, # 更强约束
|
||||
"--use_post_process": True,
|
||||
"--post_process_threshold": 0.4
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for scenario_name, config in scenarios.items():
|
||||
print(f"\n{scenario_name}: {config['description']}")
|
||||
for param, value in config['params'].items():
|
||||
print(f" {param}: {value}")
|
||||
|
||||
def example_integration():
|
||||
"""集成到现有管线示例"""
|
||||
print("\n=== 集成示例 ===")
|
||||
|
||||
# 更新配置文件
|
||||
config_update = {
|
||||
"config_file": "configs/train_config.yaml",
|
||||
"updates": {
|
||||
"synthetic.enabled": True,
|
||||
"synthetic.ratio": 0.0,
|
||||
"synthetic.diffusion.enabled": True,
|
||||
"synthetic.diffusion.png_dir": "outputs/diffusion_optimized/generated",
|
||||
"synthetic.diffusion.ratio": 0.4,
|
||||
"synthetic.diffusion.model_checkpoint": "outputs/diffusion_optimized/model/best_model.pth"
|
||||
}
|
||||
}
|
||||
|
||||
print("配置文件更新示例:")
|
||||
print(f"配置文件: {config_update['config_file']}")
|
||||
for key, value in config_update['updates'].items():
|
||||
print(f" {key}: {value}")
|
||||
|
||||
integration_cmd = [
|
||||
sys.executable, "run_optimized_pipeline.py",
|
||||
"--data_dir", "data/training_layouts",
|
||||
"--output_dir", "outputs/integration",
|
||||
"--update_config", config_update["config_file"],
|
||||
"--diffusion_ratio", 0.4,
|
||||
"--epochs", 100,
|
||||
"--num_samples", 500
|
||||
]
|
||||
|
||||
print(f"\n集成命令: {' '.join(map(str, integration_cmd))}")
|
||||
|
||||
def show_tips():
|
||||
"""显示使用建议"""
|
||||
print("\n=== 使用建议 ===")
|
||||
|
||||
tips = [
|
||||
"🎯 数据质量是关键:使用高质量、多样化的IC版图数据进行训练",
|
||||
"⚖️ 平衡约束:曼哈顿权重不宜过高(0.05-0.2),避免过度约束影响生成多样性",
|
||||
"🔄 迭代优化:根据生成结果调整损失函数权重和后处理参数",
|
||||
"📊 质量监控:定期检查生成样本的质量指标",
|
||||
"💾 定期保存:设置合理的保存间隔,避免训练中断导致损失",
|
||||
"🚀 性能优化:使用DDIM采样可以显著提高生成速度",
|
||||
"🔧 参数调优:根据具体任务需求调整模型参数"
|
||||
]
|
||||
|
||||
for tip in tips:
|
||||
print(tip)
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
print("IC版图扩散模型优化版本 - 使用示例")
|
||||
print("=" * 50)
|
||||
|
||||
# 检查是否在正确的目录
|
||||
if not Path("ic_layout_diffusion_optimized.py").exists():
|
||||
print("错误:请在 tools/diffusion/ 目录下运行此脚本")
|
||||
sys.exit(1)
|
||||
|
||||
# 显示示例
|
||||
example_basic_training()
|
||||
example_advanced_training()
|
||||
example_generation_only()
|
||||
example_custom_parameters()
|
||||
example_integration()
|
||||
show_tips()
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print("运行示例:")
|
||||
print("1. 基本使用:python run_optimized_pipeline.py --data_dir data/ic_layouts --output_dir outputs")
|
||||
print("2. 查看完整参数:python train_optimized.py --help")
|
||||
print("3. 查看生成参数:python generate_optimized.py --help")
|
||||
print("4. 阅读详细文档:README_OPTIMIZED.md")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,253 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
一键生成扩散数据的脚本:
|
||||
1. 基于原始数据训练扩散模型
|
||||
2. 使用训练好的模型生成相似图像
|
||||
3. 更新配置文件
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import yaml
|
||||
import argparse
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
import logging
|
||||
|
||||
|
||||
def setup_logging():
|
||||
"""设置日志"""
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.StreamHandler(sys.stdout)
|
||||
]
|
||||
)
|
||||
return logging.getLogger(__name__)
|
||||
|
||||
|
||||
def train_diffusion_model(data_dir, model_dir, logger, **train_kwargs):
|
||||
"""训练扩散模型"""
|
||||
logger.info("开始训练扩散模型...")
|
||||
|
||||
# 构建训练命令
|
||||
cmd = [
|
||||
sys.executable, "tools/diffusion/ic_layout_diffusion.py", "train",
|
||||
"--data_dir", data_dir,
|
||||
"--output_dir", model_dir,
|
||||
"--image_size", str(train_kwargs.get("image_size", 256)),
|
||||
"--batch_size", str(train_kwargs.get("batch_size", 8)),
|
||||
"--epochs", str(train_kwargs.get("epochs", 100)),
|
||||
"--lr", str(train_kwargs.get("lr", 1e-4)),
|
||||
"--timesteps", str(train_kwargs.get("timesteps", 1000)),
|
||||
"--num_samples", str(train_kwargs.get("num_samples", 50)),
|
||||
"--save_interval", str(train_kwargs.get("save_interval", 10))
|
||||
]
|
||||
|
||||
if train_kwargs.get("augment", False):
|
||||
cmd.append("--augment")
|
||||
|
||||
# 执行训练
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
if result.returncode != 0:
|
||||
logger.error(f"扩散模型训练失败: {result.stderr}")
|
||||
return False
|
||||
|
||||
logger.info("扩散模型训练完成")
|
||||
return True
|
||||
|
||||
|
||||
def generate_samples(model_dir, output_dir, num_samples, logger, **gen_kwargs):
|
||||
"""生成样本"""
|
||||
logger.info(f"开始生成 {num_samples} 个样本...")
|
||||
|
||||
# 查找最终模型
|
||||
model_path = Path(model_dir) / "diffusion_final.pth"
|
||||
if not model_path.exists():
|
||||
# 如果没有最终模型,查找最新的检查点
|
||||
checkpoints = list(Path(model_dir).glob("diffusion_epoch_*.pth"))
|
||||
if checkpoints:
|
||||
model_path = max(checkpoints, key=lambda x: int(x.stem.split('_')[-1]))
|
||||
else:
|
||||
logger.error(f"在 {model_dir} 中找不到模型检查点")
|
||||
return False
|
||||
|
||||
logger.info(f"使用模型: {model_path}")
|
||||
|
||||
# 构建生成命令
|
||||
cmd = [
|
||||
sys.executable, "tools/diffusion/ic_layout_diffusion.py", "generate",
|
||||
"--checkpoint", str(model_path),
|
||||
"--output_dir", output_dir,
|
||||
"--num_samples", str(num_samples),
|
||||
"--image_size", str(gen_kwargs.get("image_size", 256)),
|
||||
"--timesteps", str(gen_kwargs.get("timesteps", 1000))
|
||||
]
|
||||
|
||||
# 执行生成
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
if result.returncode != 0:
|
||||
logger.error(f"样本生成失败: {result.stderr}")
|
||||
return False
|
||||
|
||||
logger.info("样本生成完成")
|
||||
return True
|
||||
|
||||
|
||||
def update_config(config_path, output_dir, ratio, logger):
|
||||
"""更新配置文件"""
|
||||
logger.info(f"更新配置文件: {config_path}")
|
||||
|
||||
# 读取配置
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
# 确保必要的结构存在
|
||||
if 'synthetic' not in config:
|
||||
config['synthetic'] = {}
|
||||
|
||||
# 更新扩散配置
|
||||
config['synthetic']['enabled'] = True
|
||||
config['synthetic']['ratio'] = 0.0 # 禁用程序化合成
|
||||
|
||||
if 'diffusion' not in config['synthetic']:
|
||||
config['synthetic']['diffusion'] = {}
|
||||
|
||||
config['synthetic']['diffusion']['enabled'] = True
|
||||
config['synthetic']['diffusion']['png_dir'] = output_dir
|
||||
config['synthetic']['diffusion']['ratio'] = ratio
|
||||
|
||||
# 保存配置
|
||||
with open(config_path, 'w', encoding='utf-8') as f:
|
||||
yaml.dump(config, f, default_flow_style=False, allow_unicode=True)
|
||||
|
||||
logger.info(f"配置文件更新完成,扩散数据比例: {ratio}")
|
||||
|
||||
|
||||
def validate_generated_data(output_dir, logger):
|
||||
"""验证生成的数据"""
|
||||
logger.info("验证生成的数据...")
|
||||
|
||||
output_path = Path(output_dir)
|
||||
if not output_path.exists():
|
||||
logger.error(f"输出目录不存在: {output_dir}")
|
||||
return False
|
||||
|
||||
# 统计生成的图像
|
||||
png_files = list(output_path.glob("*.png"))
|
||||
if not png_files:
|
||||
logger.error("没有找到生成的PNG图像")
|
||||
return False
|
||||
|
||||
logger.info(f"验证通过,生成了 {len(png_files)} 个图像")
|
||||
return True
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="一键生成扩散数据管线")
|
||||
parser.add_argument("--config", type=str, required=True, help="配置文件路径")
|
||||
parser.add_argument("--data_dir", type=str, help="原始数据目录(覆盖配置文件)")
|
||||
parser.add_argument("--model_dir", type=str, default="models/diffusion", help="扩散模型保存目录")
|
||||
parser.add_argument("--output_dir", type=str, default="data/diffusion_generated", help="生成数据保存目录")
|
||||
parser.add_argument("--num_samples", type=int, default=200, help="生成的样本数量")
|
||||
parser.add_argument("--ratio", type=float, default=0.3, help="扩散数据在训练中的比例")
|
||||
parser.add_argument("--skip_training", action="store_true", help="跳过训练,直接生成")
|
||||
parser.add_argument("--model_checkpoint", type=str, help="指定模型检查点路径(skip_training时使用)")
|
||||
|
||||
# 训练参数
|
||||
parser.add_argument("--epochs", type=int, default=100, help="训练轮数")
|
||||
parser.add_argument("--batch_size", type=int, default=8, help="批次大小")
|
||||
parser.add_argument("--lr", type=float, default=1e-4, help="学习率")
|
||||
parser.add_argument("--image_size", type=int, default=256, help="图像尺寸")
|
||||
parser.add_argument("--augment", action="store_true", help="启用数据增强")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 设置日志
|
||||
logger = setup_logging()
|
||||
|
||||
# 读取配置文件获取数据目录
|
||||
config_path = Path(args.config)
|
||||
if not config_path.exists():
|
||||
logger.error(f"配置文件不存在: {config_path}")
|
||||
return False
|
||||
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
# 确定数据目录
|
||||
if args.data_dir:
|
||||
data_dir = args.data_dir
|
||||
else:
|
||||
# 从配置文件获取数据目录
|
||||
config_dir = config_path.parent
|
||||
layout_dir = config.get('paths', {}).get('layout_dir', 'data/layouts')
|
||||
data_dir = str(config_dir / layout_dir)
|
||||
|
||||
data_path = Path(data_dir)
|
||||
if not data_path.exists():
|
||||
logger.error(f"数据目录不存在: {data_path}")
|
||||
return False
|
||||
|
||||
logger.info(f"使用数据目录: {data_path}")
|
||||
logger.info(f"模型保存目录: {args.model_dir}")
|
||||
logger.info(f"生成数据目录: {args.output_dir}")
|
||||
logger.info(f"生成样本数量: {args.num_samples}")
|
||||
logger.info(f"训练比例: {args.ratio}")
|
||||
|
||||
# 1. 训练扩散模型(如果需要)
|
||||
if not args.skip_training:
|
||||
success = train_diffusion_model(
|
||||
data_dir=data_dir,
|
||||
model_dir=args.model_dir,
|
||||
logger=logger,
|
||||
image_size=args.image_size,
|
||||
batch_size=args.batch_size,
|
||||
epochs=args.epochs,
|
||||
lr=args.lr,
|
||||
num_samples=args.num_samples,
|
||||
augment=args.augment
|
||||
)
|
||||
if not success:
|
||||
logger.error("扩散模型训练失败")
|
||||
return False
|
||||
else:
|
||||
logger.info("跳过训练步骤")
|
||||
|
||||
# 2. 生成样本
|
||||
success = generate_samples(
|
||||
model_dir=args.model_dir,
|
||||
output_dir=args.output_dir,
|
||||
num_samples=args.num_samples,
|
||||
logger=logger,
|
||||
image_size=args.image_size
|
||||
)
|
||||
if not success:
|
||||
logger.error("样本生成失败")
|
||||
return False
|
||||
|
||||
# 3. 验证生成的数据
|
||||
if not validate_generated_data(args.output_dir, logger):
|
||||
logger.error("数据验证失败")
|
||||
return False
|
||||
|
||||
# 4. 更新配置文件
|
||||
update_config(
|
||||
config_path=args.config,
|
||||
output_dir=args.output_dir,
|
||||
ratio=args.ratio,
|
||||
logger=logger
|
||||
)
|
||||
|
||||
logger.info("=== 扩散数据生成管线完成 ===")
|
||||
logger.info(f"生成数据位置: {args.output_dir}")
|
||||
logger.info(f"配置文件已更新: {args.config}")
|
||||
logger.info(f"扩散数据比例: {args.ratio}")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = main()
|
||||
sys.exit(0 if success else 1)
|
||||
@@ -1,319 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
使用优化后的扩散模型生成IC版图图像
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from pathlib import Path
|
||||
import logging
|
||||
import argparse
|
||||
import yaml
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
# 导入优化后的模块
|
||||
from ic_layout_diffusion_optimized import (
|
||||
ManhattanAwareUNet,
|
||||
OptimizedNoiseScheduler,
|
||||
OptimizedDiffusionTrainer,
|
||||
manhattan_post_process,
|
||||
manhattan_regularization_loss
|
||||
)
|
||||
|
||||
def setup_logging():
|
||||
"""设置日志"""
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.StreamHandler(sys.stdout),
|
||||
logging.FileHandler('diffusion_generation.log')
|
||||
]
|
||||
)
|
||||
return logging.getLogger(__name__)
|
||||
|
||||
def load_model(checkpoint_path, device):
|
||||
"""加载训练好的模型"""
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 加载检查点
|
||||
checkpoint = torch.load(checkpoint_path, map_location=device)
|
||||
|
||||
# 从检查点中获取配置信息(如果有)
|
||||
config = checkpoint.get('config', {})
|
||||
|
||||
# 创建模型
|
||||
model = ManhattanAwareUNet(
|
||||
in_channels=1,
|
||||
out_channels=1,
|
||||
use_edge_condition=config.get('edge_condition', False)
|
||||
).to(device)
|
||||
|
||||
# 加载权重
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
model.eval()
|
||||
|
||||
logger.info(f"模型已加载: {checkpoint_path}")
|
||||
logger.info(f"模型参数数量: {sum(p.numel() for p in model.parameters()):,}")
|
||||
|
||||
return model, config
|
||||
|
||||
def ddim_sample(model, scheduler, num_samples, image_size, device, num_steps=50, eta=0.0):
|
||||
"""DDIM采样,比标准DDPM更快"""
|
||||
model.eval()
|
||||
|
||||
# 从纯噪声开始
|
||||
x = torch.randn(num_samples, 1, image_size, image_size).to(device)
|
||||
|
||||
# 选择时间步
|
||||
timesteps = torch.linspace(scheduler.num_timesteps - 1, 0, num_steps).long().to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
for i, t in enumerate(timesteps):
|
||||
t_batch = torch.full((num_samples,), t, device=device)
|
||||
|
||||
# 预测噪声
|
||||
predicted_noise = model(x, t_batch)
|
||||
|
||||
# 计算原始图像的估计
|
||||
alpha_t = scheduler.alphas[t].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
||||
alpha_cumprod_t = scheduler.alphas_cumprod[t].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
||||
beta_t = scheduler.betas[t].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
||||
sqrt_one_minus_alpha_cumprod_t = scheduler.sqrt_one_minus_alphas_cumprod[t].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
||||
|
||||
# 计算x_0的估计
|
||||
x_0_pred = (x - sqrt_one_minus_alpha_cumprod_t * predicted_noise) / torch.sqrt(alpha_cumprod_t)
|
||||
|
||||
# 计算前一时间步的方向
|
||||
if i < len(timesteps) - 1:
|
||||
alpha_t_prev = scheduler.alphas[timesteps[i+1]]
|
||||
alpha_cumprod_t_prev = scheduler.alphas_cumprod[timesteps[i+1]]
|
||||
sqrt_alpha_cumprod_t_prev = torch.sqrt(alpha_cumprod_t_prev).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
||||
sqrt_one_minus_alpha_cumprod_t_prev = torch.sqrt(1 - alpha_cumprod_t_prev).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
||||
|
||||
# 计算方差
|
||||
variance = eta * torch.sqrt(beta_t).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
||||
|
||||
# 计算前一时间步的x
|
||||
x = sqrt_alpha_cumprod_t_prev * x_0_pred + torch.sqrt(1 - alpha_cumprod_t_prev - variance**2) * predicted_noise
|
||||
|
||||
if eta > 0:
|
||||
noise = torch.randn_like(x)
|
||||
x += variance * noise
|
||||
else:
|
||||
x = x_0_pred
|
||||
|
||||
# 限制范围
|
||||
x = torch.clamp(x, -2.0, 2.0)
|
||||
|
||||
return torch.clamp(x, 0.0, 1.0)
|
||||
|
||||
def generate_with_guidance(model, scheduler, num_samples, image_size, device,
|
||||
guidance_scale=1.0, num_steps=50, use_ddim=True):
|
||||
"""带引导的采样(可扩展为classifier-free guidance)"""
|
||||
|
||||
if use_ddim:
|
||||
# 使用DDIM采样
|
||||
samples = ddim_sample(model, scheduler, num_samples, image_size, device, num_steps)
|
||||
else:
|
||||
# 使用标准DDPM采样
|
||||
trainer = OptimizedDiffusionTrainer(model, scheduler, device)
|
||||
samples = trainer.generate(num_samples, image_size, save_dir=None, use_post_process=False)
|
||||
|
||||
return samples
|
||||
|
||||
def evaluate_generation_quality(samples, device):
|
||||
"""评估生成质量"""
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
quality_metrics = {}
|
||||
|
||||
# 1. 曼哈顿几何合规性
|
||||
manhattan_loss = manhattan_regularization_loss(samples, device)
|
||||
quality_metrics['manhattan_compliance'] = float(manhattan_loss.item())
|
||||
|
||||
# 2. 边缘锐度
|
||||
sobel_x = torch.tensor([[[[-1,0,1],[-2,0,2],[-1,0,1]]]], device=device, dtype=samples.dtype)
|
||||
sobel_y = torch.tensor([[[[-1,-2,-1],[0,0,0],[1,2,1]]]], device=device, dtype=samples.dtype)
|
||||
|
||||
edge_x = F.conv2d(samples, sobel_x, padding=1)
|
||||
edge_y = F.conv2d(samples, sobel_y, padding=1)
|
||||
edge_magnitude = torch.sqrt(edge_x**2 + edge_y**2)
|
||||
|
||||
quality_metrics['edge_sharpness'] = float(torch.mean(edge_magnitude).item())
|
||||
|
||||
# 3. 对比度
|
||||
quality_metrics['contrast'] = float(torch.std(samples).item())
|
||||
|
||||
# 4. 稀疏性(IC版图通常是稀疏的)
|
||||
quality_metrics['sparsity'] = float((samples < 0.1).float().mean().item())
|
||||
|
||||
logger.info("生成质量评估:")
|
||||
for metric, value in quality_metrics.items():
|
||||
logger.info(f" {metric}: {value:.4f}")
|
||||
|
||||
return quality_metrics
|
||||
|
||||
def generate_optimized_samples(args):
|
||||
"""生成优化样本的主函数"""
|
||||
logger = setup_logging()
|
||||
|
||||
# 设备检查
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
logger.info(f"使用设备: {device}")
|
||||
|
||||
# 设置随机种子
|
||||
torch.manual_seed(args.seed)
|
||||
if device.type == 'cuda':
|
||||
torch.cuda.manual_seed(args.seed)
|
||||
|
||||
# 创建输出目录
|
||||
output_dir = Path(args.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 加载模型
|
||||
model, config = load_model(args.checkpoint, device)
|
||||
|
||||
# 创建调度器
|
||||
scheduler = OptimizedNoiseScheduler(
|
||||
num_timesteps=config.get('timesteps', args.timesteps),
|
||||
schedule_type=config.get('schedule_type', args.schedule_type)
|
||||
)
|
||||
|
||||
# 生成参数
|
||||
generation_config = {
|
||||
'num_samples': args.num_samples,
|
||||
'image_size': args.image_size,
|
||||
'guidance_scale': args.guidance_scale,
|
||||
'num_steps': args.num_steps,
|
||||
'use_ddim': args.use_ddim,
|
||||
'eta': args.eta,
|
||||
'seed': args.seed
|
||||
}
|
||||
|
||||
# 保存生成配置
|
||||
with open(output_dir / 'generation_config.yaml', 'w') as f:
|
||||
yaml.dump(generation_config, f, default_flow_style=False)
|
||||
|
||||
logger.info(f"开始生成 {args.num_samples} 个样本...")
|
||||
logger.info(f"采样步数: {args.num_steps}, DDIM: {args.use_ddim}, ETA: {args.eta}")
|
||||
|
||||
# 分批生成以避免内存不足
|
||||
all_samples = []
|
||||
batch_size = min(args.batch_size, args.num_samples)
|
||||
num_batches = (args.num_samples + batch_size - 1) // batch_size
|
||||
|
||||
for batch_idx in range(num_batches):
|
||||
start_idx = batch_idx * batch_size
|
||||
end_idx = min(start_idx + batch_size, args.num_samples)
|
||||
current_batch_size = end_idx - start_idx
|
||||
|
||||
logger.info(f"生成批次 {batch_idx + 1}/{num_batches} ({current_batch_size} 个样本)")
|
||||
|
||||
# 生成样本
|
||||
with torch.no_grad():
|
||||
samples = generate_with_guidance(
|
||||
model, scheduler, current_batch_size, args.image_size, device,
|
||||
args.guidance_scale, args.num_steps, args.use_ddim
|
||||
)
|
||||
|
||||
# 后处理
|
||||
if args.use_post_process:
|
||||
samples = manhattan_post_process(samples, threshold=args.post_process_threshold)
|
||||
|
||||
all_samples.append(samples)
|
||||
|
||||
# 立即保存当前批次
|
||||
batch_dir = output_dir / f"batch_{batch_idx + 1}"
|
||||
batch_dir.mkdir(exist_ok=True)
|
||||
|
||||
for i in range(current_batch_size):
|
||||
img_tensor = samples[i].cpu()
|
||||
img_array = (img_tensor.squeeze().numpy() * 255).astype(np.uint8)
|
||||
img = Image.fromarray(img_array, mode='L')
|
||||
img.save(batch_dir / f"sample_{start_idx + i:06d}.png")
|
||||
|
||||
# 合并所有样本
|
||||
all_samples = torch.cat(all_samples, dim=0)
|
||||
|
||||
# 评估生成质量
|
||||
quality_metrics = evaluate_generation_quality(all_samples, device)
|
||||
|
||||
# 保存质量评估结果
|
||||
with open(output_dir / 'quality_metrics.yaml', 'w') as f:
|
||||
yaml.dump(quality_metrics, f, default_flow_style=False)
|
||||
|
||||
# 创建质量报告
|
||||
report_content = f"""
|
||||
IC版图扩散模型生成报告
|
||||
======================
|
||||
|
||||
生成配置:
|
||||
- 模型检查点: {args.checkpoint}
|
||||
- 样本数量: {args.num_samples}
|
||||
- 图像尺寸: {args.image_size}x{args.image_size}
|
||||
- 采样步数: {args.num_steps}
|
||||
- DDIM采样: {args.use_ddim}
|
||||
- 后处理: {args.use_post_process}
|
||||
|
||||
质量指标:
|
||||
- 曼哈顿几何合规性: {quality_metrics['manhattan_compliance']:.4f} (越低越好)
|
||||
- 边缘锐度: {quality_metrics['edge_sharpness']:.4f}
|
||||
- 对比度: {quality_metrics['contrast']:.4f}
|
||||
- 稀疏性: {quality_metrics['sparsity']:.4f}
|
||||
|
||||
输出目录: {output_dir}
|
||||
"""
|
||||
|
||||
with open(output_dir / 'generation_report.txt', 'w') as f:
|
||||
f.write(report_content)
|
||||
|
||||
logger.info("生成完成!")
|
||||
logger.info(f"样本保存目录: {output_dir}")
|
||||
logger.info(f"质量报告: {output_dir / 'generation_report.txt'}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="使用优化的扩散模型生成IC版图")
|
||||
|
||||
# 必需参数
|
||||
parser.add_argument('--checkpoint', type=str, required=True, help='模型检查点路径')
|
||||
parser.add_argument('--output_dir', type=str, required=True, help='输出目录')
|
||||
|
||||
# 生成参数
|
||||
parser.add_argument('--num_samples', type=int, default=200, help='生成样本数量')
|
||||
parser.add_argument('--image_size', type=int, default=256, help='图像尺寸')
|
||||
parser.add_argument('--seed', type=int, default=42, help='随机种子')
|
||||
parser.add_argument('--batch_size', type=int, default=8, help='批次大小')
|
||||
|
||||
# 采样参数
|
||||
parser.add_argument('--num_steps', type=int, default=50, help='采样步数')
|
||||
parser.add_argument('--use_ddim', action='store_true', default=True, help='使用DDIM采样')
|
||||
parser.add_argument('--guidance_scale', type=float, default=1.0, help='引导尺度')
|
||||
parser.add_argument('--eta', type=float, default=0.0, help='DDIM eta参数 (0=确定性, 1=随机)')
|
||||
|
||||
# 后处理参数
|
||||
parser.add_argument('--use_post_process', action='store_true', default=True, help='启用后处理')
|
||||
parser.add_argument('--post_process_threshold', type=float, default=0.5, help='后处理阈值')
|
||||
|
||||
# 模型配置(用于覆盖检查点中的配置)
|
||||
parser.add_argument('--timesteps', type=int, default=1000, help='扩散时间步数')
|
||||
parser.add_argument('--schedule_type', type=str, default='cosine',
|
||||
choices=['linear', 'cosine'], help='噪声调度类型')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 检查检查点文件
|
||||
if not Path(args.checkpoint).exists():
|
||||
print(f"错误: 检查点文件不存在: {args.checkpoint}")
|
||||
sys.exit(1)
|
||||
|
||||
# 开始生成
|
||||
generate_optimized_samples(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,393 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
基于原始IC版图数据训练扩散模型,生成相似图像的完整实现。
|
||||
|
||||
使用DDPM (Denoising Diffusion Probabilistic Models)
|
||||
针对单通道灰度IC版图图像进行优化。
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from pathlib import Path
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from torchvision import transforms
|
||||
import logging
|
||||
|
||||
# 尝试导入tqdm,如果没有则使用简单的进度显示
|
||||
try:
|
||||
from tqdm import tqdm
|
||||
except ImportError:
|
||||
def tqdm(iterable, **kwargs):
|
||||
return iterable
|
||||
|
||||
|
||||
class ICDiffusionDataset(Dataset):
|
||||
"""IC版图扩散模型训练数据集"""
|
||||
|
||||
def __init__(self, image_dir, image_size=256, augment=True):
|
||||
self.image_dir = Path(image_dir)
|
||||
self.image_size = image_size
|
||||
|
||||
# 获取所有PNG图像
|
||||
self.image_paths = []
|
||||
for ext in ['*.png', '*.jpg', '*.jpeg']:
|
||||
self.image_paths.extend(list(self.image_dir.glob(ext)))
|
||||
|
||||
# 数据变换
|
||||
self.transform = transforms.Compose([
|
||||
transforms.Resize((image_size, image_size)),
|
||||
transforms.ToTensor(), # 转换到[0,1]范围
|
||||
])
|
||||
|
||||
# 数据增强
|
||||
self.augment = augment
|
||||
if augment:
|
||||
self.aug_transform = transforms.Compose([
|
||||
transforms.RandomHorizontalFlip(p=0.5),
|
||||
transforms.RandomVerticalFlip(p=0.5),
|
||||
transforms.RandomRotation(90, fill=0),
|
||||
])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.image_paths)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
img_path = self.image_paths[idx]
|
||||
image = Image.open(img_path).convert('L') # 确保是灰度图
|
||||
|
||||
# 基础变换
|
||||
image = self.transform(image)
|
||||
|
||||
# 数据增强
|
||||
if self.augment and np.random.random() > 0.5:
|
||||
image = self.aug_transform(image)
|
||||
|
||||
return image
|
||||
|
||||
|
||||
class UNet(nn.Module):
|
||||
"""简化的U-Net架构用于扩散模型"""
|
||||
|
||||
def __init__(self, in_channels=1, out_channels=1, time_dim=256):
|
||||
super().__init__()
|
||||
|
||||
# 时间嵌入
|
||||
self.time_mlp = nn.Sequential(
|
||||
nn.Linear(1, time_dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(time_dim, time_dim)
|
||||
)
|
||||
|
||||
# 编码器
|
||||
self.encoder = nn.ModuleList([
|
||||
nn.Conv2d(in_channels, 64, 3, padding=1),
|
||||
nn.Conv2d(64, 128, 3, stride=2, padding=1),
|
||||
nn.Conv2d(128, 256, 3, stride=2, padding=1),
|
||||
nn.Conv2d(256, 512, 3, stride=2, padding=1),
|
||||
])
|
||||
|
||||
# 中间层
|
||||
self.middle = nn.Sequential(
|
||||
nn.Conv2d(512, 512, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(512, 512, 3, padding=1)
|
||||
)
|
||||
|
||||
# 解码器
|
||||
self.decoder = nn.ModuleList([
|
||||
nn.ConvTranspose2d(512, 256, 3, stride=2, padding=1, output_padding=1),
|
||||
nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1),
|
||||
nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1),
|
||||
])
|
||||
|
||||
# 输出层
|
||||
self.output = nn.Conv2d(64, out_channels, 3, padding=1)
|
||||
|
||||
# 时间融合层
|
||||
self.time_fusion = nn.ModuleList([
|
||||
nn.Linear(time_dim, 64),
|
||||
nn.Linear(time_dim, 128),
|
||||
nn.Linear(time_dim, 256),
|
||||
nn.Linear(time_dim, 512),
|
||||
])
|
||||
|
||||
# 归一化层
|
||||
self.norms = nn.ModuleList([
|
||||
nn.GroupNorm(8, 64),
|
||||
nn.GroupNorm(8, 128),
|
||||
nn.GroupNorm(8, 256),
|
||||
nn.GroupNorm(8, 512),
|
||||
])
|
||||
|
||||
def forward(self, x, t):
|
||||
# 时间嵌入
|
||||
t_emb = self.time_mlp(t.float().unsqueeze(-1)) # [B, time_dim]
|
||||
|
||||
# 编码器路径
|
||||
skips = []
|
||||
for i, (conv, norm, fusion) in enumerate(zip(self.encoder, self.norms, self.time_fusion)):
|
||||
x = conv(x)
|
||||
x = norm(x)
|
||||
# 融合时间信息
|
||||
t_feat = fusion(t_emb).unsqueeze(-1).unsqueeze(-1)
|
||||
x = x + t_feat
|
||||
x = F.silu(x)
|
||||
skips.append(x)
|
||||
if i < len(self.encoder) - 1:
|
||||
x = F.silu(x)
|
||||
|
||||
# 中间层
|
||||
x = self.middle(x)
|
||||
x = F.silu(x)
|
||||
|
||||
# 解码器路径
|
||||
for i, (deconv, skip) in enumerate(zip(self.decoder, reversed(skips[:-1]))):
|
||||
x = deconv(x)
|
||||
x = x + skip # 跳跃连接
|
||||
x = F.silu(x)
|
||||
|
||||
# 输出
|
||||
x = self.output(x)
|
||||
return x
|
||||
|
||||
|
||||
class NoiseScheduler:
|
||||
"""噪声调度器"""
|
||||
|
||||
def __init__(self, num_timesteps=1000, beta_start=1e-4, beta_end=0.02):
|
||||
self.num_timesteps = num_timesteps
|
||||
|
||||
# beta调度
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_timesteps)
|
||||
|
||||
# 预计算
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = torch.cumprod(self.alphas, axis=0)
|
||||
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
|
||||
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)
|
||||
|
||||
def add_noise(self, x_0, t):
|
||||
"""向干净图像添加噪声"""
|
||||
noise = torch.randn_like(x_0)
|
||||
sqrt_alphas_cumprod_t = self.sqrt_alphas_cumprod[t].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
||||
sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
||||
|
||||
return sqrt_alphas_cumprod_t * x_0 + sqrt_one_minus_alphas_cumprod_t * noise, noise
|
||||
|
||||
def sample_timestep(self, batch_size):
|
||||
"""采样时间步"""
|
||||
return torch.randint(0, self.num_timesteps, (batch_size,))
|
||||
|
||||
def step(self, model, x_t, t):
|
||||
"""单步去噪"""
|
||||
# 预测噪声
|
||||
predicted_noise = model(x_t, t)
|
||||
|
||||
# 计算系数
|
||||
alpha_t = self.alphas[t].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
||||
sqrt_alpha_t = torch.sqrt(alpha_t)
|
||||
beta_t = self.betas[t].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
||||
sqrt_one_minus_alpha_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
||||
|
||||
# 计算均值
|
||||
model_mean = (1.0 / sqrt_alpha_t) * (x_t - (beta_t / sqrt_one_minus_alpha_cumprod_t) * predicted_noise)
|
||||
|
||||
if t.min() == 0:
|
||||
return model_mean
|
||||
else:
|
||||
noise = torch.randn_like(x_t)
|
||||
return model_mean + torch.sqrt(beta_t) * noise
|
||||
|
||||
|
||||
class DiffusionTrainer:
|
||||
"""扩散模型训练器"""
|
||||
|
||||
def __init__(self, model, scheduler, device='cuda'):
|
||||
self.model = model.to(device)
|
||||
self.scheduler = scheduler
|
||||
self.device = device
|
||||
self.loss_fn = nn.MSELoss()
|
||||
|
||||
def train_step(self, optimizer, dataloader):
|
||||
"""单步训练"""
|
||||
self.model.train()
|
||||
total_loss = 0
|
||||
|
||||
for batch in dataloader:
|
||||
batch = batch.to(self.device)
|
||||
|
||||
# 采样时间步
|
||||
t = self.scheduler.sample_timestep(batch.shape[0]).to(self.device)
|
||||
|
||||
# 添加噪声
|
||||
noisy_batch, noise = self.scheduler.add_noise(batch, t)
|
||||
|
||||
# 预测噪声
|
||||
predicted_noise = self.model(noisy_batch, t)
|
||||
|
||||
# 计算损失
|
||||
loss = self.loss_fn(predicted_noise, noise)
|
||||
|
||||
# 反向传播
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
total_loss += loss.item()
|
||||
|
||||
return total_loss / len(dataloader)
|
||||
|
||||
def generate(self, num_samples, image_size=256, save_dir=None):
|
||||
"""生成图像"""
|
||||
self.model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
# 从纯噪声开始
|
||||
x = torch.randn(num_samples, 1, image_size, image_size).to(self.device)
|
||||
|
||||
# 逐步去噪
|
||||
for t in reversed(range(self.scheduler.num_timesteps)):
|
||||
t_batch = torch.full((num_samples,), t, device=self.device)
|
||||
x = self.scheduler.step(self.model, x, t_batch)
|
||||
|
||||
# 限制到[0,1]范围
|
||||
x = torch.clamp(x, 0.0, 1.0)
|
||||
|
||||
# 保存图像
|
||||
if save_dir:
|
||||
save_dir = Path(save_dir)
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for i in range(num_samples):
|
||||
img_tensor = x[i].cpu()
|
||||
img_array = (img_tensor.squeeze().numpy() * 255).astype(np.uint8)
|
||||
img = Image.fromarray(img_array, mode='L')
|
||||
img.save(save_dir / f"generated_{i:06d}.png")
|
||||
|
||||
return x.cpu()
|
||||
|
||||
|
||||
def train_diffusion_model(args):
|
||||
"""训练扩散模型的主函数"""
|
||||
# 设置日志
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 设备检查
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
logger.info(f"使用设备: {device}")
|
||||
|
||||
# 创建数据集和数据加载器
|
||||
dataset = ICDiffusionDataset(args.data_dir, args.image_size, args.augment)
|
||||
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=4)
|
||||
logger.info(f"数据集大小: {len(dataset)}")
|
||||
|
||||
# 创建模型和调度器
|
||||
model = UNet(in_channels=1, out_channels=1)
|
||||
scheduler = NoiseScheduler(num_timesteps=args.timesteps)
|
||||
trainer = DiffusionTrainer(model, scheduler, device)
|
||||
|
||||
# 优化器
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
|
||||
|
||||
# 训练循环
|
||||
logger.info(f"开始训练 {args.epochs} 个epoch...")
|
||||
for epoch in range(args.epochs):
|
||||
loss = trainer.train_step(optimizer, dataloader)
|
||||
logger.info(f"Epoch {epoch+1}/{args.epochs}, Loss: {loss:.6f}")
|
||||
|
||||
# 定期保存模型
|
||||
if (epoch + 1) % args.save_interval == 0:
|
||||
checkpoint = {
|
||||
'epoch': epoch,
|
||||
'model_state_dict': model.state_dict(),
|
||||
'optimizer_state_dict': optimizer.state_dict(),
|
||||
'loss': loss,
|
||||
}
|
||||
checkpoint_path = Path(args.output_dir) / f"diffusion_epoch_{epoch+1}.pth"
|
||||
checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
torch.save(checkpoint, checkpoint_path)
|
||||
logger.info(f"保存检查点: {checkpoint_path}")
|
||||
|
||||
# 生成样本
|
||||
logger.info("生成示例图像...")
|
||||
trainer.generate(
|
||||
num_samples=args.num_samples,
|
||||
image_size=args.image_size,
|
||||
save_dir=os.path.join(args.output_dir, 'samples')
|
||||
)
|
||||
|
||||
# 保存最终模型
|
||||
final_checkpoint = {
|
||||
'epoch': args.epochs,
|
||||
'model_state_dict': model.state_dict(),
|
||||
'optimizer_state_dict': optimizer.state_dict(),
|
||||
'loss': loss,
|
||||
}
|
||||
final_path = Path(args.output_dir) / "diffusion_final.pth"
|
||||
torch.save(final_checkpoint, final_path)
|
||||
logger.info(f"训练完成,最终模型保存在: {final_path}")
|
||||
|
||||
|
||||
def generate_with_trained_model(args):
|
||||
"""使用训练好的模型生成图像"""
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
# 加载模型
|
||||
model = UNet(in_channels=1, out_channels=1)
|
||||
checkpoint = torch.load(args.checkpoint, map_location=device)
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
model.to(device)
|
||||
|
||||
# 创建调度器和训练器
|
||||
scheduler = NoiseScheduler(num_timesteps=args.timesteps)
|
||||
trainer = DiffusionTrainer(model, scheduler, device)
|
||||
|
||||
# 生成图像
|
||||
trainer.generate(
|
||||
num_samples=args.num_samples,
|
||||
image_size=args.image_size,
|
||||
save_dir=args.output_dir
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="IC版图扩散模型训练和生成")
|
||||
subparsers = parser.add_subparsers(dest='command', help='命令')
|
||||
|
||||
# 训练命令
|
||||
train_parser = subparsers.add_parser('train', help='训练扩散模型')
|
||||
train_parser.add_argument('--data_dir', type=str, required=True, help='训练数据目录')
|
||||
train_parser.add_argument('--output_dir', type=str, required=True, help='输出目录')
|
||||
train_parser.add_argument('--image_size', type=int, default=256, help='图像尺寸')
|
||||
train_parser.add_argument('--batch_size', type=int, default=8, help='批次大小')
|
||||
train_parser.add_argument('--epochs', type=int, default=100, help='训练轮数')
|
||||
train_parser.add_argument('--lr', type=float, default=1e-4, help='学习率')
|
||||
train_parser.add_argument('--timesteps', type=int, default=1000, help='扩散时间步数')
|
||||
train_parser.add_argument('--num_samples', type=int, default=50, help='生成的样本数量')
|
||||
train_parser.add_argument('--save_interval', type=int, default=10, help='保存间隔')
|
||||
train_parser.add_argument('--augment', action='store_true', help='启用数据增强')
|
||||
|
||||
# 生成命令
|
||||
gen_parser = subparsers.add_parser('generate', help='使用训练好的模型生成图像')
|
||||
gen_parser.add_argument('--checkpoint', type=str, required=True, help='模型检查点路径')
|
||||
gen_parser.add_argument('--output_dir', type=str, required=True, help='输出目录')
|
||||
gen_parser.add_argument('--num_samples', type=int, default=200, help='生成样本数量')
|
||||
gen_parser.add_argument('--image_size', type=int, default=256, help='图像尺寸')
|
||||
gen_parser.add_argument('--timesteps', type=int, default=1000, help='扩散时间步数')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.command == 'train':
|
||||
train_diffusion_model(args)
|
||||
elif args.command == 'generate':
|
||||
generate_with_trained_model(args)
|
||||
else:
|
||||
parser.print_help()
|
||||
@@ -1,539 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
针对IC版图优化的去噪扩散模型
|
||||
|
||||
专门针对以曼哈顿多边形为全部组成元素的IC版图光栅化图像进行优化:
|
||||
- 曼哈顿几何感知的U-Net架构
|
||||
- 边缘感知损失函数
|
||||
- 多尺度结构损失
|
||||
- 曼哈顿约束正则化
|
||||
- 几何保持的数据增强
|
||||
- 后处理优化
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from pathlib import Path
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from torchvision import transforms
|
||||
import logging
|
||||
import cv2
|
||||
|
||||
try:
|
||||
from tqdm import tqdm
|
||||
except ImportError:
|
||||
def tqdm(iterable, **kwargs):
|
||||
return iterable
|
||||
|
||||
|
||||
class ICDiffusionDataset(Dataset):
|
||||
"""IC版图扩散模型训练数据集 - 优化版"""
|
||||
|
||||
def __init__(self, image_dir, image_size=256, augment=True, use_edge_condition=False):
|
||||
self.image_dir = Path(image_dir)
|
||||
self.image_size = image_size
|
||||
self.use_edge_condition = use_edge_condition
|
||||
|
||||
# 获取所有PNG图像
|
||||
self.image_paths = []
|
||||
for ext in ['*.png', '*.jpg', '*.jpeg']:
|
||||
self.image_paths.extend(list(self.image_dir.glob(ext)))
|
||||
|
||||
# 基础变换
|
||||
self.transform = transforms.Compose([
|
||||
transforms.Resize((image_size, image_size)),
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
|
||||
# 几何保持的数据增强
|
||||
self.augment = augment
|
||||
if augment:
|
||||
self.aug_transform = transforms.Compose([
|
||||
transforms.RandomHorizontalFlip(p=0.5),
|
||||
transforms.RandomVerticalFlip(p=0.5),
|
||||
# 移除旋转,保持曼哈顿几何
|
||||
])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.image_paths)
|
||||
|
||||
def _extract_edges(self, image_tensor):
|
||||
"""提取边缘条件图"""
|
||||
# 使用Sobel算子提取边缘
|
||||
sobel_x = torch.tensor([[[[-1,0,1],[-2,0,2],[-1,0,1]]]],
|
||||
dtype=image_tensor.dtype, device=image_tensor.device)
|
||||
sobel_y = torch.tensor([[[[-1,-2,-1],[0,0,0],[1,2,1]]]],
|
||||
dtype=image_tensor.dtype, device=image_tensor.device)
|
||||
|
||||
edge_x = F.conv2d(image_tensor.unsqueeze(0), sobel_x, padding=1)
|
||||
edge_y = F.conv2d(image_tensor.unsqueeze(0), sobel_y, padding=1)
|
||||
edge_magnitude = torch.sqrt(edge_x**2 + edge_y**2)
|
||||
|
||||
return torch.clamp(edge_magnitude, 0, 1)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
img_path = self.image_paths[idx]
|
||||
image = Image.open(img_path).convert('L')
|
||||
|
||||
# 基础变换
|
||||
image = self.transform(image)
|
||||
|
||||
# 几何保持的数据增强
|
||||
if self.augment and np.random.random() > 0.5:
|
||||
image = self.aug_transform(image)
|
||||
|
||||
if self.use_edge_condition:
|
||||
edge_condition = self._extract_edges(image)
|
||||
return image, edge_condition.squeeze(0)
|
||||
|
||||
return image
|
||||
|
||||
|
||||
class EdgeAwareLoss(nn.Module):
|
||||
"""边缘感知损失函数"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# 注册为缓冲区以避免重复创建
|
||||
self.register_buffer('sobel_x', torch.tensor([[[[-1,0,1],[-2,0,2],[-1,0,1]]]]))
|
||||
self.register_buffer('sobel_y', torch.tensor([[[[-1,-2,-1],[0,0,0],[1,2,1]]]]))
|
||||
|
||||
def forward(self, pred, target):
|
||||
# 原始MSE损失
|
||||
mse_loss = F.mse_loss(pred, target)
|
||||
|
||||
# 计算边缘
|
||||
pred_edge_x = F.conv2d(pred, self.sobel_x, padding=1)
|
||||
pred_edge_y = F.conv2d(pred, self.sobel_y, padding=1)
|
||||
target_edge_x = F.conv2d(target, self.sobel_x, padding=1)
|
||||
target_edge_y = F.conv2d(target, self.sobel_y, padding=1)
|
||||
|
||||
# 边缘损失
|
||||
edge_loss = F.mse_loss(pred_edge_x, target_edge_x) + F.mse_loss(pred_edge_y, target_edge_y)
|
||||
|
||||
return mse_loss + 0.5 * edge_loss
|
||||
|
||||
|
||||
class MultiScaleStructureLoss(nn.Module):
|
||||
"""多尺度结构损失"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, pred, target):
|
||||
# 原始分辨率损失
|
||||
loss_1x = F.mse_loss(pred, target)
|
||||
|
||||
# 2x下采样损失
|
||||
pred_2x = F.avg_pool2d(pred, 2)
|
||||
target_2x = F.avg_pool2d(target, 2)
|
||||
loss_2x = F.mse_loss(pred_2x, target_2x)
|
||||
|
||||
# 4x下采样损失
|
||||
pred_4x = F.avg_pool2d(pred, 4)
|
||||
target_4x = F.avg_pool2d(target, 4)
|
||||
loss_4x = F.mse_loss(pred_4x, target_4x)
|
||||
|
||||
return loss_1x + 0.5 * loss_2x + 0.25 * loss_4x
|
||||
|
||||
|
||||
def manhattan_regularization_loss(generated_image, device='cuda'):
|
||||
"""曼哈顿约束正则化损失"""
|
||||
if device == 'cuda':
|
||||
device = generated_image.device
|
||||
|
||||
# Sobel算子
|
||||
sobel_x = torch.tensor([[[[-1,0,1],[-2,0,2],[-1,0,1]]]], device=device, dtype=generated_image.dtype)
|
||||
sobel_y = torch.tensor([[[[-1,-2,-1],[0,0,0],[1,2,1]]]], device=device, dtype=generated_image.dtype)
|
||||
|
||||
# 检测边缘
|
||||
edge_x = F.conv2d(generated_image, sobel_x, padding=1)
|
||||
edge_y = F.conv2d(generated_image, sobel_y, padding=1)
|
||||
|
||||
# 边缘强度
|
||||
edge_magnitude = torch.sqrt(edge_x**2 + edge_y**2 + 1e-8)
|
||||
|
||||
# 计算角度偏差
|
||||
angles = torch.atan2(edge_y, edge_x)
|
||||
|
||||
# 惩罚不接近0°、90°、180°或270°的角度
|
||||
angle_penalty = torch.min(
|
||||
torch.min(torch.abs(angles), torch.abs(angles - np.pi/2)),
|
||||
torch.min(torch.abs(angles - np.pi), torch.abs(angles - 3*np.pi/2))
|
||||
)
|
||||
|
||||
return torch.mean(angle_penalty * edge_magnitude)
|
||||
|
||||
|
||||
class ManhattanAwareUNet(nn.Module):
|
||||
"""曼哈顿几何感知的U-Net架构"""
|
||||
|
||||
def __init__(self, in_channels=1, out_channels=1, time_dim=256, use_edge_condition=False):
|
||||
super().__init__()
|
||||
self.use_edge_condition = use_edge_condition
|
||||
|
||||
# 输入通道数(原始图像 + 可选边缘条件)
|
||||
input_channels = in_channels + (1 if use_edge_condition else 0)
|
||||
|
||||
# 时间嵌入
|
||||
self.time_mlp = nn.Sequential(
|
||||
nn.Linear(1, time_dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(time_dim, time_dim)
|
||||
)
|
||||
|
||||
# 曼哈顿几何感知的初始卷积层
|
||||
self.horiz_conv = nn.Conv2d(input_channels, 32, (1, 7), padding=(0, 3))
|
||||
self.vert_conv = nn.Conv2d(input_channels, 32, (7, 1), padding=(3, 0))
|
||||
self.standard_conv = nn.Conv2d(input_channels, 32, 3, padding=1)
|
||||
|
||||
# 特征融合
|
||||
self.initial_fusion = nn.Sequential(
|
||||
nn.Conv2d(96, 64, 3, padding=1),
|
||||
nn.GroupNorm(8, 64),
|
||||
nn.SiLU()
|
||||
)
|
||||
|
||||
# 编码器 - 增强版
|
||||
self.encoder = nn.ModuleList([
|
||||
self._make_block(64, 128),
|
||||
self._make_block(128, 256, stride=2),
|
||||
self._make_block(256, 512, stride=2),
|
||||
self._make_block(512, 1024, stride=2),
|
||||
])
|
||||
|
||||
# 中间层
|
||||
self.middle = nn.Sequential(
|
||||
nn.Conv2d(1024, 1024, 3, padding=1),
|
||||
nn.GroupNorm(8, 1024),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(1024, 1024, 3, padding=1),
|
||||
nn.GroupNorm(8, 1024),
|
||||
nn.SiLU(),
|
||||
)
|
||||
|
||||
# 解码器
|
||||
self.decoder = nn.ModuleList([
|
||||
self._make_decoder_block(1024, 512),
|
||||
self._make_decoder_block(512, 256),
|
||||
self._make_decoder_block(256, 128),
|
||||
self._make_decoder_block(128, 64),
|
||||
])
|
||||
|
||||
# 输出层
|
||||
self.output = nn.Sequential(
|
||||
nn.Conv2d(64, 32, 3, padding=1),
|
||||
nn.GroupNorm(8, 32),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(32, out_channels, 3, padding=1)
|
||||
)
|
||||
|
||||
# 时间融合层
|
||||
self.time_fusion = nn.ModuleList([
|
||||
nn.Linear(time_dim, 64),
|
||||
nn.Linear(time_dim, 128),
|
||||
nn.Linear(time_dim, 256),
|
||||
nn.Linear(time_dim, 512),
|
||||
nn.Linear(time_dim, 1024),
|
||||
])
|
||||
|
||||
def _make_block(self, in_channels, out_channels, stride=1):
|
||||
"""创建残差块"""
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, 3, stride=stride, padding=1),
|
||||
nn.GroupNorm(8, out_channels),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(out_channels, out_channels, 3, padding=1),
|
||||
nn.GroupNorm(8, out_channels),
|
||||
nn.SiLU(),
|
||||
)
|
||||
|
||||
def _make_decoder_block(self, in_channels, out_channels):
|
||||
"""创建解码器块"""
|
||||
return nn.Sequential(
|
||||
nn.ConvTranspose2d(in_channels, out_channels, 3, stride=2, padding=1, output_padding=1),
|
||||
nn.GroupNorm(8, out_channels),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(out_channels, out_channels, 3, padding=1),
|
||||
nn.GroupNorm(8, out_channels),
|
||||
nn.SiLU(),
|
||||
)
|
||||
|
||||
def forward(self, x, t, edge_condition=None):
|
||||
# 如果有边缘条件,连接到输入
|
||||
if self.use_edge_condition and edge_condition is not None:
|
||||
x = torch.cat([x, edge_condition], dim=1)
|
||||
|
||||
# 时间嵌入
|
||||
t_emb = self.time_mlp(t.float().unsqueeze(-1)) # [B, time_dim]
|
||||
|
||||
# 曼哈顿几何感知的特征提取
|
||||
h_features = F.silu(self.horiz_conv(x))
|
||||
v_features = F.silu(self.vert_conv(x))
|
||||
s_features = F.silu(self.standard_conv(x))
|
||||
|
||||
# 融合特征
|
||||
x = torch.cat([h_features, v_features, s_features], dim=1)
|
||||
x = self.initial_fusion(x)
|
||||
|
||||
# 编码器路径
|
||||
skips = []
|
||||
for i, (encoder, fusion) in enumerate(zip(self.encoder, self.time_fusion)):
|
||||
# 残差连接
|
||||
residual = x
|
||||
x = encoder(x)
|
||||
|
||||
# 融合时间信息
|
||||
t_feat = fusion(t_emb).unsqueeze(-1).unsqueeze(-1)
|
||||
x = x + t_feat
|
||||
|
||||
# 跳跃连接
|
||||
skips.append(x + residual if i == 0 else x)
|
||||
|
||||
# 中间层
|
||||
x = self.middle(x)
|
||||
|
||||
# 解码器路径
|
||||
for i, (decoder, skip) in enumerate(zip(self.decoder, reversed(skips))):
|
||||
x = decoder(x)
|
||||
x = x + skip # 跳跃连接
|
||||
|
||||
# 输出
|
||||
x = self.output(x)
|
||||
return x
|
||||
|
||||
|
||||
class OptimizedNoiseScheduler:
|
||||
"""优化的噪声调度器"""
|
||||
|
||||
def __init__(self, num_timesteps=1000, beta_start=1e-4, beta_end=0.02, schedule_type='linear'):
|
||||
self.num_timesteps = num_timesteps
|
||||
|
||||
# 不同调度策略
|
||||
if schedule_type == 'cosine':
|
||||
# 余弦调度,通常效果更好
|
||||
steps = num_timesteps + 1
|
||||
x = torch.linspace(0, num_timesteps, steps, dtype=torch.float64)
|
||||
alphas_cumprod = torch.cos(((x / num_timesteps) + 0.008) / 1.008 * np.pi / 2) ** 2
|
||||
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
|
||||
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
||||
self.betas = torch.clip(betas, 0, 0.999)
|
||||
else:
|
||||
# 线性调度
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_timesteps)
|
||||
|
||||
# 预计算
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = torch.cumprod(self.alphas, axis=0)
|
||||
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
|
||||
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)
|
||||
|
||||
def add_noise(self, x_0, t):
|
||||
"""向干净图像添加噪声"""
|
||||
noise = torch.randn_like(x_0)
|
||||
sqrt_alphas_cumprod_t = self.sqrt_alphas_cumprod[t].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
||||
sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
||||
|
||||
return sqrt_alphas_cumprod_t * x_0 + sqrt_one_minus_alphas_cumprod_t * noise, noise
|
||||
|
||||
def sample_timestep(self, batch_size):
|
||||
"""采样时间步"""
|
||||
return torch.randint(0, self.num_timesteps, (batch_size,))
|
||||
|
||||
def step(self, model, x_t, t):
|
||||
"""单步去噪"""
|
||||
# 预测噪声
|
||||
predicted_noise = model(x_t, t)
|
||||
|
||||
# 计算系数
|
||||
alpha_t = self.alphas[t].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
||||
sqrt_alpha_t = torch.sqrt(alpha_t)
|
||||
beta_t = self.betas[t].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
||||
sqrt_one_minus_alpha_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
||||
|
||||
# 计算均值
|
||||
model_mean = (1.0 / sqrt_alpha_t) * (x_t - (beta_t / sqrt_one_minus_alpha_cumprod_t) * predicted_noise)
|
||||
|
||||
if t.min() == 0:
|
||||
return model_mean
|
||||
else:
|
||||
noise = torch.randn_like(x_t)
|
||||
return model_mean + torch.sqrt(beta_t) * noise
|
||||
|
||||
|
||||
def manhattan_post_process(image, threshold=0.5):
|
||||
"""曼哈顿化后处理"""
|
||||
device = image.device
|
||||
|
||||
# 二值化
|
||||
binary = (image > threshold).float()
|
||||
|
||||
# 形态学操作强化直角特征
|
||||
kernel_h = torch.tensor([[[[1,1,1]]]], device=device)
|
||||
kernel_v = torch.tensor([[[[1],[1],[1]]]], device=device)
|
||||
|
||||
# 水平和垂直增强
|
||||
horizontal = F.conv2d(binary, kernel_h, padding=(0,1))
|
||||
vertical = F.conv2d(binary, kernel_v, padding=(1,0))
|
||||
|
||||
# 合并结果
|
||||
result = torch.clamp(horizontal + vertical - binary, 0, 1)
|
||||
|
||||
# 最终阈值处理
|
||||
result = (result > 0.5).float()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class OptimizedDiffusionTrainer:
|
||||
"""优化的扩散模型训练器"""
|
||||
|
||||
def __init__(self, model, scheduler, device='cuda', use_edge_condition=False):
|
||||
self.model = model.to(device)
|
||||
self.scheduler = scheduler
|
||||
self.device = device
|
||||
self.use_edge_condition = use_edge_condition
|
||||
|
||||
# 组合损失函数
|
||||
self.edge_loss = EdgeAwareLoss()
|
||||
self.structure_loss = MultiScaleStructureLoss()
|
||||
self.mse_loss = nn.MSELoss()
|
||||
|
||||
def train_step(self, optimizer, dataloader, manhattan_weight=0.1):
|
||||
"""单步训练"""
|
||||
self.model.train()
|
||||
total_loss = 0
|
||||
total_edge_loss = 0
|
||||
total_structure_loss = 0
|
||||
total_manhattan_loss = 0
|
||||
|
||||
for batch in dataloader:
|
||||
if self.use_edge_condition:
|
||||
images, edge_conditions = batch
|
||||
edge_conditions = edge_conditions.to(self.device)
|
||||
else:
|
||||
images = batch
|
||||
edge_conditions = None
|
||||
|
||||
images = images.to(self.device)
|
||||
|
||||
# 采样时间步
|
||||
t = self.scheduler.sample_timestep(images.shape[0]).to(self.device)
|
||||
|
||||
# 添加噪声
|
||||
noisy_images, noise = self.scheduler.add_noise(images, t)
|
||||
|
||||
# 预测噪声
|
||||
predicted_noise = self.model(noisy_images, t, edge_conditions)
|
||||
|
||||
# 计算多种损失
|
||||
mse_loss = self.mse_loss(predicted_noise, noise)
|
||||
edge_loss = self.edge_loss(predicted_noise, noise)
|
||||
structure_loss = self.structure_loss(predicted_noise, noise)
|
||||
|
||||
# 曼哈顿正则化损失
|
||||
with torch.no_grad():
|
||||
# 对去噪结果应用曼哈顿约束
|
||||
denoised = noisy_images - predicted_noise
|
||||
manhattan_loss = manhattan_regularization_loss(denoised, self.device)
|
||||
|
||||
# 总损失
|
||||
total_step_loss = mse_loss + 0.3 * edge_loss + 0.2 * structure_loss + manhattan_weight * manhattan_loss
|
||||
|
||||
# 反向传播
|
||||
optimizer.zero_grad()
|
||||
total_step_loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) # 梯度裁剪
|
||||
optimizer.step()
|
||||
|
||||
total_loss += total_step_loss.item()
|
||||
total_edge_loss += edge_loss.item()
|
||||
total_structure_loss += structure_loss.item()
|
||||
total_manhattan_loss += manhattan_loss.item()
|
||||
|
||||
num_batches = len(dataloader)
|
||||
return {
|
||||
'total_loss': total_loss / num_batches,
|
||||
'mse_loss': total_loss / num_batches, # 近似值
|
||||
'edge_loss': total_edge_loss / num_batches,
|
||||
'structure_loss': total_structure_loss / num_batches,
|
||||
'manhattan_loss': total_manhattan_loss / num_batches
|
||||
}
|
||||
|
||||
def generate(self, num_samples, image_size=256, save_dir=None, use_post_process=True):
|
||||
"""生成图像"""
|
||||
self.model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
# 从纯噪声开始
|
||||
x = torch.randn(num_samples, 1, image_size, image_size).to(self.device)
|
||||
|
||||
# 逐步去噪
|
||||
for t in reversed(range(self.scheduler.num_timesteps)):
|
||||
t_batch = torch.full((num_samples,), t, device=self.device)
|
||||
x = self.scheduler.step(self.model, x, t_batch)
|
||||
|
||||
# 限制到合理范围
|
||||
x = torch.clamp(x, -2.0, 2.0)
|
||||
|
||||
# 最终处理
|
||||
x = torch.clamp(x, 0.0, 1.0)
|
||||
|
||||
# 后处理
|
||||
if use_post_process:
|
||||
x = manhattan_post_process(x)
|
||||
|
||||
# 保存图像
|
||||
if save_dir:
|
||||
save_dir = Path(save_dir)
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for i in range(num_samples):
|
||||
img_tensor = x[i].cpu()
|
||||
img_array = (img_tensor.squeeze().numpy() * 255).astype(np.uint8)
|
||||
img = Image.fromarray(img_array, mode='L')
|
||||
img.save(save_dir / f"generated_{i:06d}.png")
|
||||
|
||||
return x.cpu()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="优化的IC版图扩散模型训练和生成")
|
||||
subparsers = parser.add_subparsers(dest='command', help='命令')
|
||||
|
||||
# 训练命令
|
||||
train_parser = subparsers.add_parser('train', help='训练扩散模型')
|
||||
train_parser.add_argument('--data_dir', type=str, required=True, help='训练数据目录')
|
||||
train_parser.add_argument('--output_dir', type=str, required=True, help='输出目录')
|
||||
train_parser.add_argument('--image_size', type=int, default=256, help='图像尺寸')
|
||||
train_parser.add_argument('--batch_size', type=int, default=4, help='批次大小')
|
||||
train_parser.add_argument('--epochs', type=int, default=100, help='训练轮数')
|
||||
train_parser.add_argument('--lr', type=float, default=1e-4, help='学习率')
|
||||
train_parser.add_argument('--timesteps', type=int, default=1000, help='扩散时间步数')
|
||||
train_parser.add_argument('--num_samples', type=int, default=50, help='生成的样本数量')
|
||||
train_parser.add_argument('--save_interval', type=int, default=10, help='保存间隔')
|
||||
train_parser.add_argument('--augment', action='store_true', help='启用数据增强')
|
||||
train_parser.add_argument('--edge_condition', action='store_true', help='使用边缘条件')
|
||||
train_parser.add_argument('--manhattan_weight', type=float, default=0.1, help='曼哈顿正则化权重')
|
||||
train_parser.add_argument('--schedule_type', type=str, default='cosine', choices=['linear', 'cosine'], help='噪声调度类型')
|
||||
|
||||
# 生成命令
|
||||
gen_parser = subparsers.add_parser('generate', help='使用训练好的模型生成图像')
|
||||
gen_parser.add_argument('--checkpoint', type=str, required=True, help='模型检查点路径')
|
||||
gen_parser.add_argument('--output_dir', type=str, required=True, help='输出目录')
|
||||
gen_parser.add_argument('--num_samples', type=int, default=200, help='生成样本数量')
|
||||
gen_parser.add_argument('--image_size', type=int, default=256, help='图像尺寸')
|
||||
gen_parser.add_argument('--timesteps', type=int, default=1000, help='扩散时间步数')
|
||||
gen_parser.add_argument('--use_post_process', action='store_true', default=True, help='启用后处理')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# TODO: 实现训练和生成函数,使用优化后的组件
|
||||
print("[TODO] 实现完整的训练和生成流程,使用优化后的模型架构和损失函数")
|
||||
@@ -1,46 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Prepare raster patch dataset and optional condition maps for diffusion training.
|
||||
|
||||
Planned inputs:
|
||||
- --src_dirs: one or more directories containing PNG layout images
|
||||
- --out_dir: output root for images/ and conditions/
|
||||
- --size: patch size (e.g., 256)
|
||||
- --stride: sliding stride for patch extraction
|
||||
- --min_fg_ratio: minimum foreground ratio to keep a patch (0-1)
|
||||
- --make_conditions: flags to generate edge/skeleton/distance maps
|
||||
|
||||
Current status: CLI skeleton and TODOs only.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Prepare patch dataset for diffusion training (skeleton)")
|
||||
parser.add_argument("--src_dirs", type=str, nargs="+", help="Source PNG dirs for layouts")
|
||||
parser.add_argument("--out_dir", type=str, required=True, help="Output root directory")
|
||||
parser.add_argument("--size", type=int, default=256, help="Patch size")
|
||||
parser.add_argument("--stride", type=int, default=256, help="Patch stride")
|
||||
parser.add_argument("--min_fg_ratio", type=float, default=0.02, help="Min foreground ratio to keep a patch")
|
||||
parser.add_argument("--make_edge", action="store_true", help="Generate edge map conditions (e.g., Sobel/Canny)")
|
||||
parser.add_argument("--make_skeleton", action="store_true", help="Generate morphological skeleton condition")
|
||||
parser.add_argument("--make_dist", action="store_true", help="Generate distance transform condition")
|
||||
args = parser.parse_args()
|
||||
|
||||
out_root = Path(args.out_dir)
|
||||
out_root.mkdir(parents=True, exist_ok=True)
|
||||
(out_root / "images").mkdir(exist_ok=True)
|
||||
(out_root / "conditions").mkdir(exist_ok=True)
|
||||
|
||||
# TODO: implement extraction loop over src_dirs, crop patches, filter by min_fg_ratio,
|
||||
# and save into images/; generate optional condition maps into conditions/ mirroring filenames.
|
||||
# Keep file naming consistent: images/xxx.png, conditions/xxx_edge.png, etc.
|
||||
|
||||
print("[TODO] Implement patch extraction and condition map generation.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,355 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
一键运行优化的IC版图扩散模型训练和生成管线
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import yaml
|
||||
import argparse
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
import logging
|
||||
import shutil
|
||||
|
||||
def setup_logging():
|
||||
"""设置日志"""
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.StreamHandler(sys.stdout),
|
||||
logging.FileHandler('optimized_pipeline.log')
|
||||
]
|
||||
)
|
||||
return logging.getLogger(__name__)
|
||||
|
||||
def run_command(cmd, description, logger):
|
||||
"""运行命令并处理错误"""
|
||||
logger.info(f"执行: {description}")
|
||||
logger.info(f"命令: {' '.join(cmd)}")
|
||||
|
||||
try:
|
||||
result = subprocess.run(cmd, check=True, capture_output=True, text=True)
|
||||
logger.info(f"{description} - 成功")
|
||||
if result.stdout:
|
||||
logger.debug(f"输出: {result.stdout}")
|
||||
return True
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.error(f"{description} - 失败")
|
||||
logger.error(f"错误码: {e.returncode}")
|
||||
logger.error(f"错误输出: {e.stderr}")
|
||||
return False
|
||||
|
||||
def validate_data_directory(data_dir, logger):
|
||||
"""验证数据目录"""
|
||||
data_path = Path(data_dir)
|
||||
if not data_path.exists():
|
||||
logger.error(f"数据目录不存在: {data_path}")
|
||||
return False
|
||||
|
||||
# 检查图像文件
|
||||
image_extensions = ['.png', '.jpg', '.jpeg']
|
||||
image_files = []
|
||||
for ext in image_extensions:
|
||||
image_files.extend(data_path.glob(f"*{ext}"))
|
||||
image_files.extend(data_path.glob(f"*{ext.upper()}"))
|
||||
|
||||
if len(image_files) == 0:
|
||||
logger.error(f"数据目录中没有找到图像文件: {data_path}")
|
||||
return False
|
||||
|
||||
logger.info(f"数据验证通过 - 找到 {len(image_files)} 个图像文件")
|
||||
return True
|
||||
|
||||
def create_sample_images(output_dir, logger, num_samples=5):
|
||||
"""创建示例图像"""
|
||||
logger.info("创建示例图像...")
|
||||
|
||||
# 创建简单的曼哈顿几何图案
|
||||
from PIL import Image, ImageDraw
|
||||
import numpy as np
|
||||
|
||||
sample_dir = Path(output_dir) / "reference_samples"
|
||||
sample_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for i in range(num_samples):
|
||||
# 创建空白图像
|
||||
img = Image.new('L', (256, 256), 255) # 白色背景
|
||||
draw = ImageDraw.Draw(img)
|
||||
|
||||
# 绘制曼哈顿几何图案
|
||||
np.random.seed(i)
|
||||
|
||||
# 外框
|
||||
draw.rectangle([20, 20, 236, 236], outline=0, width=2)
|
||||
|
||||
# 随机内部矩形
|
||||
for _ in range(np.random.randint(3, 8)):
|
||||
x1 = np.random.randint(40, 180)
|
||||
y1 = np.random.randint(40, 180)
|
||||
x2 = x1 + np.random.randint(20, 60)
|
||||
y2 = y1 + np.random.randint(20, 60)
|
||||
if x2 < 220 and y2 < 220: # 确保不超出边界
|
||||
draw.rectangle([x1, y1, x2, y2], outline=0, width=1)
|
||||
|
||||
# 保存图像
|
||||
img.save(sample_dir / f"sample_{i:03d}.png")
|
||||
|
||||
logger.info(f"示例图像已保存到: {sample_dir}")
|
||||
|
||||
def run_optimized_pipeline(args):
|
||||
"""运行优化管线"""
|
||||
logger = setup_logging()
|
||||
|
||||
logger.info("=== 开始优化的IC版图扩散模型管线 ===")
|
||||
|
||||
# 验证输入
|
||||
if not validate_data_directory(args.data_dir, logger):
|
||||
return False
|
||||
|
||||
# 创建输出目录
|
||||
output_dir = Path(args.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 如果需要,创建示例数据
|
||||
if args.create_sample_data:
|
||||
create_sample_images(args.data_dir, logger)
|
||||
|
||||
# 训练阶段
|
||||
if not args.skip_training:
|
||||
logger.info("\n=== 第一阶段: 训练优化模型 ===")
|
||||
|
||||
train_cmd = [
|
||||
sys.executable, "train_optimized.py",
|
||||
"--data_dir", args.data_dir,
|
||||
"--output_dir", str(output_dir / "model"),
|
||||
"--image_size", str(args.image_size),
|
||||
"--batch_size", str(args.batch_size),
|
||||
"--epochs", str(args.epochs),
|
||||
"--lr", str(args.lr),
|
||||
"--timesteps", str(args.timesteps),
|
||||
"--schedule_type", args.schedule_type,
|
||||
"--manhattan_weight", str(args.manhattan_weight),
|
||||
"--seed", str(args.seed),
|
||||
"--save_interval", str(args.save_interval),
|
||||
"--sample_interval", str(args.sample_interval),
|
||||
"--num_samples", str(args.train_samples)
|
||||
]
|
||||
|
||||
if args.edge_condition:
|
||||
train_cmd.append("--edge_condition")
|
||||
|
||||
if args.augment:
|
||||
train_cmd.append("--augment")
|
||||
|
||||
if args.resume:
|
||||
train_cmd.extend(["--resume", args.resume])
|
||||
|
||||
success = run_command(train_cmd, "训练优化模型", logger)
|
||||
if not success:
|
||||
logger.error("训练阶段失败")
|
||||
return False
|
||||
|
||||
# 查找最佳模型
|
||||
model_checkpoint = output_dir / "model" / "best_model.pth"
|
||||
if not model_checkpoint.exists():
|
||||
# 如果没有最佳模型,使用最终模型
|
||||
model_checkpoint = output_dir / "model" / "final_model.pth"
|
||||
|
||||
if not model_checkpoint.exists():
|
||||
logger.error("找不到训练好的模型")
|
||||
return False
|
||||
|
||||
else:
|
||||
logger.info("\n=== 跳过训练阶段 ===")
|
||||
model_checkpoint = args.checkpoint
|
||||
if not model_checkpoint:
|
||||
logger.error("跳过训练时需要提供 --checkpoint 参数")
|
||||
return False
|
||||
|
||||
if not Path(model_checkpoint).exists():
|
||||
logger.error(f"指定的检查点不存在: {model_checkpoint}")
|
||||
return False
|
||||
|
||||
# 生成阶段
|
||||
logger.info("\n=== 第二阶段: 生成样本 ===")
|
||||
|
||||
generate_cmd = [
|
||||
sys.executable, "generate_optimized.py",
|
||||
"--checkpoint", str(model_checkpoint),
|
||||
"--output_dir", str(output_dir / "generated"),
|
||||
"--num_samples", str(args.num_samples),
|
||||
"--image_size", str(args.image_size),
|
||||
"--batch_size", str(args.gen_batch_size),
|
||||
"--num_steps", str(args.num_steps),
|
||||
"--seed", str(args.seed),
|
||||
"--timesteps", str(args.timesteps),
|
||||
"--schedule_type", args.schedule_type
|
||||
]
|
||||
|
||||
if args.use_ddim:
|
||||
generate_cmd.append("--use_ddim")
|
||||
|
||||
if args.use_post_process:
|
||||
generate_cmd.append("--use_post_process")
|
||||
|
||||
success = run_command(generate_cmd, "生成样本", logger)
|
||||
if not success:
|
||||
logger.error("生成阶段失败")
|
||||
return False
|
||||
|
||||
# 更新配置文件(如果提供了)
|
||||
if args.update_config and Path(args.update_config).exists():
|
||||
logger.info("\n=== 第三阶段: 更新配置文件 ===")
|
||||
|
||||
config_path = Path(args.update_config)
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
# 更新扩散配置
|
||||
if 'synthetic' not in config:
|
||||
config['synthetic'] = {}
|
||||
|
||||
config['synthetic']['enabled'] = True
|
||||
config['synthetic']['ratio'] = 0.0 # 禁用程序化合成
|
||||
|
||||
if 'diffusion' not in config['synthetic']:
|
||||
config['synthetic']['diffusion'] = {}
|
||||
|
||||
config['synthetic']['diffusion']['enabled'] = True
|
||||
config['synthetic']['diffusion']['png_dir'] = str(output_dir / "generated")
|
||||
config['synthetic']['diffusion']['ratio'] = args.diffusion_ratio
|
||||
config['synthetic']['diffusion']['model_checkpoint'] = str(model_checkpoint)
|
||||
|
||||
# 保存配置
|
||||
with open(config_path, 'w', encoding='utf-8') as f:
|
||||
yaml.dump(config, f, default_flow_style=False, allow_unicode=True)
|
||||
|
||||
logger.info(f"配置文件已更新: {config_path}")
|
||||
logger.info(f"扩散数据比例: {args.diffusion_ratio}")
|
||||
|
||||
# 创建管线报告
|
||||
create_pipeline_report(output_dir, model_checkpoint, args, logger)
|
||||
|
||||
logger.info("\n=== 优化管线完成 ===")
|
||||
logger.info(f"模型: {model_checkpoint}")
|
||||
logger.info(f"生成数据: {output_dir / 'generated'}")
|
||||
logger.info(f"管线报告: {output_dir / 'pipeline_report.txt'}")
|
||||
|
||||
return True
|
||||
|
||||
def create_pipeline_report(output_dir, model_checkpoint, args, logger):
|
||||
"""创建管线报告"""
|
||||
report_content = f"""
|
||||
IC版图扩散模型优化管线报告
|
||||
============================
|
||||
|
||||
管线配置:
|
||||
- 数据目录: {args.data_dir}
|
||||
- 输出目录: {args.output_dir}
|
||||
- 图像尺寸: {args.image_size}x{args.image_size}
|
||||
- 训练轮数: {args.epochs}
|
||||
- 批次大小: {args.batch_size}
|
||||
- 学习率: {args.lr}
|
||||
- 时间步数: {args.timesteps}
|
||||
- 调度类型: {args.schedule_type}
|
||||
- 曼哈顿权重: {args.manhattan_weight}
|
||||
- 随机种子: {args.seed}
|
||||
|
||||
模型配置:
|
||||
- 边缘条件: {args.edge_condition}
|
||||
- 数据增强: {args.augment}
|
||||
- 最终模型: {model_checkpoint}
|
||||
|
||||
生成配置:
|
||||
- 生成样本数: {args.num_samples}
|
||||
- 生成批次大小: {args.gen_batch_size}
|
||||
- 采样步数: {args.num_steps}
|
||||
- DDIM采样: {args.use_ddim}
|
||||
- 后处理: {args.use_post_process}
|
||||
|
||||
优化特性:
|
||||
- 曼哈顿几何感知的U-Net架构
|
||||
- 边缘感知损失函数
|
||||
- 多尺度结构损失
|
||||
- 曼哈顿约束正则化
|
||||
- 几何保持的数据增强
|
||||
- 后处理优化
|
||||
|
||||
输出目录结构:
|
||||
- model/: 训练好的模型和检查点
|
||||
- generated/: 生成的IC版图样本
|
||||
- pipeline_report.txt: 本报告
|
||||
|
||||
质量评估:
|
||||
生成完成后,请查看 generated/quality_metrics.yaml 和 generation_report.txt 获取详细的质量评估。
|
||||
|
||||
使用说明:
|
||||
1. 训练数据应包含高质量的IC版图图像
|
||||
2. 建议使用边缘条件来提高生成质量
|
||||
3. 生成的样本可以使用后处理进一步优化
|
||||
4. 可根据质量评估结果调整训练参数
|
||||
"""
|
||||
|
||||
report_path = output_dir / 'pipeline_report.txt'
|
||||
with open(report_path, 'w', encoding='utf-8') as f:
|
||||
f.write(report_content)
|
||||
|
||||
logger.info(f"管线报告已保存: {report_path}")
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="一键运行优化的IC版图扩散模型管线")
|
||||
|
||||
# 基本参数
|
||||
parser.add_argument("--data_dir", type=str, required=True, help="训练数据目录")
|
||||
parser.add_argument("--output_dir", type=str, required=True, help="输出目录")
|
||||
|
||||
# 训练参数
|
||||
parser.add_argument("--image_size", type=int, default=256, help="图像尺寸")
|
||||
parser.add_argument("--batch_size", type=int, default=4, help="训练批次大小")
|
||||
parser.add_argument("--epochs", type=int, default=100, help="训练轮数")
|
||||
parser.add_argument("--lr", type=float, default=1e-4, help="学习率")
|
||||
parser.add_argument("--timesteps", type=int, default=1000, help="扩散时间步数")
|
||||
parser.add_argument("--schedule_type", type=str, default='cosine', choices=['linear', 'cosine'], help="噪声调度类型")
|
||||
parser.add_argument("--manhattan_weight", type=float, default=0.1, help="曼哈顿正则化权重")
|
||||
parser.add_argument("--seed", type=int, default=42, help="随机种子")
|
||||
parser.add_argument("--save_interval", type=int, default=10, help="模型保存间隔")
|
||||
parser.add_argument("--sample_interval", type=int, default=20, help="样本生成间隔")
|
||||
parser.add_argument("--train_samples", type=int, default=16, help="训练时生成的样本数量")
|
||||
|
||||
# 训练控制
|
||||
parser.add_argument("--skip_training", action='store_true', help="跳过训练,使用现有模型")
|
||||
parser.add_argument("--checkpoint", type=str, help="现有模型检查点路径(skip_training时使用)")
|
||||
parser.add_argument("--resume", type=str, help="恢复训练的检查点路径")
|
||||
parser.add_argument("--edge_condition", action='store_true', help="使用边缘条件")
|
||||
parser.add_argument("--augment", action='store_true', help="启用数据增强")
|
||||
|
||||
# 生成参数
|
||||
parser.add_argument("--num_samples", type=int, default=200, help="生成样本数量")
|
||||
parser.add_argument("--gen_batch_size", type=int, default=8, help="生成批次大小")
|
||||
parser.add_argument("--num_steps", type=int, default=50, help="采样步数")
|
||||
parser.add_argument("--use_ddim", action='store_true', default=True, help="使用DDIM采样")
|
||||
parser.add_argument("--use_post_process", action='store_true', default=True, help="启用后处理")
|
||||
|
||||
# 配置更新
|
||||
parser.add_argument("--update_config", type=str, help="要更新的配置文件路径")
|
||||
parser.add_argument("--diffusion_ratio", type=float, default=0.3, help="扩散数据在训练中的比例")
|
||||
|
||||
# 开发选项
|
||||
parser.add_argument("--create_sample_data", action='store_true', help="创建示例训练数据")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 验证参数
|
||||
if args.skip_training and not args.checkpoint:
|
||||
print("错误: 跳过训练时必须提供 --checkpoint 参数")
|
||||
sys.exit(1)
|
||||
|
||||
# 运行管线
|
||||
success = run_optimized_pipeline(args)
|
||||
sys.exit(0 if success else 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,38 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Sample layout patches using a trained diffusion model (skeleton).
|
||||
|
||||
Outputs raster PNGs into a target directory compatible with current training pipeline (no H pairing).
|
||||
|
||||
Current status: CLI skeleton and TODOs only.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Sample layout patches from diffusion model (skeleton)")
|
||||
parser.add_argument("--ckpt", type=str, required=True, help="Path to trained diffusion checkpoint or HF repo id")
|
||||
parser.add_argument("--out_dir", type=str, required=True, help="Directory to write sampled PNGs")
|
||||
parser.add_argument("--num", type=int, default=200)
|
||||
parser.add_argument("--image_size", type=int, default=256)
|
||||
parser.add_argument("--guidance", type=float, default=5.0)
|
||||
parser.add_argument("--steps", type=int, default=50)
|
||||
parser.add_argument("--seed", type=int, default=42)
|
||||
parser.add_argument("--cond_dir", type=str, default=None, help="Optional condition maps directory")
|
||||
parser.add_argument("--cond_types", type=str, nargs="*", default=None, help="e.g., edge skeleton dist")
|
||||
args = parser.parse_args()
|
||||
|
||||
out_dir = Path(args.out_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# TODO: load pipeline from ckpt, set scheduler, handle conditions if provided,
|
||||
# sample args.num images, save as PNG files into out_dir.
|
||||
|
||||
print("[TODO] Implement diffusion sampling and PNG saving.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,37 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Train a diffusion model for layout patch generation (skeleton).
|
||||
|
||||
Planned: fine-tune Stable Diffusion (or Latent Diffusion) with optional ControlNet edge/skeleton conditions.
|
||||
|
||||
Dependencies to consider: diffusers, transformers, accelerate, torch, torchvision, opencv-python.
|
||||
|
||||
Current status: CLI skeleton and TODOs only.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Train diffusion model for layout patches (skeleton)")
|
||||
parser.add_argument("--data_dir", type=str, required=True, help="Prepared dataset root (images/ + conditions/)")
|
||||
parser.add_argument("--output_dir", type=str, required=True, help="Checkpoint output directory")
|
||||
parser.add_argument("--image_size", type=int, default=256)
|
||||
parser.add_argument("--batch_size", type=int, default=8)
|
||||
parser.add_argument("--lr", type=float, default=1e-4)
|
||||
parser.add_argument("--max_steps", type=int, default=100000)
|
||||
parser.add_argument("--use_controlnet", action="store_true", help="Train with ControlNet conditioning")
|
||||
parser.add_argument("--condition_types", type=str, nargs="*", default=["edge"], help="e.g., edge skeleton dist")
|
||||
args = parser.parse_args()
|
||||
|
||||
# TODO: implement dataset/dataloader (images and optional conditions)
|
||||
# TODO: load base pipeline (Stable Diffusion or Latent Diffusion) and optionally ControlNet
|
||||
# TODO: set up optimizer, LR schedule, EMA, gradient accumulation, and run training loop
|
||||
# TODO: save periodic checkpoints to output_dir
|
||||
|
||||
print("[TODO] Implement diffusion training loop and checkpoints.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,333 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
使用优化后的扩散模型进行训练的完整脚本
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from pathlib import Path
|
||||
import logging
|
||||
import yaml
|
||||
from torch.utils.data import DataLoader
|
||||
import argparse
|
||||
|
||||
# 导入优化后的模块
|
||||
from ic_layout_diffusion_optimized import (
|
||||
ICDiffusionDataset,
|
||||
ManhattanAwareUNet,
|
||||
OptimizedNoiseScheduler,
|
||||
OptimizedDiffusionTrainer
|
||||
)
|
||||
|
||||
def setup_logging():
|
||||
"""设置日志"""
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.StreamHandler(sys.stdout),
|
||||
logging.FileHandler('diffusion_training.log')
|
||||
]
|
||||
)
|
||||
return logging.getLogger(__name__)
|
||||
|
||||
def save_checkpoint(model, optimizer, scheduler, epoch, losses, checkpoint_path):
|
||||
"""保存检查点"""
|
||||
checkpoint = {
|
||||
'epoch': epoch,
|
||||
'model_state_dict': model.state_dict(),
|
||||
'optimizer_state_dict': optimizer.state_dict(),
|
||||
'scheduler_state_dict': scheduler.state_dict() if hasattr(scheduler, 'state_dict') else None,
|
||||
'losses': losses
|
||||
}
|
||||
torch.save(checkpoint, checkpoint_path)
|
||||
logging.info(f"检查点已保存: {checkpoint_path}")
|
||||
|
||||
def load_checkpoint(checkpoint_path, model, optimizer=None, scheduler=None):
|
||||
"""加载检查点"""
|
||||
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
|
||||
if optimizer is not None and 'optimizer_state_dict' in checkpoint:
|
||||
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||
|
||||
if scheduler is not None and 'scheduler_state_dict' in checkpoint and checkpoint['scheduler_state_dict']:
|
||||
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
||||
|
||||
start_epoch = checkpoint.get('epoch', 0)
|
||||
losses = checkpoint.get('losses', {})
|
||||
|
||||
logging.info(f"检查点已加载: {checkpoint_path}, 从epoch {start_epoch}继续")
|
||||
return start_epoch, losses
|
||||
|
||||
def validate_model(trainer, val_dataloader, device):
|
||||
"""验证模型"""
|
||||
trainer.model.eval()
|
||||
total_loss = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in val_dataloader:
|
||||
if trainer.use_edge_condition:
|
||||
images, edge_conditions = batch
|
||||
edge_conditions = edge_conditions.to(device)
|
||||
else:
|
||||
images = batch
|
||||
edge_conditions = None
|
||||
|
||||
images = images.to(device)
|
||||
t = trainer.scheduler.sample_timestep(images.shape[0]).to(device)
|
||||
noisy_images, noise = trainer.scheduler.add_noise(images, t)
|
||||
predicted_noise = trainer.model(noisy_images, t, edge_conditions)
|
||||
|
||||
loss = trainer.mse_loss(predicted_noise, noise)
|
||||
total_loss += loss.item()
|
||||
|
||||
trainer.model.train()
|
||||
return total_loss / len(val_dataloader)
|
||||
|
||||
def train_optimized_diffusion(args):
|
||||
"""训练优化的扩散模型"""
|
||||
logger = setup_logging()
|
||||
|
||||
# 设备检查
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
logger.info(f"使用设备: {device}")
|
||||
|
||||
# 设置随机种子
|
||||
torch.manual_seed(args.seed)
|
||||
if device.type == 'cuda':
|
||||
torch.cuda.manual_seed(args.seed)
|
||||
|
||||
# 创建输出目录
|
||||
output_dir = Path(args.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 保存训练配置
|
||||
config = {
|
||||
'image_size': args.image_size,
|
||||
'batch_size': args.batch_size,
|
||||
'epochs': args.epochs,
|
||||
'lr': args.lr,
|
||||
'timesteps': args.timesteps,
|
||||
'schedule_type': args.schedule_type,
|
||||
'edge_condition': args.edge_condition,
|
||||
'manhattan_weight': args.manhattan_weight,
|
||||
'augment': args.augment,
|
||||
'seed': args.seed
|
||||
}
|
||||
|
||||
with open(output_dir / 'training_config.yaml', 'w') as f:
|
||||
yaml.dump(config, f, default_flow_style=False)
|
||||
|
||||
# 创建数据集
|
||||
logger.info(f"加载数据集: {args.data_dir}")
|
||||
dataset = ICDiffusionDataset(
|
||||
image_dir=args.data_dir,
|
||||
image_size=args.image_size,
|
||||
augment=args.augment,
|
||||
use_edge_condition=args.edge_condition
|
||||
)
|
||||
|
||||
# 数据集分割
|
||||
total_size = len(dataset)
|
||||
train_size = int(0.9 * total_size)
|
||||
val_size = total_size - train_size
|
||||
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
|
||||
|
||||
# 数据加载器
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=4,
|
||||
pin_memory=True
|
||||
)
|
||||
val_dataloader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=2
|
||||
)
|
||||
|
||||
logger.info(f"训练集大小: {len(train_dataset)}, 验证集大小: {len(val_dataset)}")
|
||||
|
||||
# 创建模型
|
||||
logger.info("创建优化模型...")
|
||||
model = ManhattanAwareUNet(
|
||||
in_channels=1,
|
||||
out_channels=1,
|
||||
use_edge_condition=args.edge_condition
|
||||
).to(device)
|
||||
|
||||
# 创建调度器
|
||||
scheduler = OptimizedNoiseScheduler(
|
||||
num_timesteps=args.timesteps,
|
||||
schedule_type=args.schedule_type
|
||||
)
|
||||
|
||||
# 创建训练器
|
||||
trainer = OptimizedDiffusionTrainer(
|
||||
model, scheduler, device, args.edge_condition
|
||||
)
|
||||
|
||||
# 优化器和学习率调度器
|
||||
optimizer = optim.AdamW(
|
||||
model.parameters(),
|
||||
lr=args.lr,
|
||||
weight_decay=0.01,
|
||||
betas=(0.9, 0.999)
|
||||
)
|
||||
|
||||
lr_scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
|
||||
optimizer, T_0=10, T_mult=2, eta_min=1e-6
|
||||
)
|
||||
|
||||
# 检查点恢复
|
||||
start_epoch = 0
|
||||
losses_history = []
|
||||
|
||||
if args.resume:
|
||||
checkpoint_path = Path(args.resume)
|
||||
if checkpoint_path.exists():
|
||||
start_epoch, losses_history = load_checkpoint(
|
||||
checkpoint_path, model, optimizer, lr_scheduler
|
||||
)
|
||||
else:
|
||||
logger.warning(f"检查点文件不存在: {checkpoint_path}")
|
||||
|
||||
logger.info(f"开始训练 {args.epochs} 个epoch (从epoch {start_epoch}开始)...")
|
||||
|
||||
# 训练循环
|
||||
best_val_loss = float('inf')
|
||||
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
# 训练
|
||||
train_losses = trainer.train_step(
|
||||
optimizer, train_dataloader, args.manhattan_weight
|
||||
)
|
||||
|
||||
# 验证
|
||||
val_loss = validate_model(trainer, val_dataloader, device)
|
||||
|
||||
# 学习率调度
|
||||
lr_scheduler.step()
|
||||
|
||||
# 记录损失
|
||||
current_lr = optimizer.param_groups[0]['lr']
|
||||
losses_history.append({
|
||||
'epoch': epoch,
|
||||
'train_loss': train_losses['total_loss'],
|
||||
'val_loss': val_loss,
|
||||
'edge_loss': train_losses['edge_loss'],
|
||||
'structure_loss': train_losses['structure_loss'],
|
||||
'manhattan_loss': train_losses['manhattan_loss'],
|
||||
'lr': current_lr
|
||||
})
|
||||
|
||||
# 日志输出
|
||||
logger.info(
|
||||
f"Epoch {epoch+1}/{args.epochs} | "
|
||||
f"Train Loss: {train_losses['total_loss']:.6f} | "
|
||||
f"Val Loss: {val_loss:.6f} | "
|
||||
f"Edge: {train_losses['edge_loss']:.6f} | "
|
||||
f"Structure: {train_losses['structure_loss']:.6f} | "
|
||||
f"Manhattan: {train_losses['manhattan_loss']:.6f} | "
|
||||
f"LR: {current_lr:.2e}"
|
||||
)
|
||||
|
||||
# 保存最佳模型
|
||||
if val_loss < best_val_loss:
|
||||
best_val_loss = val_loss
|
||||
best_model_path = output_dir / "best_model.pth"
|
||||
save_checkpoint(
|
||||
model, optimizer, lr_scheduler, epoch, losses_history, best_model_path
|
||||
)
|
||||
|
||||
# 定期保存检查点
|
||||
if (epoch + 1) % args.save_interval == 0:
|
||||
checkpoint_path = output_dir / f"checkpoint_epoch_{epoch+1}.pth"
|
||||
save_checkpoint(
|
||||
model, optimizer, lr_scheduler, epoch, losses_history, checkpoint_path
|
||||
)
|
||||
|
||||
# 生成样本
|
||||
if (epoch + 1) % args.sample_interval == 0:
|
||||
sample_dir = output_dir / f"samples_epoch_{epoch+1}"
|
||||
logger.info(f"生成样本到 {sample_dir}")
|
||||
trainer.generate(
|
||||
num_samples=args.num_samples,
|
||||
image_size=args.image_size,
|
||||
save_dir=sample_dir,
|
||||
use_post_process=True
|
||||
)
|
||||
|
||||
# 保存最终模型
|
||||
final_model_path = output_dir / "final_model.pth"
|
||||
save_checkpoint(
|
||||
model, optimizer, lr_scheduler, args.epochs-1, losses_history, final_model_path
|
||||
)
|
||||
|
||||
# 保存损失历史
|
||||
with open(output_dir / 'loss_history.yaml', 'w') as f:
|
||||
yaml.dump(losses_history, f, default_flow_style=False)
|
||||
|
||||
# 最终生成
|
||||
logger.info("生成最终样本...")
|
||||
final_sample_dir = output_dir / "final_samples"
|
||||
trainer.generate(
|
||||
num_samples=args.num_samples * 2, # 生成更多样本
|
||||
image_size=args.image_size,
|
||||
save_dir=final_sample_dir,
|
||||
use_post_process=True
|
||||
)
|
||||
|
||||
logger.info("训练完成!")
|
||||
logger.info(f"最佳模型: {output_dir / 'best_model.pth'}")
|
||||
logger.info(f"最终模型: {final_model_path}")
|
||||
logger.info(f"最终样本: {final_sample_dir}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="训练优化的IC版图扩散模型")
|
||||
|
||||
# 数据参数
|
||||
parser.add_argument('--data_dir', type=str, required=True, help='训练数据目录')
|
||||
parser.add_argument('--output_dir', type=str, required=True, help='输出目录')
|
||||
|
||||
# 模型参数
|
||||
parser.add_argument('--image_size', type=int, default=256, help='图像尺寸')
|
||||
parser.add_argument('--timesteps', type=int, default=1000, help='扩散时间步数')
|
||||
parser.add_argument('--schedule_type', type=str, default='cosine',
|
||||
choices=['linear', 'cosine'], help='噪声调度类型')
|
||||
parser.add_argument('--edge_condition', action='store_true', help='使用边缘条件')
|
||||
|
||||
# 训练参数
|
||||
parser.add_argument('--batch_size', type=int, default=4, help='批次大小')
|
||||
parser.add_argument('--epochs', type=int, default=100, help='训练轮数')
|
||||
parser.add_argument('--lr', type=float, default=1e-4, help='学习率')
|
||||
parser.add_argument('--manhattan_weight', type=float, default=0.1, help='曼哈顿正则化权重')
|
||||
parser.add_argument('--seed', type=int, default=42, help='随机种子')
|
||||
|
||||
# 训练控制
|
||||
parser.add_argument('--augment', action='store_true', help='启用数据增强')
|
||||
parser.add_argument('--resume', type=str, default=None, help='恢复训练的检查点路径')
|
||||
parser.add_argument('--save_interval', type=int, default=10, help='保存间隔')
|
||||
parser.add_argument('--sample_interval', type=int, default=20, help='生成样本间隔')
|
||||
parser.add_argument('--num_samples', type=int, default=16, help='每次生成的样本数量')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 检查数据目录
|
||||
if not Path(args.data_dir).exists():
|
||||
print(f"错误: 数据目录不存在: {args.data_dir}")
|
||||
sys.exit(1)
|
||||
|
||||
# 开始训练
|
||||
train_optimized_diffusion(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user