improve IC Layout Diffussion model 20251120

This commit is contained in:
Jiao77
2025-11-20 01:47:09 +08:00
parent 930f1952d5
commit 49fe21fb2f
8 changed files with 2254 additions and 0 deletions

View File

@@ -0,0 +1,539 @@
#!/usr/bin/env python3
"""
针对IC版图优化的去噪扩散模型
专门针对以曼哈顿多边形为全部组成元素的IC版图光栅化图像进行优化
- 曼哈顿几何感知的U-Net架构
- 边缘感知损失函数
- 多尺度结构损失
- 曼哈顿约束正则化
- 几何保持的数据增强
- 后处理优化
"""
import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from PIL import Image
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import logging
import cv2
try:
from tqdm import tqdm
except ImportError:
def tqdm(iterable, **kwargs):
return iterable
class ICDiffusionDataset(Dataset):
"""IC版图扩散模型训练数据集 - 优化版"""
def __init__(self, image_dir, image_size=256, augment=True, use_edge_condition=False):
self.image_dir = Path(image_dir)
self.image_size = image_size
self.use_edge_condition = use_edge_condition
# 获取所有PNG图像
self.image_paths = []
for ext in ['*.png', '*.jpg', '*.jpeg']:
self.image_paths.extend(list(self.image_dir.glob(ext)))
# 基础变换
self.transform = transforms.Compose([
transforms.Resize((image_size, image_size)),
transforms.ToTensor(),
])
# 几何保持的数据增强
self.augment = augment
if augment:
self.aug_transform = transforms.Compose([
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomVerticalFlip(p=0.5),
# 移除旋转,保持曼哈顿几何
])
def __len__(self):
return len(self.image_paths)
def _extract_edges(self, image_tensor):
"""提取边缘条件图"""
# 使用Sobel算子提取边缘
sobel_x = torch.tensor([[[[-1,0,1],[-2,0,2],[-1,0,1]]]],
dtype=image_tensor.dtype, device=image_tensor.device)
sobel_y = torch.tensor([[[[-1,-2,-1],[0,0,0],[1,2,1]]]],
dtype=image_tensor.dtype, device=image_tensor.device)
edge_x = F.conv2d(image_tensor.unsqueeze(0), sobel_x, padding=1)
edge_y = F.conv2d(image_tensor.unsqueeze(0), sobel_y, padding=1)
edge_magnitude = torch.sqrt(edge_x**2 + edge_y**2)
return torch.clamp(edge_magnitude, 0, 1)
def __getitem__(self, idx):
img_path = self.image_paths[idx]
image = Image.open(img_path).convert('L')
# 基础变换
image = self.transform(image)
# 几何保持的数据增强
if self.augment and np.random.random() > 0.5:
image = self.aug_transform(image)
if self.use_edge_condition:
edge_condition = self._extract_edges(image)
return image, edge_condition.squeeze(0)
return image
class EdgeAwareLoss(nn.Module):
"""边缘感知损失函数"""
def __init__(self):
super().__init__()
# 注册为缓冲区以避免重复创建
self.register_buffer('sobel_x', torch.tensor([[[[-1,0,1],[-2,0,2],[-1,0,1]]]]))
self.register_buffer('sobel_y', torch.tensor([[[[-1,-2,-1],[0,0,0],[1,2,1]]]]))
def forward(self, pred, target):
# 原始MSE损失
mse_loss = F.mse_loss(pred, target)
# 计算边缘
pred_edge_x = F.conv2d(pred, self.sobel_x, padding=1)
pred_edge_y = F.conv2d(pred, self.sobel_y, padding=1)
target_edge_x = F.conv2d(target, self.sobel_x, padding=1)
target_edge_y = F.conv2d(target, self.sobel_y, padding=1)
# 边缘损失
edge_loss = F.mse_loss(pred_edge_x, target_edge_x) + F.mse_loss(pred_edge_y, target_edge_y)
return mse_loss + 0.5 * edge_loss
class MultiScaleStructureLoss(nn.Module):
"""多尺度结构损失"""
def __init__(self):
super().__init__()
def forward(self, pred, target):
# 原始分辨率损失
loss_1x = F.mse_loss(pred, target)
# 2x下采样损失
pred_2x = F.avg_pool2d(pred, 2)
target_2x = F.avg_pool2d(target, 2)
loss_2x = F.mse_loss(pred_2x, target_2x)
# 4x下采样损失
pred_4x = F.avg_pool2d(pred, 4)
target_4x = F.avg_pool2d(target, 4)
loss_4x = F.mse_loss(pred_4x, target_4x)
return loss_1x + 0.5 * loss_2x + 0.25 * loss_4x
def manhattan_regularization_loss(generated_image, device='cuda'):
"""曼哈顿约束正则化损失"""
if device == 'cuda':
device = generated_image.device
# Sobel算子
sobel_x = torch.tensor([[[[-1,0,1],[-2,0,2],[-1,0,1]]]], device=device, dtype=generated_image.dtype)
sobel_y = torch.tensor([[[[-1,-2,-1],[0,0,0],[1,2,1]]]], device=device, dtype=generated_image.dtype)
# 检测边缘
edge_x = F.conv2d(generated_image, sobel_x, padding=1)
edge_y = F.conv2d(generated_image, sobel_y, padding=1)
# 边缘强度
edge_magnitude = torch.sqrt(edge_x**2 + edge_y**2 + 1e-8)
# 计算角度偏差
angles = torch.atan2(edge_y, edge_x)
# 惩罚不接近0°、90°、180°或270°的角度
angle_penalty = torch.min(
torch.min(torch.abs(angles), torch.abs(angles - np.pi/2)),
torch.min(torch.abs(angles - np.pi), torch.abs(angles - 3*np.pi/2))
)
return torch.mean(angle_penalty * edge_magnitude)
class ManhattanAwareUNet(nn.Module):
"""曼哈顿几何感知的U-Net架构"""
def __init__(self, in_channels=1, out_channels=1, time_dim=256, use_edge_condition=False):
super().__init__()
self.use_edge_condition = use_edge_condition
# 输入通道数(原始图像 + 可选边缘条件)
input_channels = in_channels + (1 if use_edge_condition else 0)
# 时间嵌入
self.time_mlp = nn.Sequential(
nn.Linear(1, time_dim),
nn.SiLU(),
nn.Linear(time_dim, time_dim)
)
# 曼哈顿几何感知的初始卷积层
self.horiz_conv = nn.Conv2d(input_channels, 32, (1, 7), padding=(0, 3))
self.vert_conv = nn.Conv2d(input_channels, 32, (7, 1), padding=(3, 0))
self.standard_conv = nn.Conv2d(input_channels, 32, 3, padding=1)
# 特征融合
self.initial_fusion = nn.Sequential(
nn.Conv2d(96, 64, 3, padding=1),
nn.GroupNorm(8, 64),
nn.SiLU()
)
# 编码器 - 增强版
self.encoder = nn.ModuleList([
self._make_block(64, 128),
self._make_block(128, 256, stride=2),
self._make_block(256, 512, stride=2),
self._make_block(512, 1024, stride=2),
])
# 中间层
self.middle = nn.Sequential(
nn.Conv2d(1024, 1024, 3, padding=1),
nn.GroupNorm(8, 1024),
nn.SiLU(),
nn.Conv2d(1024, 1024, 3, padding=1),
nn.GroupNorm(8, 1024),
nn.SiLU(),
)
# 解码器
self.decoder = nn.ModuleList([
self._make_decoder_block(1024, 512),
self._make_decoder_block(512, 256),
self._make_decoder_block(256, 128),
self._make_decoder_block(128, 64),
])
# 输出层
self.output = nn.Sequential(
nn.Conv2d(64, 32, 3, padding=1),
nn.GroupNorm(8, 32),
nn.SiLU(),
nn.Conv2d(32, out_channels, 3, padding=1)
)
# 时间融合层
self.time_fusion = nn.ModuleList([
nn.Linear(time_dim, 64),
nn.Linear(time_dim, 128),
nn.Linear(time_dim, 256),
nn.Linear(time_dim, 512),
nn.Linear(time_dim, 1024),
])
def _make_block(self, in_channels, out_channels, stride=1):
"""创建残差块"""
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, stride=stride, padding=1),
nn.GroupNorm(8, out_channels),
nn.SiLU(),
nn.Conv2d(out_channels, out_channels, 3, padding=1),
nn.GroupNorm(8, out_channels),
nn.SiLU(),
)
def _make_decoder_block(self, in_channels, out_channels):
"""创建解码器块"""
return nn.Sequential(
nn.ConvTranspose2d(in_channels, out_channels, 3, stride=2, padding=1, output_padding=1),
nn.GroupNorm(8, out_channels),
nn.SiLU(),
nn.Conv2d(out_channels, out_channels, 3, padding=1),
nn.GroupNorm(8, out_channels),
nn.SiLU(),
)
def forward(self, x, t, edge_condition=None):
# 如果有边缘条件,连接到输入
if self.use_edge_condition and edge_condition is not None:
x = torch.cat([x, edge_condition], dim=1)
# 时间嵌入
t_emb = self.time_mlp(t.float().unsqueeze(-1)) # [B, time_dim]
# 曼哈顿几何感知的特征提取
h_features = F.silu(self.horiz_conv(x))
v_features = F.silu(self.vert_conv(x))
s_features = F.silu(self.standard_conv(x))
# 融合特征
x = torch.cat([h_features, v_features, s_features], dim=1)
x = self.initial_fusion(x)
# 编码器路径
skips = []
for i, (encoder, fusion) in enumerate(zip(self.encoder, self.time_fusion)):
# 残差连接
residual = x
x = encoder(x)
# 融合时间信息
t_feat = fusion(t_emb).unsqueeze(-1).unsqueeze(-1)
x = x + t_feat
# 跳跃连接
skips.append(x + residual if i == 0 else x)
# 中间层
x = self.middle(x)
# 解码器路径
for i, (decoder, skip) in enumerate(zip(self.decoder, reversed(skips))):
x = decoder(x)
x = x + skip # 跳跃连接
# 输出
x = self.output(x)
return x
class OptimizedNoiseScheduler:
"""优化的噪声调度器"""
def __init__(self, num_timesteps=1000, beta_start=1e-4, beta_end=0.02, schedule_type='linear'):
self.num_timesteps = num_timesteps
# 不同调度策略
if schedule_type == 'cosine':
# 余弦调度,通常效果更好
steps = num_timesteps + 1
x = torch.linspace(0, num_timesteps, steps, dtype=torch.float64)
alphas_cumprod = torch.cos(((x / num_timesteps) + 0.008) / 1.008 * np.pi / 2) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
self.betas = torch.clip(betas, 0, 0.999)
else:
# 线性调度
self.betas = torch.linspace(beta_start, beta_end, num_timesteps)
# 预计算
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, axis=0)
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)
def add_noise(self, x_0, t):
"""向干净图像添加噪声"""
noise = torch.randn_like(x_0)
sqrt_alphas_cumprod_t = self.sqrt_alphas_cumprod[t].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
return sqrt_alphas_cumprod_t * x_0 + sqrt_one_minus_alphas_cumprod_t * noise, noise
def sample_timestep(self, batch_size):
"""采样时间步"""
return torch.randint(0, self.num_timesteps, (batch_size,))
def step(self, model, x_t, t):
"""单步去噪"""
# 预测噪声
predicted_noise = model(x_t, t)
# 计算系数
alpha_t = self.alphas[t].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
sqrt_alpha_t = torch.sqrt(alpha_t)
beta_t = self.betas[t].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
sqrt_one_minus_alpha_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
# 计算均值
model_mean = (1.0 / sqrt_alpha_t) * (x_t - (beta_t / sqrt_one_minus_alpha_cumprod_t) * predicted_noise)
if t.min() == 0:
return model_mean
else:
noise = torch.randn_like(x_t)
return model_mean + torch.sqrt(beta_t) * noise
def manhattan_post_process(image, threshold=0.5):
"""曼哈顿化后处理"""
device = image.device
# 二值化
binary = (image > threshold).float()
# 形态学操作强化直角特征
kernel_h = torch.tensor([[[[1,1,1]]]], device=device)
kernel_v = torch.tensor([[[[1],[1],[1]]]], device=device)
# 水平和垂直增强
horizontal = F.conv2d(binary, kernel_h, padding=(0,1))
vertical = F.conv2d(binary, kernel_v, padding=(1,0))
# 合并结果
result = torch.clamp(horizontal + vertical - binary, 0, 1)
# 最终阈值处理
result = (result > 0.5).float()
return result
class OptimizedDiffusionTrainer:
"""优化的扩散模型训练器"""
def __init__(self, model, scheduler, device='cuda', use_edge_condition=False):
self.model = model.to(device)
self.scheduler = scheduler
self.device = device
self.use_edge_condition = use_edge_condition
# 组合损失函数
self.edge_loss = EdgeAwareLoss()
self.structure_loss = MultiScaleStructureLoss()
self.mse_loss = nn.MSELoss()
def train_step(self, optimizer, dataloader, manhattan_weight=0.1):
"""单步训练"""
self.model.train()
total_loss = 0
total_edge_loss = 0
total_structure_loss = 0
total_manhattan_loss = 0
for batch in dataloader:
if self.use_edge_condition:
images, edge_conditions = batch
edge_conditions = edge_conditions.to(self.device)
else:
images = batch
edge_conditions = None
images = images.to(self.device)
# 采样时间步
t = self.scheduler.sample_timestep(images.shape[0]).to(self.device)
# 添加噪声
noisy_images, noise = self.scheduler.add_noise(images, t)
# 预测噪声
predicted_noise = self.model(noisy_images, t, edge_conditions)
# 计算多种损失
mse_loss = self.mse_loss(predicted_noise, noise)
edge_loss = self.edge_loss(predicted_noise, noise)
structure_loss = self.structure_loss(predicted_noise, noise)
# 曼哈顿正则化损失
with torch.no_grad():
# 对去噪结果应用曼哈顿约束
denoised = noisy_images - predicted_noise
manhattan_loss = manhattan_regularization_loss(denoised, self.device)
# 总损失
total_step_loss = mse_loss + 0.3 * edge_loss + 0.2 * structure_loss + manhattan_weight * manhattan_loss
# 反向传播
optimizer.zero_grad()
total_step_loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) # 梯度裁剪
optimizer.step()
total_loss += total_step_loss.item()
total_edge_loss += edge_loss.item()
total_structure_loss += structure_loss.item()
total_manhattan_loss += manhattan_loss.item()
num_batches = len(dataloader)
return {
'total_loss': total_loss / num_batches,
'mse_loss': total_loss / num_batches, # 近似值
'edge_loss': total_edge_loss / num_batches,
'structure_loss': total_structure_loss / num_batches,
'manhattan_loss': total_manhattan_loss / num_batches
}
def generate(self, num_samples, image_size=256, save_dir=None, use_post_process=True):
"""生成图像"""
self.model.eval()
with torch.no_grad():
# 从纯噪声开始
x = torch.randn(num_samples, 1, image_size, image_size).to(self.device)
# 逐步去噪
for t in reversed(range(self.scheduler.num_timesteps)):
t_batch = torch.full((num_samples,), t, device=self.device)
x = self.scheduler.step(self.model, x, t_batch)
# 限制到合理范围
x = torch.clamp(x, -2.0, 2.0)
# 最终处理
x = torch.clamp(x, 0.0, 1.0)
# 后处理
if use_post_process:
x = manhattan_post_process(x)
# 保存图像
if save_dir:
save_dir = Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
for i in range(num_samples):
img_tensor = x[i].cpu()
img_array = (img_tensor.squeeze().numpy() * 255).astype(np.uint8)
img = Image.fromarray(img_array, mode='L')
img.save(save_dir / f"generated_{i:06d}.png")
return x.cpu()
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="优化的IC版图扩散模型训练和生成")
subparsers = parser.add_subparsers(dest='command', help='命令')
# 训练命令
train_parser = subparsers.add_parser('train', help='训练扩散模型')
train_parser.add_argument('--data_dir', type=str, required=True, help='训练数据目录')
train_parser.add_argument('--output_dir', type=str, required=True, help='输出目录')
train_parser.add_argument('--image_size', type=int, default=256, help='图像尺寸')
train_parser.add_argument('--batch_size', type=int, default=4, help='批次大小')
train_parser.add_argument('--epochs', type=int, default=100, help='训练轮数')
train_parser.add_argument('--lr', type=float, default=1e-4, help='学习率')
train_parser.add_argument('--timesteps', type=int, default=1000, help='扩散时间步数')
train_parser.add_argument('--num_samples', type=int, default=50, help='生成的样本数量')
train_parser.add_argument('--save_interval', type=int, default=10, help='保存间隔')
train_parser.add_argument('--augment', action='store_true', help='启用数据增强')
train_parser.add_argument('--edge_condition', action='store_true', help='使用边缘条件')
train_parser.add_argument('--manhattan_weight', type=float, default=0.1, help='曼哈顿正则化权重')
train_parser.add_argument('--schedule_type', type=str, default='cosine', choices=['linear', 'cosine'], help='噪声调度类型')
# 生成命令
gen_parser = subparsers.add_parser('generate', help='使用训练好的模型生成图像')
gen_parser.add_argument('--checkpoint', type=str, required=True, help='模型检查点路径')
gen_parser.add_argument('--output_dir', type=str, required=True, help='输出目录')
gen_parser.add_argument('--num_samples', type=int, default=200, help='生成样本数量')
gen_parser.add_argument('--image_size', type=int, default=256, help='图像尺寸')
gen_parser.add_argument('--timesteps', type=int, default=1000, help='扩散时间步数')
gen_parser.add_argument('--use_post_process', action='store_true', default=True, help='启用后处理')
args = parser.parse_args()
# TODO: 实现训练和生成函数,使用优化后的组件
print("[TODO] 实现完整的训练和生成流程,使用优化后的模型架构和损失函数")