change some problem 7
This commit is contained in:
@@ -13,6 +13,7 @@
|
||||
|
||||
import os
|
||||
import sys
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
@@ -64,15 +65,23 @@ class ICDiffusionDataset(Dataset):
|
||||
|
||||
def _extract_edges(self, image_tensor):
|
||||
"""提取边缘条件图"""
|
||||
# 使用Sobel算子提取边缘
|
||||
sobel_x = torch.tensor([[[[-1,0,1],[-2,0,2],[-1,0,1]]]],
|
||||
# 修复Sobel算子 - 正确的3x3 Sobel算子
|
||||
if len(image_tensor.shape) == 3: # [C, H, W]
|
||||
image_tensor = image_tensor.unsqueeze(0) # [1, C, H, W]
|
||||
|
||||
# 为单通道图像设计的Sobel算子
|
||||
sobel_x = torch.tensor([[[[-1.0, 0.0, 1.0],
|
||||
[-2.0, 0.0, 2.0],
|
||||
[-1.0, 0.0, 1.0]]]],
|
||||
dtype=image_tensor.dtype, device=image_tensor.device)
|
||||
sobel_y = torch.tensor([[[[-1,-2,-1],[0,0,0],[1,2,1]]]],
|
||||
sobel_y = torch.tensor([[[[-1.0, -2.0, -1.0],
|
||||
[0.0, 0.0, 0.0],
|
||||
[1.0, 2.0, 1.0]]]],
|
||||
dtype=image_tensor.dtype, device=image_tensor.device)
|
||||
|
||||
edge_x = F.conv2d(image_tensor.unsqueeze(0), sobel_x, padding=1)
|
||||
edge_y = F.conv2d(image_tensor.unsqueeze(0), sobel_y, padding=1)
|
||||
edge_magnitude = torch.sqrt(edge_x**2 + edge_y**2)
|
||||
edge_x = F.conv2d(image_tensor, sobel_x, padding=1)
|
||||
edge_y = F.conv2d(image_tensor, sobel_y, padding=1)
|
||||
edge_magnitude = torch.sqrt(edge_x**2 + edge_y**2 + 1e-8)
|
||||
|
||||
return torch.clamp(edge_magnitude, 0, 1)
|
||||
|
||||
@@ -170,8 +179,25 @@ def manhattan_regularization_loss(generated_image, device='cuda'):
|
||||
return torch.mean(angle_penalty * edge_magnitude)
|
||||
|
||||
|
||||
class SinusoidalPositionEmbeddings(nn.Module):
|
||||
"""正弦位置编码用于时间步嵌入"""
|
||||
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, time):
|
||||
device = time.device
|
||||
half_dim = self.dim // 2
|
||||
embeddings = math.log(10000) / (half_dim - 1)
|
||||
embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
|
||||
embeddings = time[:, None].float() * embeddings[None, :]
|
||||
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=1)
|
||||
return embeddings
|
||||
|
||||
|
||||
class ManhattanAwareUNet(nn.Module):
|
||||
"""曼哈顿几何感知的U-Net架构"""
|
||||
"""曼哈顿几何感知的U-Net架构 - 修复版"""
|
||||
|
||||
def __init__(self, in_channels=1, out_channels=1, time_dim=256, use_edge_condition=False):
|
||||
super().__init__()
|
||||
@@ -180,9 +206,9 @@ class ManhattanAwareUNet(nn.Module):
|
||||
# 输入通道数(原始图像 + 可选边缘条件)
|
||||
input_channels = in_channels + (1 if use_edge_condition else 0)
|
||||
|
||||
# 时间嵌入
|
||||
# 时间嵌入 - 修复时间步编码
|
||||
self.time_mlp = nn.Sequential(
|
||||
nn.Linear(1, time_dim),
|
||||
SinusoidalPositionEmbeddings(time_dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(time_dim, time_dim)
|
||||
)
|
||||
@@ -199,43 +225,34 @@ class ManhattanAwareUNet(nn.Module):
|
||||
nn.SiLU()
|
||||
)
|
||||
|
||||
# 编码器 - 增强版
|
||||
# 编码器通道配置
|
||||
encoder_channels = [64, 128, 256, 512] # 修正:减少通道数避免过拟合
|
||||
self.encoder = nn.ModuleList([
|
||||
self._make_block(64, 128),
|
||||
self._make_block(128, 256, stride=2),
|
||||
self._make_block(256, 512, stride=2),
|
||||
self._make_block(512, 1024, stride=2),
|
||||
])
|
||||
|
||||
# 残差投影层 - 确保残差连接的通道数和空间尺寸匹配
|
||||
self.residual_projections = nn.ModuleList([
|
||||
nn.Conv2d(64, 128, 1), # 64 -> 128, 保持尺寸
|
||||
nn.Conv2d(128, 256, 3, stride=2, padding=1), # 128 -> 256, 尺寸减半
|
||||
nn.Conv2d(256, 512, 3, stride=2, padding=1), # 256 -> 512, 尺寸减半
|
||||
nn.Conv2d(512, 1024, 3, stride=2, padding=1), # 512 -> 1024, 尺寸减半
|
||||
self._make_block(encoder_channels[i], encoder_channels[i+1], stride=2 if i > 0 else 1)
|
||||
for i in range(len(encoder_channels)-1)
|
||||
])
|
||||
|
||||
# 中间层
|
||||
self.middle = nn.Sequential(
|
||||
nn.Conv2d(1024, 1024, 3, padding=1),
|
||||
nn.GroupNorm(8, 1024),
|
||||
nn.Conv2d(512, 512, 3, padding=1),
|
||||
nn.GroupNorm(8, 512),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(1024, 1024, 3, padding=1),
|
||||
nn.GroupNorm(8, 1024),
|
||||
nn.Conv2d(512, 512, 3, padding=1),
|
||||
nn.GroupNorm(8, 512),
|
||||
nn.SiLU(),
|
||||
)
|
||||
|
||||
# 解码器
|
||||
# 解码器 - 修复通道数计算
|
||||
self.decoder = nn.ModuleList([
|
||||
self._make_decoder_block(1024, 512),
|
||||
self._make_decoder_block(512, 256),
|
||||
self._make_decoder_block(256, 128),
|
||||
self._make_decoder_block(128, 64),
|
||||
self._make_decoder_block(512, 256), # middle (512) -> 256
|
||||
self._make_decoder_block(512, 128), # 256+256(skip) -> 128
|
||||
self._make_decoder_block(256, 64), # 128+128(skip) -> 64
|
||||
self._make_decoder_block(128, 64), # 64+64(skip) -> 64
|
||||
])
|
||||
|
||||
# 输出层
|
||||
# 输出层 - 修复输入通道数
|
||||
self.output = nn.Sequential(
|
||||
nn.Conv2d(64, 32, 3, padding=1),
|
||||
nn.Conv2d(64, 32, 3, padding=1), # 最后一层跳跃连接后是64通道
|
||||
nn.GroupNorm(8, 32),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(32, out_channels, 3, padding=1)
|
||||
@@ -243,10 +260,7 @@ class ManhattanAwareUNet(nn.Module):
|
||||
|
||||
# 时间融合层 - 与编码器输出通道数匹配
|
||||
self.time_fusion = nn.ModuleList([
|
||||
nn.Linear(time_dim, 128),
|
||||
nn.Linear(time_dim, 256),
|
||||
nn.Linear(time_dim, 512),
|
||||
nn.Linear(time_dim, 1024),
|
||||
nn.Linear(time_dim, channels) for channels in encoder_channels[1:] + [512]
|
||||
])
|
||||
|
||||
def _make_block(self, in_channels, out_channels, stride=1):
|
||||
@@ -276,8 +290,8 @@ class ManhattanAwareUNet(nn.Module):
|
||||
if self.use_edge_condition and edge_condition is not None:
|
||||
x = torch.cat([x, edge_condition], dim=1)
|
||||
|
||||
# 时间嵌入
|
||||
t_emb = self.time_mlp(t.float().unsqueeze(-1)) # [B, time_dim]
|
||||
# 时间嵌入 - 使用正弦位置编码
|
||||
t_emb = self.time_mlp(t) # [B, time_dim]
|
||||
|
||||
# 曼哈顿几何感知的特征提取
|
||||
h_features = F.silu(self.horiz_conv(x))
|
||||
@@ -285,45 +299,58 @@ class ManhattanAwareUNet(nn.Module):
|
||||
s_features = F.silu(self.standard_conv(x))
|
||||
|
||||
# 融合特征
|
||||
x = torch.cat([h_features, v_features, s_features], dim=1)
|
||||
x = self.initial_fusion(x)
|
||||
x = torch.cat([h_features, v_features, s_features], dim=1) # [B, 96, H, W]
|
||||
x = self.initial_fusion(x) # [B, 64, H, W]
|
||||
|
||||
# 编码器路径
|
||||
# 编码器路径 - 修复跳跃连接逻辑
|
||||
skips = []
|
||||
for i, (encoder, fusion) in enumerate(zip(self.encoder, self.time_fusion)):
|
||||
# 保存残差连接并使用投影层匹配通道数
|
||||
residual = self.residual_projections[i](x)
|
||||
# 保存跳跃连接(在编码之前)
|
||||
skips.append(x)
|
||||
|
||||
# 编码
|
||||
x = encoder(x)
|
||||
|
||||
# 融合时间信息
|
||||
t_feat = fusion(t_emb).unsqueeze(-1).unsqueeze(-1)
|
||||
x = x + t_feat
|
||||
t_feat = fusion(t_emb).unsqueeze(-1).unsqueeze(-1) # [B, channels, 1, 1]
|
||||
|
||||
# 跳跃连接 - 添加投影后的残差
|
||||
skips.append(x + residual)
|
||||
# 检查通道数是否匹配
|
||||
if x.shape[1] == t_feat.shape[1]:
|
||||
x = x + t_feat
|
||||
else:
|
||||
# 如果不匹配,使用1x1卷积调整通道数
|
||||
if not hasattr(self, f'time_proj_{i}'):
|
||||
setattr(self, f'time_proj_{i}',
|
||||
nn.Conv2d(t_feat.shape[1], x.shape[1], 1).to(x.device))
|
||||
time_proj = getattr(self, f'time_proj_{i}')
|
||||
x = x + time_proj(t_feat)
|
||||
|
||||
# 中间层
|
||||
x = self.middle(x)
|
||||
|
||||
# 解码器路径
|
||||
for i, (decoder, skip) in enumerate(zip(self.decoder, reversed(skips))):
|
||||
# 解码器路径 - 修复跳跃连接逻辑
|
||||
for i, decoder in enumerate(self.decoder):
|
||||
# 获取对应的跳跃连接(反向顺序)
|
||||
skip = skips[-(i+1)]
|
||||
|
||||
# 上采样
|
||||
x = decoder(x)
|
||||
|
||||
# 确保跳跃连接尺寸匹配 - 如果尺寸不匹配则进行裁剪或填充
|
||||
# 确保跳跃连接尺寸匹配
|
||||
if x.shape[2:] != skip.shape[2:]:
|
||||
# 裁剪到最小尺寸
|
||||
h_min = min(x.shape[2], skip.shape[2])
|
||||
w_min = min(x.shape[3], skip.shape[3])
|
||||
if x.shape[2] > h_min:
|
||||
x = x[:, :, :h_min, :]
|
||||
if x.shape[3] > w_min:
|
||||
x = x[:, :, :, :w_min]
|
||||
if skip.shape[2] > h_min:
|
||||
skip = skip[:, :, :h_min, :]
|
||||
if skip.shape[3] > w_min:
|
||||
skip = skip[:, :, :, :w_min]
|
||||
# 使用插值调整尺寸
|
||||
x = F.interpolate(x, size=skip.shape[2:], mode='bilinear', align_corners=False)
|
||||
|
||||
x = x + skip # 跳跃连接
|
||||
# 跳跃连接(需要处理通道数匹配)
|
||||
if x.shape[1] == skip.shape[1]:
|
||||
x = x + skip
|
||||
else:
|
||||
# 如果通道数不匹配,使用1x1卷积调整
|
||||
if not hasattr(self, f'skip_proj_{i}'):
|
||||
setattr(self, f'skip_proj_{i}',
|
||||
nn.Conv2d(skip.shape[1], x.shape[1], 1).to(x.device))
|
||||
skip_proj = getattr(self, f'skip_proj_{i}')
|
||||
x = x + skip_proj(skip)
|
||||
|
||||
# 输出
|
||||
x = self.output(x)
|
||||
@@ -358,14 +385,11 @@ class OptimizedNoiseScheduler:
|
||||
def add_noise(self, x_0, t):
|
||||
"""向干净图像添加噪声"""
|
||||
noise = torch.randn_like(x_0)
|
||||
# 确保调度器张量与输入张量在同一设备上
|
||||
device = x_0.device
|
||||
if self.sqrt_alphas_cumprod.device != device:
|
||||
self.sqrt_alphas_cumprod = self.sqrt_alphas_cumprod.to(device)
|
||||
self.sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod.to(device)
|
||||
|
||||
sqrt_alphas_cumprod_t = self.sqrt_alphas_cumprod[t].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
||||
sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
||||
# 确保调度器张量与输入张量在同一设备上
|
||||
sqrt_alphas_cumprod_t = self.sqrt_alphas_cumprod[t].to(device).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
||||
sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t].to(device).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
||||
|
||||
return sqrt_alphas_cumprod_t * x_0 + sqrt_one_minus_alphas_cumprod_t * noise, noise
|
||||
|
||||
@@ -381,18 +405,13 @@ class OptimizedNoiseScheduler:
|
||||
# 预测噪声
|
||||
predicted_noise = model(x_t, t)
|
||||
|
||||
# 确保调度器张量与输入张量在同一设备上
|
||||
device = x_t.device
|
||||
if self.alphas.device != device:
|
||||
self.alphas = self.alphas.to(device)
|
||||
self.betas = self.betas.to(device)
|
||||
self.sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod.to(device)
|
||||
|
||||
# 计算系数
|
||||
alpha_t = self.alphas[t].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
||||
# 计算系数(直接使用索引并移动到设备)
|
||||
alpha_t = self.alphas[t].to(device).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
||||
sqrt_alpha_t = torch.sqrt(alpha_t)
|
||||
beta_t = self.betas[t].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
||||
sqrt_one_minus_alpha_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
||||
beta_t = self.betas[t].to(device).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
||||
sqrt_one_minus_alpha_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t].to(device).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
||||
|
||||
# 计算均值
|
||||
model_mean = (1.0 / sqrt_alpha_t) * (x_t - (beta_t / sqrt_one_minus_alpha_cumprod_t) * predicted_noise)
|
||||
|
||||
Reference in New Issue
Block a user