change some problem 7

This commit is contained in:
Jiao77
2025-11-20 03:09:18 +08:00
parent 3d75ed722a
commit 3258b7b6de
3 changed files with 407 additions and 100 deletions

View File

@@ -13,6 +13,7 @@
import os
import sys
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
@@ -64,15 +65,23 @@ class ICDiffusionDataset(Dataset):
def _extract_edges(self, image_tensor):
"""提取边缘条件图"""
# 使用Sobel算子提取边缘
sobel_x = torch.tensor([[[[-1,0,1],[-2,0,2],[-1,0,1]]]],
# 修复Sobel算子 - 正确的3x3 Sobel算子
if len(image_tensor.shape) == 3: # [C, H, W]
image_tensor = image_tensor.unsqueeze(0) # [1, C, H, W]
# 为单通道图像设计的Sobel算子
sobel_x = torch.tensor([[[[-1.0, 0.0, 1.0],
[-2.0, 0.0, 2.0],
[-1.0, 0.0, 1.0]]]],
dtype=image_tensor.dtype, device=image_tensor.device)
sobel_y = torch.tensor([[[[-1,-2,-1],[0,0,0],[1,2,1]]]],
sobel_y = torch.tensor([[[[-1.0, -2.0, -1.0],
[0.0, 0.0, 0.0],
[1.0, 2.0, 1.0]]]],
dtype=image_tensor.dtype, device=image_tensor.device)
edge_x = F.conv2d(image_tensor.unsqueeze(0), sobel_x, padding=1)
edge_y = F.conv2d(image_tensor.unsqueeze(0), sobel_y, padding=1)
edge_magnitude = torch.sqrt(edge_x**2 + edge_y**2)
edge_x = F.conv2d(image_tensor, sobel_x, padding=1)
edge_y = F.conv2d(image_tensor, sobel_y, padding=1)
edge_magnitude = torch.sqrt(edge_x**2 + edge_y**2 + 1e-8)
return torch.clamp(edge_magnitude, 0, 1)
@@ -170,8 +179,25 @@ def manhattan_regularization_loss(generated_image, device='cuda'):
return torch.mean(angle_penalty * edge_magnitude)
class SinusoidalPositionEmbeddings(nn.Module):
"""正弦位置编码用于时间步嵌入"""
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, time):
device = time.device
half_dim = self.dim // 2
embeddings = math.log(10000) / (half_dim - 1)
embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
embeddings = time[:, None].float() * embeddings[None, :]
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=1)
return embeddings
class ManhattanAwareUNet(nn.Module):
"""曼哈顿几何感知的U-Net架构"""
"""曼哈顿几何感知的U-Net架构 - 修复版"""
def __init__(self, in_channels=1, out_channels=1, time_dim=256, use_edge_condition=False):
super().__init__()
@@ -180,9 +206,9 @@ class ManhattanAwareUNet(nn.Module):
# 输入通道数(原始图像 + 可选边缘条件)
input_channels = in_channels + (1 if use_edge_condition else 0)
# 时间嵌入
# 时间嵌入 - 修复时间步编码
self.time_mlp = nn.Sequential(
nn.Linear(1, time_dim),
SinusoidalPositionEmbeddings(time_dim),
nn.SiLU(),
nn.Linear(time_dim, time_dim)
)
@@ -199,43 +225,34 @@ class ManhattanAwareUNet(nn.Module):
nn.SiLU()
)
# 编码器 - 增强版
# 编码器通道配置
encoder_channels = [64, 128, 256, 512] # 修正:减少通道数避免过拟合
self.encoder = nn.ModuleList([
self._make_block(64, 128),
self._make_block(128, 256, stride=2),
self._make_block(256, 512, stride=2),
self._make_block(512, 1024, stride=2),
])
# 残差投影层 - 确保残差连接的通道数和空间尺寸匹配
self.residual_projections = nn.ModuleList([
nn.Conv2d(64, 128, 1), # 64 -> 128, 保持尺寸
nn.Conv2d(128, 256, 3, stride=2, padding=1), # 128 -> 256, 尺寸减半
nn.Conv2d(256, 512, 3, stride=2, padding=1), # 256 -> 512, 尺寸减半
nn.Conv2d(512, 1024, 3, stride=2, padding=1), # 512 -> 1024, 尺寸减半
self._make_block(encoder_channels[i], encoder_channels[i+1], stride=2 if i > 0 else 1)
for i in range(len(encoder_channels)-1)
])
# 中间层
self.middle = nn.Sequential(
nn.Conv2d(1024, 1024, 3, padding=1),
nn.GroupNorm(8, 1024),
nn.Conv2d(512, 512, 3, padding=1),
nn.GroupNorm(8, 512),
nn.SiLU(),
nn.Conv2d(1024, 1024, 3, padding=1),
nn.GroupNorm(8, 1024),
nn.Conv2d(512, 512, 3, padding=1),
nn.GroupNorm(8, 512),
nn.SiLU(),
)
# 解码器
# 解码器 - 修复通道数计算
self.decoder = nn.ModuleList([
self._make_decoder_block(1024, 512),
self._make_decoder_block(512, 256),
self._make_decoder_block(256, 128),
self._make_decoder_block(128, 64),
self._make_decoder_block(512, 256), # middle (512) -> 256
self._make_decoder_block(512, 128), # 256+256(skip) -> 128
self._make_decoder_block(256, 64), # 128+128(skip) -> 64
self._make_decoder_block(128, 64), # 64+64(skip) -> 64
])
# 输出层
# 输出层 - 修复输入通道数
self.output = nn.Sequential(
nn.Conv2d(64, 32, 3, padding=1),
nn.Conv2d(64, 32, 3, padding=1), # 最后一层跳跃连接后是64通道
nn.GroupNorm(8, 32),
nn.SiLU(),
nn.Conv2d(32, out_channels, 3, padding=1)
@@ -243,10 +260,7 @@ class ManhattanAwareUNet(nn.Module):
# 时间融合层 - 与编码器输出通道数匹配
self.time_fusion = nn.ModuleList([
nn.Linear(time_dim, 128),
nn.Linear(time_dim, 256),
nn.Linear(time_dim, 512),
nn.Linear(time_dim, 1024),
nn.Linear(time_dim, channels) for channels in encoder_channels[1:] + [512]
])
def _make_block(self, in_channels, out_channels, stride=1):
@@ -276,8 +290,8 @@ class ManhattanAwareUNet(nn.Module):
if self.use_edge_condition and edge_condition is not None:
x = torch.cat([x, edge_condition], dim=1)
# 时间嵌入
t_emb = self.time_mlp(t.float().unsqueeze(-1)) # [B, time_dim]
# 时间嵌入 - 使用正弦位置编码
t_emb = self.time_mlp(t) # [B, time_dim]
# 曼哈顿几何感知的特征提取
h_features = F.silu(self.horiz_conv(x))
@@ -285,45 +299,58 @@ class ManhattanAwareUNet(nn.Module):
s_features = F.silu(self.standard_conv(x))
# 融合特征
x = torch.cat([h_features, v_features, s_features], dim=1)
x = self.initial_fusion(x)
x = torch.cat([h_features, v_features, s_features], dim=1) # [B, 96, H, W]
x = self.initial_fusion(x) # [B, 64, H, W]
# 编码器路径
# 编码器路径 - 修复跳跃连接逻辑
skips = []
for i, (encoder, fusion) in enumerate(zip(self.encoder, self.time_fusion)):
# 保存残差连接并使用投影层匹配通道数
residual = self.residual_projections[i](x)
# 保存跳跃连接(在编码之前)
skips.append(x)
# 编码
x = encoder(x)
# 融合时间信息
t_feat = fusion(t_emb).unsqueeze(-1).unsqueeze(-1)
x = x + t_feat
t_feat = fusion(t_emb).unsqueeze(-1).unsqueeze(-1) # [B, channels, 1, 1]
# 跳跃连接 - 添加投影后的残差
skips.append(x + residual)
# 检查通道数是否匹配
if x.shape[1] == t_feat.shape[1]:
x = x + t_feat
else:
# 如果不匹配使用1x1卷积调整通道数
if not hasattr(self, f'time_proj_{i}'):
setattr(self, f'time_proj_{i}',
nn.Conv2d(t_feat.shape[1], x.shape[1], 1).to(x.device))
time_proj = getattr(self, f'time_proj_{i}')
x = x + time_proj(t_feat)
# 中间层
x = self.middle(x)
# 解码器路径
for i, (decoder, skip) in enumerate(zip(self.decoder, reversed(skips))):
# 解码器路径 - 修复跳跃连接逻辑
for i, decoder in enumerate(self.decoder):
# 获取对应的跳跃连接(反向顺序)
skip = skips[-(i+1)]
# 上采样
x = decoder(x)
# 确保跳跃连接尺寸匹配 - 如果尺寸不匹配则进行裁剪或填充
# 确保跳跃连接尺寸匹配
if x.shape[2:] != skip.shape[2:]:
# 裁剪到最小尺寸
h_min = min(x.shape[2], skip.shape[2])
w_min = min(x.shape[3], skip.shape[3])
if x.shape[2] > h_min:
x = x[:, :, :h_min, :]
if x.shape[3] > w_min:
x = x[:, :, :, :w_min]
if skip.shape[2] > h_min:
skip = skip[:, :, :h_min, :]
if skip.shape[3] > w_min:
skip = skip[:, :, :, :w_min]
# 使用插值调整尺寸
x = F.interpolate(x, size=skip.shape[2:], mode='bilinear', align_corners=False)
x = x + skip # 跳跃连接
# 跳跃连接(需要处理通道数匹配)
if x.shape[1] == skip.shape[1]:
x = x + skip
else:
# 如果通道数不匹配使用1x1卷积调整
if not hasattr(self, f'skip_proj_{i}'):
setattr(self, f'skip_proj_{i}',
nn.Conv2d(skip.shape[1], x.shape[1], 1).to(x.device))
skip_proj = getattr(self, f'skip_proj_{i}')
x = x + skip_proj(skip)
# 输出
x = self.output(x)
@@ -358,14 +385,11 @@ class OptimizedNoiseScheduler:
def add_noise(self, x_0, t):
"""向干净图像添加噪声"""
noise = torch.randn_like(x_0)
# 确保调度器张量与输入张量在同一设备上
device = x_0.device
if self.sqrt_alphas_cumprod.device != device:
self.sqrt_alphas_cumprod = self.sqrt_alphas_cumprod.to(device)
self.sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod.to(device)
sqrt_alphas_cumprod_t = self.sqrt_alphas_cumprod[t].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
# 确保调度器张量与输入张量在同一设备上
sqrt_alphas_cumprod_t = self.sqrt_alphas_cumprod[t].to(device).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t].to(device).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
return sqrt_alphas_cumprod_t * x_0 + sqrt_one_minus_alphas_cumprod_t * noise, noise
@@ -381,18 +405,13 @@ class OptimizedNoiseScheduler:
# 预测噪声
predicted_noise = model(x_t, t)
# 确保调度器张量与输入张量在同一设备上
device = x_t.device
if self.alphas.device != device:
self.alphas = self.alphas.to(device)
self.betas = self.betas.to(device)
self.sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod.to(device)
# 计算系数
alpha_t = self.alphas[t].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
# 计算系数(直接使用索引并移动到设备)
alpha_t = self.alphas[t].to(device).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
sqrt_alpha_t = torch.sqrt(alpha_t)
beta_t = self.betas[t].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
sqrt_one_minus_alpha_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
beta_t = self.betas[t].to(device).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
sqrt_one_minus_alpha_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t].to(device).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
# 计算均值
model_mean = (1.0 / sqrt_alpha_t) * (x_t - (beta_t / sqrt_one_minus_alpha_cumprod_t) * predicted_noise)