change some problem 7
This commit is contained in:
@@ -13,6 +13,7 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
import math
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@@ -64,15 +65,23 @@ class ICDiffusionDataset(Dataset):
|
|||||||
|
|
||||||
def _extract_edges(self, image_tensor):
|
def _extract_edges(self, image_tensor):
|
||||||
"""提取边缘条件图"""
|
"""提取边缘条件图"""
|
||||||
# 使用Sobel算子提取边缘
|
# 修复Sobel算子 - 正确的3x3 Sobel算子
|
||||||
sobel_x = torch.tensor([[[[-1,0,1],[-2,0,2],[-1,0,1]]]],
|
if len(image_tensor.shape) == 3: # [C, H, W]
|
||||||
|
image_tensor = image_tensor.unsqueeze(0) # [1, C, H, W]
|
||||||
|
|
||||||
|
# 为单通道图像设计的Sobel算子
|
||||||
|
sobel_x = torch.tensor([[[[-1.0, 0.0, 1.0],
|
||||||
|
[-2.0, 0.0, 2.0],
|
||||||
|
[-1.0, 0.0, 1.0]]]],
|
||||||
dtype=image_tensor.dtype, device=image_tensor.device)
|
dtype=image_tensor.dtype, device=image_tensor.device)
|
||||||
sobel_y = torch.tensor([[[[-1,-2,-1],[0,0,0],[1,2,1]]]],
|
sobel_y = torch.tensor([[[[-1.0, -2.0, -1.0],
|
||||||
|
[0.0, 0.0, 0.0],
|
||||||
|
[1.0, 2.0, 1.0]]]],
|
||||||
dtype=image_tensor.dtype, device=image_tensor.device)
|
dtype=image_tensor.dtype, device=image_tensor.device)
|
||||||
|
|
||||||
edge_x = F.conv2d(image_tensor.unsqueeze(0), sobel_x, padding=1)
|
edge_x = F.conv2d(image_tensor, sobel_x, padding=1)
|
||||||
edge_y = F.conv2d(image_tensor.unsqueeze(0), sobel_y, padding=1)
|
edge_y = F.conv2d(image_tensor, sobel_y, padding=1)
|
||||||
edge_magnitude = torch.sqrt(edge_x**2 + edge_y**2)
|
edge_magnitude = torch.sqrt(edge_x**2 + edge_y**2 + 1e-8)
|
||||||
|
|
||||||
return torch.clamp(edge_magnitude, 0, 1)
|
return torch.clamp(edge_magnitude, 0, 1)
|
||||||
|
|
||||||
@@ -170,8 +179,25 @@ def manhattan_regularization_loss(generated_image, device='cuda'):
|
|||||||
return torch.mean(angle_penalty * edge_magnitude)
|
return torch.mean(angle_penalty * edge_magnitude)
|
||||||
|
|
||||||
|
|
||||||
|
class SinusoidalPositionEmbeddings(nn.Module):
|
||||||
|
"""正弦位置编码用于时间步嵌入"""
|
||||||
|
|
||||||
|
def __init__(self, dim):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
|
||||||
|
def forward(self, time):
|
||||||
|
device = time.device
|
||||||
|
half_dim = self.dim // 2
|
||||||
|
embeddings = math.log(10000) / (half_dim - 1)
|
||||||
|
embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
|
||||||
|
embeddings = time[:, None].float() * embeddings[None, :]
|
||||||
|
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=1)
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
class ManhattanAwareUNet(nn.Module):
|
class ManhattanAwareUNet(nn.Module):
|
||||||
"""曼哈顿几何感知的U-Net架构"""
|
"""曼哈顿几何感知的U-Net架构 - 修复版"""
|
||||||
|
|
||||||
def __init__(self, in_channels=1, out_channels=1, time_dim=256, use_edge_condition=False):
|
def __init__(self, in_channels=1, out_channels=1, time_dim=256, use_edge_condition=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -180,9 +206,9 @@ class ManhattanAwareUNet(nn.Module):
|
|||||||
# 输入通道数(原始图像 + 可选边缘条件)
|
# 输入通道数(原始图像 + 可选边缘条件)
|
||||||
input_channels = in_channels + (1 if use_edge_condition else 0)
|
input_channels = in_channels + (1 if use_edge_condition else 0)
|
||||||
|
|
||||||
# 时间嵌入
|
# 时间嵌入 - 修复时间步编码
|
||||||
self.time_mlp = nn.Sequential(
|
self.time_mlp = nn.Sequential(
|
||||||
nn.Linear(1, time_dim),
|
SinusoidalPositionEmbeddings(time_dim),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
nn.Linear(time_dim, time_dim)
|
nn.Linear(time_dim, time_dim)
|
||||||
)
|
)
|
||||||
@@ -199,43 +225,34 @@ class ManhattanAwareUNet(nn.Module):
|
|||||||
nn.SiLU()
|
nn.SiLU()
|
||||||
)
|
)
|
||||||
|
|
||||||
# 编码器 - 增强版
|
# 编码器通道配置
|
||||||
|
encoder_channels = [64, 128, 256, 512] # 修正:减少通道数避免过拟合
|
||||||
self.encoder = nn.ModuleList([
|
self.encoder = nn.ModuleList([
|
||||||
self._make_block(64, 128),
|
self._make_block(encoder_channels[i], encoder_channels[i+1], stride=2 if i > 0 else 1)
|
||||||
self._make_block(128, 256, stride=2),
|
for i in range(len(encoder_channels)-1)
|
||||||
self._make_block(256, 512, stride=2),
|
|
||||||
self._make_block(512, 1024, stride=2),
|
|
||||||
])
|
|
||||||
|
|
||||||
# 残差投影层 - 确保残差连接的通道数和空间尺寸匹配
|
|
||||||
self.residual_projections = nn.ModuleList([
|
|
||||||
nn.Conv2d(64, 128, 1), # 64 -> 128, 保持尺寸
|
|
||||||
nn.Conv2d(128, 256, 3, stride=2, padding=1), # 128 -> 256, 尺寸减半
|
|
||||||
nn.Conv2d(256, 512, 3, stride=2, padding=1), # 256 -> 512, 尺寸减半
|
|
||||||
nn.Conv2d(512, 1024, 3, stride=2, padding=1), # 512 -> 1024, 尺寸减半
|
|
||||||
])
|
])
|
||||||
|
|
||||||
# 中间层
|
# 中间层
|
||||||
self.middle = nn.Sequential(
|
self.middle = nn.Sequential(
|
||||||
nn.Conv2d(1024, 1024, 3, padding=1),
|
nn.Conv2d(512, 512, 3, padding=1),
|
||||||
nn.GroupNorm(8, 1024),
|
nn.GroupNorm(8, 512),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
nn.Conv2d(1024, 1024, 3, padding=1),
|
nn.Conv2d(512, 512, 3, padding=1),
|
||||||
nn.GroupNorm(8, 1024),
|
nn.GroupNorm(8, 512),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
)
|
)
|
||||||
|
|
||||||
# 解码器
|
# 解码器 - 修复通道数计算
|
||||||
self.decoder = nn.ModuleList([
|
self.decoder = nn.ModuleList([
|
||||||
self._make_decoder_block(1024, 512),
|
self._make_decoder_block(512, 256), # middle (512) -> 256
|
||||||
self._make_decoder_block(512, 256),
|
self._make_decoder_block(512, 128), # 256+256(skip) -> 128
|
||||||
self._make_decoder_block(256, 128),
|
self._make_decoder_block(256, 64), # 128+128(skip) -> 64
|
||||||
self._make_decoder_block(128, 64),
|
self._make_decoder_block(128, 64), # 64+64(skip) -> 64
|
||||||
])
|
])
|
||||||
|
|
||||||
# 输出层
|
# 输出层 - 修复输入通道数
|
||||||
self.output = nn.Sequential(
|
self.output = nn.Sequential(
|
||||||
nn.Conv2d(64, 32, 3, padding=1),
|
nn.Conv2d(64, 32, 3, padding=1), # 最后一层跳跃连接后是64通道
|
||||||
nn.GroupNorm(8, 32),
|
nn.GroupNorm(8, 32),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
nn.Conv2d(32, out_channels, 3, padding=1)
|
nn.Conv2d(32, out_channels, 3, padding=1)
|
||||||
@@ -243,10 +260,7 @@ class ManhattanAwareUNet(nn.Module):
|
|||||||
|
|
||||||
# 时间融合层 - 与编码器输出通道数匹配
|
# 时间融合层 - 与编码器输出通道数匹配
|
||||||
self.time_fusion = nn.ModuleList([
|
self.time_fusion = nn.ModuleList([
|
||||||
nn.Linear(time_dim, 128),
|
nn.Linear(time_dim, channels) for channels in encoder_channels[1:] + [512]
|
||||||
nn.Linear(time_dim, 256),
|
|
||||||
nn.Linear(time_dim, 512),
|
|
||||||
nn.Linear(time_dim, 1024),
|
|
||||||
])
|
])
|
||||||
|
|
||||||
def _make_block(self, in_channels, out_channels, stride=1):
|
def _make_block(self, in_channels, out_channels, stride=1):
|
||||||
@@ -276,8 +290,8 @@ class ManhattanAwareUNet(nn.Module):
|
|||||||
if self.use_edge_condition and edge_condition is not None:
|
if self.use_edge_condition and edge_condition is not None:
|
||||||
x = torch.cat([x, edge_condition], dim=1)
|
x = torch.cat([x, edge_condition], dim=1)
|
||||||
|
|
||||||
# 时间嵌入
|
# 时间嵌入 - 使用正弦位置编码
|
||||||
t_emb = self.time_mlp(t.float().unsqueeze(-1)) # [B, time_dim]
|
t_emb = self.time_mlp(t) # [B, time_dim]
|
||||||
|
|
||||||
# 曼哈顿几何感知的特征提取
|
# 曼哈顿几何感知的特征提取
|
||||||
h_features = F.silu(self.horiz_conv(x))
|
h_features = F.silu(self.horiz_conv(x))
|
||||||
@@ -285,45 +299,58 @@ class ManhattanAwareUNet(nn.Module):
|
|||||||
s_features = F.silu(self.standard_conv(x))
|
s_features = F.silu(self.standard_conv(x))
|
||||||
|
|
||||||
# 融合特征
|
# 融合特征
|
||||||
x = torch.cat([h_features, v_features, s_features], dim=1)
|
x = torch.cat([h_features, v_features, s_features], dim=1) # [B, 96, H, W]
|
||||||
x = self.initial_fusion(x)
|
x = self.initial_fusion(x) # [B, 64, H, W]
|
||||||
|
|
||||||
# 编码器路径
|
# 编码器路径 - 修复跳跃连接逻辑
|
||||||
skips = []
|
skips = []
|
||||||
for i, (encoder, fusion) in enumerate(zip(self.encoder, self.time_fusion)):
|
for i, (encoder, fusion) in enumerate(zip(self.encoder, self.time_fusion)):
|
||||||
# 保存残差连接并使用投影层匹配通道数
|
# 保存跳跃连接(在编码之前)
|
||||||
residual = self.residual_projections[i](x)
|
skips.append(x)
|
||||||
|
|
||||||
|
# 编码
|
||||||
x = encoder(x)
|
x = encoder(x)
|
||||||
|
|
||||||
# 融合时间信息
|
# 融合时间信息
|
||||||
t_feat = fusion(t_emb).unsqueeze(-1).unsqueeze(-1)
|
t_feat = fusion(t_emb).unsqueeze(-1).unsqueeze(-1) # [B, channels, 1, 1]
|
||||||
x = x + t_feat
|
|
||||||
|
|
||||||
# 跳跃连接 - 添加投影后的残差
|
# 检查通道数是否匹配
|
||||||
skips.append(x + residual)
|
if x.shape[1] == t_feat.shape[1]:
|
||||||
|
x = x + t_feat
|
||||||
|
else:
|
||||||
|
# 如果不匹配,使用1x1卷积调整通道数
|
||||||
|
if not hasattr(self, f'time_proj_{i}'):
|
||||||
|
setattr(self, f'time_proj_{i}',
|
||||||
|
nn.Conv2d(t_feat.shape[1], x.shape[1], 1).to(x.device))
|
||||||
|
time_proj = getattr(self, f'time_proj_{i}')
|
||||||
|
x = x + time_proj(t_feat)
|
||||||
|
|
||||||
# 中间层
|
# 中间层
|
||||||
x = self.middle(x)
|
x = self.middle(x)
|
||||||
|
|
||||||
# 解码器路径
|
# 解码器路径 - 修复跳跃连接逻辑
|
||||||
for i, (decoder, skip) in enumerate(zip(self.decoder, reversed(skips))):
|
for i, decoder in enumerate(self.decoder):
|
||||||
|
# 获取对应的跳跃连接(反向顺序)
|
||||||
|
skip = skips[-(i+1)]
|
||||||
|
|
||||||
|
# 上采样
|
||||||
x = decoder(x)
|
x = decoder(x)
|
||||||
|
|
||||||
# 确保跳跃连接尺寸匹配 - 如果尺寸不匹配则进行裁剪或填充
|
# 确保跳跃连接尺寸匹配
|
||||||
if x.shape[2:] != skip.shape[2:]:
|
if x.shape[2:] != skip.shape[2:]:
|
||||||
# 裁剪到最小尺寸
|
# 使用插值调整尺寸
|
||||||
h_min = min(x.shape[2], skip.shape[2])
|
x = F.interpolate(x, size=skip.shape[2:], mode='bilinear', align_corners=False)
|
||||||
w_min = min(x.shape[3], skip.shape[3])
|
|
||||||
if x.shape[2] > h_min:
|
|
||||||
x = x[:, :, :h_min, :]
|
|
||||||
if x.shape[3] > w_min:
|
|
||||||
x = x[:, :, :, :w_min]
|
|
||||||
if skip.shape[2] > h_min:
|
|
||||||
skip = skip[:, :, :h_min, :]
|
|
||||||
if skip.shape[3] > w_min:
|
|
||||||
skip = skip[:, :, :, :w_min]
|
|
||||||
|
|
||||||
x = x + skip # 跳跃连接
|
# 跳跃连接(需要处理通道数匹配)
|
||||||
|
if x.shape[1] == skip.shape[1]:
|
||||||
|
x = x + skip
|
||||||
|
else:
|
||||||
|
# 如果通道数不匹配,使用1x1卷积调整
|
||||||
|
if not hasattr(self, f'skip_proj_{i}'):
|
||||||
|
setattr(self, f'skip_proj_{i}',
|
||||||
|
nn.Conv2d(skip.shape[1], x.shape[1], 1).to(x.device))
|
||||||
|
skip_proj = getattr(self, f'skip_proj_{i}')
|
||||||
|
x = x + skip_proj(skip)
|
||||||
|
|
||||||
# 输出
|
# 输出
|
||||||
x = self.output(x)
|
x = self.output(x)
|
||||||
@@ -358,14 +385,11 @@ class OptimizedNoiseScheduler:
|
|||||||
def add_noise(self, x_0, t):
|
def add_noise(self, x_0, t):
|
||||||
"""向干净图像添加噪声"""
|
"""向干净图像添加噪声"""
|
||||||
noise = torch.randn_like(x_0)
|
noise = torch.randn_like(x_0)
|
||||||
# 确保调度器张量与输入张量在同一设备上
|
|
||||||
device = x_0.device
|
device = x_0.device
|
||||||
if self.sqrt_alphas_cumprod.device != device:
|
|
||||||
self.sqrt_alphas_cumprod = self.sqrt_alphas_cumprod.to(device)
|
|
||||||
self.sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod.to(device)
|
|
||||||
|
|
||||||
sqrt_alphas_cumprod_t = self.sqrt_alphas_cumprod[t].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
# 确保调度器张量与输入张量在同一设备上
|
||||||
sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
sqrt_alphas_cumprod_t = self.sqrt_alphas_cumprod[t].to(device).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
||||||
|
sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t].to(device).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
||||||
|
|
||||||
return sqrt_alphas_cumprod_t * x_0 + sqrt_one_minus_alphas_cumprod_t * noise, noise
|
return sqrt_alphas_cumprod_t * x_0 + sqrt_one_minus_alphas_cumprod_t * noise, noise
|
||||||
|
|
||||||
@@ -381,18 +405,13 @@ class OptimizedNoiseScheduler:
|
|||||||
# 预测噪声
|
# 预测噪声
|
||||||
predicted_noise = model(x_t, t)
|
predicted_noise = model(x_t, t)
|
||||||
|
|
||||||
# 确保调度器张量与输入张量在同一设备上
|
|
||||||
device = x_t.device
|
device = x_t.device
|
||||||
if self.alphas.device != device:
|
|
||||||
self.alphas = self.alphas.to(device)
|
|
||||||
self.betas = self.betas.to(device)
|
|
||||||
self.sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod.to(device)
|
|
||||||
|
|
||||||
# 计算系数
|
# 计算系数(直接使用索引并移动到设备)
|
||||||
alpha_t = self.alphas[t].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
alpha_t = self.alphas[t].to(device).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
||||||
sqrt_alpha_t = torch.sqrt(alpha_t)
|
sqrt_alpha_t = torch.sqrt(alpha_t)
|
||||||
beta_t = self.betas[t].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
beta_t = self.betas[t].to(device).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
||||||
sqrt_one_minus_alpha_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
sqrt_one_minus_alpha_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t].to(device).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
||||||
|
|
||||||
# 计算均值
|
# 计算均值
|
||||||
model_mean = (1.0 / sqrt_alpha_t) * (x_t - (beta_t / sqrt_one_minus_alpha_cumprod_t) * predicted_noise)
|
model_mean = (1.0 / sqrt_alpha_t) * (x_t - (beta_t / sqrt_one_minus_alpha_cumprod_t) * predicted_noise)
|
||||||
|
|||||||
257
tools/diffusion/test_fixes.py
Normal file
257
tools/diffusion/test_fixes.py
Normal file
@@ -0,0 +1,257 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
测试修复后的IC版图扩散模型
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import numpy as np
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# 导入修复后的模型
|
||||||
|
from ic_layout_diffusion_optimized import (
|
||||||
|
ManhattanAwareUNet,
|
||||||
|
OptimizedNoiseScheduler,
|
||||||
|
OptimizedDiffusionTrainer,
|
||||||
|
ICDiffusionDataset
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_model_architecture():
|
||||||
|
"""测试模型架构修复"""
|
||||||
|
print("测试模型架构...")
|
||||||
|
|
||||||
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
print(f"使用设备: {device}")
|
||||||
|
|
||||||
|
# 测试不同的配置
|
||||||
|
test_configs = [
|
||||||
|
{'use_edge_condition': False, 'batch_size': 2},
|
||||||
|
{'use_edge_condition': True, 'batch_size': 2},
|
||||||
|
]
|
||||||
|
|
||||||
|
for config in test_configs:
|
||||||
|
print(f"\n测试配置: {config}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 创建模型
|
||||||
|
model = ManhattanAwareUNet(
|
||||||
|
in_channels=1,
|
||||||
|
out_channels=1,
|
||||||
|
time_dim=256,
|
||||||
|
use_edge_condition=config['use_edge_condition']
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
# 创建测试数据
|
||||||
|
batch_size = config['batch_size']
|
||||||
|
image_size = 64 # 使用较小的图像尺寸进行快速测试
|
||||||
|
|
||||||
|
x = torch.randn(batch_size, 1, image_size, image_size).to(device)
|
||||||
|
t = torch.randint(0, 1000, (batch_size,)).to(device)
|
||||||
|
|
||||||
|
if config['use_edge_condition']:
|
||||||
|
edge_condition = torch.randn(batch_size, 1, image_size, image_size).to(device)
|
||||||
|
output = model(x, t, edge_condition)
|
||||||
|
else:
|
||||||
|
output = model(x, t)
|
||||||
|
|
||||||
|
print(f"✓ 输入形状: {x.shape}")
|
||||||
|
print(f"✓ 输出形状: {output.shape}")
|
||||||
|
print(f"✓ 时间步形状: {t.shape}")
|
||||||
|
|
||||||
|
# 检查输出形状
|
||||||
|
expected_shape = (batch_size, 1, image_size, image_size)
|
||||||
|
if output.shape == expected_shape:
|
||||||
|
print(f"✓ 输出形状正确: {output.shape}")
|
||||||
|
else:
|
||||||
|
print(f"✗ 输出形状错误: 期望 {expected_shape}, 得到 {output.shape}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"✗ 模型测试失败: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
print("\n✓ 所有模型架构测试通过!")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def test_scheduler():
|
||||||
|
"""测试噪声调度器修复"""
|
||||||
|
print("\n测试噪声调度器...")
|
||||||
|
|
||||||
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
|
||||||
|
try:
|
||||||
|
scheduler = OptimizedNoiseScheduler(num_timesteps=1000, schedule_type='cosine')
|
||||||
|
|
||||||
|
# 测试噪声添加
|
||||||
|
x_0 = torch.randn(4, 1, 32, 32).to(device)
|
||||||
|
t = torch.randint(0, 1000, (4,)).to(device)
|
||||||
|
|
||||||
|
x_t, noise = scheduler.add_noise(x_0, t)
|
||||||
|
|
||||||
|
print(f"✓ 原始图像形状: {x_0.shape}")
|
||||||
|
print(f"✓ 噪声图像形状: {x_t.shape}")
|
||||||
|
print(f"✓ 噪声形状: {noise.shape}")
|
||||||
|
|
||||||
|
# 测试去噪步骤
|
||||||
|
model = ManhattanAwareUNet().to(device)
|
||||||
|
x_denoised = scheduler.step(model, x_t, t)
|
||||||
|
|
||||||
|
print(f"✓ 去噪图像形状: {x_denoised.shape}")
|
||||||
|
|
||||||
|
# 检查形状是否保持一致
|
||||||
|
if x_denoised.shape == x_t.shape:
|
||||||
|
print("✓ 去噪步骤形状正确")
|
||||||
|
else:
|
||||||
|
print(f"✗ 去噪步骤形状错误: 期望 {x_t.shape}, 得到 {x_denoised.shape}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"✗ 调度器测试失败: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
print("✓ 调度器测试通过!")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def test_trainer():
|
||||||
|
"""测试训练器修复"""
|
||||||
|
print("\n测试训练器...")
|
||||||
|
|
||||||
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 创建模型和调度器
|
||||||
|
model = ManhattanAwareUNet().to(device)
|
||||||
|
scheduler = OptimizedNoiseScheduler(num_timesteps=100)
|
||||||
|
trainer = OptimizedDiffusionTrainer(model, scheduler, device, use_edge_condition=False)
|
||||||
|
|
||||||
|
# 创建虚拟数据
|
||||||
|
batch_size = 2
|
||||||
|
images = torch.randn(batch_size, 1, 32, 32).to(device)
|
||||||
|
|
||||||
|
# 测试单个训练步骤
|
||||||
|
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
|
||||||
|
|
||||||
|
# 模拟数据加载器
|
||||||
|
class MockDataloader:
|
||||||
|
def __init__(self, data):
|
||||||
|
self.data = data
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
yield self.data
|
||||||
|
|
||||||
|
mock_dataloader = MockDataloader(images)
|
||||||
|
|
||||||
|
# 执行训练步骤
|
||||||
|
losses = trainer.train_step(optimizer, mock_dataloader, manhattan_weight=0.1)
|
||||||
|
|
||||||
|
print(f"✓ 训练步骤完成,损失: {losses}")
|
||||||
|
|
||||||
|
# 测试生成
|
||||||
|
samples = trainer.generate(
|
||||||
|
num_samples=2,
|
||||||
|
image_size=32,
|
||||||
|
save_dir=None,
|
||||||
|
use_post_process=False
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"✓ 生成样本形状: {samples.shape}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"✗ 训练器测试失败: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
return False
|
||||||
|
|
||||||
|
print("✓ 训练器测试通过!")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def test_data_loading():
|
||||||
|
"""测试数据加载修复"""
|
||||||
|
print("\n测试数据加载...")
|
||||||
|
|
||||||
|
# 创建临时测试目录(如果不存在)
|
||||||
|
test_dir = Path("test_images")
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not test_dir.exists():
|
||||||
|
print("创建测试图像目录...")
|
||||||
|
test_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
# 创建一些简单的测试图像
|
||||||
|
from PIL import Image
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
for i in range(3):
|
||||||
|
# 创建随机灰度图像
|
||||||
|
img_array = np.random.randint(0, 255, (64, 64), dtype=np.uint8)
|
||||||
|
img = Image.fromarray(img_array, mode='L')
|
||||||
|
img.save(test_dir / f"test_{i}.png")
|
||||||
|
|
||||||
|
# 测试数据集创建
|
||||||
|
dataset = ICDiffusionDataset(
|
||||||
|
image_dir=str(test_dir),
|
||||||
|
image_size=32,
|
||||||
|
augment=False,
|
||||||
|
use_edge_condition=False
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"✓ 数据集大小: {len(dataset)}")
|
||||||
|
|
||||||
|
if len(dataset) == 0:
|
||||||
|
print("✗ 数据集为空")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 测试数据加载
|
||||||
|
sample = dataset[0]
|
||||||
|
print(f"✓ 样本形状: {sample.shape}")
|
||||||
|
|
||||||
|
# 测试边缘条件
|
||||||
|
dataset_edge = ICDiffusionDataset(
|
||||||
|
image_dir=str(test_dir),
|
||||||
|
image_size=32,
|
||||||
|
augment=False,
|
||||||
|
use_edge_condition=True
|
||||||
|
)
|
||||||
|
|
||||||
|
image, edge = dataset_edge[0]
|
||||||
|
print(f"✓ 图像形状: {image.shape}, 边缘形状: {edge.shape}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"✗ 数据加载测试失败: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
return False
|
||||||
|
|
||||||
|
print("✓ 数据加载测试通过!")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""运行所有测试"""
|
||||||
|
print("开始测试修复后的IC版图扩散模型...")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
all_tests_passed = True
|
||||||
|
|
||||||
|
# 运行各项测试
|
||||||
|
tests = [
|
||||||
|
("模型架构", test_model_architecture),
|
||||||
|
("噪声调度器", test_scheduler),
|
||||||
|
("训练器", test_trainer),
|
||||||
|
("数据加载", test_data_loading),
|
||||||
|
]
|
||||||
|
|
||||||
|
for test_name, test_func in tests:
|
||||||
|
print(f"\n{'='*20} {test_name} {'='*20}")
|
||||||
|
if not test_func():
|
||||||
|
all_tests_passed = False
|
||||||
|
|
||||||
|
print("\n" + "=" * 50)
|
||||||
|
if all_tests_passed:
|
||||||
|
print("🎉 所有测试通过!模型修复成功。")
|
||||||
|
else:
|
||||||
|
print("❌ 部分测试失败,需要进一步修复。")
|
||||||
|
|
||||||
|
return all_tests_passed
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
success = main()
|
||||||
|
exit(0 if success else 1)
|
||||||
@@ -131,28 +131,54 @@ def train_optimized_diffusion(args):
|
|||||||
use_edge_condition=args.edge_condition
|
use_edge_condition=args.edge_condition
|
||||||
)
|
)
|
||||||
|
|
||||||
# 数据集分割
|
# 检查数据集是否为空
|
||||||
|
if len(dataset) == 0:
|
||||||
|
logger.error(f"数据集为空!请检查数据目录: {args.data_dir}")
|
||||||
|
raise ValueError(f"数据集为空,在目录 {args.data_dir} 中未找到图像文件")
|
||||||
|
|
||||||
|
logger.info(f"找到 {len(dataset)} 个训练样本")
|
||||||
|
|
||||||
|
# 数据集分割 - 修复空数据集问题
|
||||||
total_size = len(dataset)
|
total_size = len(dataset)
|
||||||
|
if total_size < 10: # 如果数据集太小,全部用于训练
|
||||||
|
logger.warning(f"数据集较小 ({total_size} 样本),全部用于训练")
|
||||||
|
train_dataset = dataset
|
||||||
|
val_dataset = None
|
||||||
|
else:
|
||||||
train_size = int(0.9 * total_size)
|
train_size = int(0.9 * total_size)
|
||||||
val_size = total_size - train_size
|
val_size = total_size - train_size
|
||||||
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
|
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
|
||||||
|
logger.info(f"训练集: {len(train_dataset)}, 验证集: {len(val_dataset)}")
|
||||||
|
|
||||||
# 数据加载器
|
# 数据加载器 - 修复None验证集问题
|
||||||
|
if device.type == 'cuda':
|
||||||
train_dataloader = DataLoader(
|
train_dataloader = DataLoader(
|
||||||
train_dataset,
|
train_dataset,
|
||||||
batch_size=args.batch_size,
|
batch_size=min(args.batch_size, len(train_dataset)), # 确保批次大小不超过数据集大小
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
num_workers=4,
|
num_workers=min(4, max(1, len(train_dataset) // args.batch_size)),
|
||||||
pin_memory=True
|
pin_memory=True,
|
||||||
|
drop_last=True # 避免最后一个不完整的批次
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
# CPU模式下使用较少的worker
|
||||||
|
train_dataloader = DataLoader(
|
||||||
|
train_dataset,
|
||||||
|
batch_size=min(args.batch_size, len(train_dataset)),
|
||||||
|
shuffle=True,
|
||||||
|
num_workers=0, # CPU模式下避免多进程
|
||||||
|
drop_last=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if val_dataset is not None:
|
||||||
val_dataloader = DataLoader(
|
val_dataloader = DataLoader(
|
||||||
val_dataset,
|
val_dataset,
|
||||||
batch_size=args.batch_size,
|
batch_size=min(args.batch_size, len(val_dataset)),
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
num_workers=2
|
num_workers=2
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
logger.info(f"训练集大小: {len(train_dataset)}, 验证集大小: {len(val_dataset)}")
|
val_dataloader = None
|
||||||
|
|
||||||
# 创建模型
|
# 创建模型
|
||||||
logger.info("创建优化模型...")
|
logger.info("创建优化模型...")
|
||||||
@@ -209,8 +235,13 @@ def train_optimized_diffusion(args):
|
|||||||
optimizer, train_dataloader, args.manhattan_weight
|
optimizer, train_dataloader, args.manhattan_weight
|
||||||
)
|
)
|
||||||
|
|
||||||
# 验证
|
# 验证 - 修复None验证集问题
|
||||||
|
if val_dataloader is not None:
|
||||||
val_loss = validate_model(trainer, val_dataloader, device)
|
val_loss = validate_model(trainer, val_dataloader, device)
|
||||||
|
else:
|
||||||
|
# 如果没有验证集,使用训练损失作为验证损失
|
||||||
|
val_loss = train_losses['total_loss']
|
||||||
|
logger.warning("未使用验证集 - 使用训练损失作为参考")
|
||||||
|
|
||||||
# 学习率调度
|
# 学习率调度
|
||||||
lr_scheduler.step()
|
lr_scheduler.step()
|
||||||
@@ -238,7 +269,7 @@ def train_optimized_diffusion(args):
|
|||||||
f"LR: {current_lr:.2e}"
|
f"LR: {current_lr:.2e}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 保存最佳模型
|
# 保存最佳模型 - 即使没有验证集也保存
|
||||||
if val_loss < best_val_loss:
|
if val_loss < best_val_loss:
|
||||||
best_val_loss = val_loss
|
best_val_loss = val_loss
|
||||||
best_model_path = output_dir / "best_model.pth"
|
best_model_path = output_dir / "best_model.pth"
|
||||||
|
|||||||
Reference in New Issue
Block a user