From f8975b26b411607c2012517f96068c78b490e54c Mon Sep 17 00:00:00 2001 From: Jiao77 Date: Thu, 20 Nov 2025 03:04:56 +0800 Subject: [PATCH] change some problem 4 --- tools/diffusion/ic_layout_diffusion_optimized.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/tools/diffusion/ic_layout_diffusion_optimized.py b/tools/diffusion/ic_layout_diffusion_optimized.py index 98064c4..8920741 100644 --- a/tools/diffusion/ic_layout_diffusion_optimized.py +++ b/tools/diffusion/ic_layout_diffusion_optimized.py @@ -207,6 +207,14 @@ class ManhattanAwareUNet(nn.Module): self._make_block(512, 1024, stride=2), ]) + # 残差投影层 - 确保残差连接的通道数匹配 + self.residual_projections = nn.ModuleList([ + nn.Conv2d(64, 128, 1), # 64 -> 128 + nn.Conv2d(128, 256, 1), # 128 -> 256 + nn.Conv2d(256, 512, 1), # 256 -> 512 + nn.Conv2d(512, 1024, 1), # 512 -> 1024 + ]) + # 中间层 self.middle = nn.Sequential( nn.Conv2d(1024, 1024, 3, padding=1), @@ -283,16 +291,16 @@ class ManhattanAwareUNet(nn.Module): # 编码器路径 skips = [] for i, (encoder, fusion) in enumerate(zip(self.encoder, self.time_fusion)): - # 残差连接 - residual = x + # 保存残差连接并使用投影层匹配通道数 + residual = self.residual_projections[i](x) x = encoder(x) # 融合时间信息 t_feat = fusion(t_emb).unsqueeze(-1).unsqueeze(-1) x = x + t_feat - # 跳跃连接 - skips.append(x + residual if i == 0 else x) + # 跳跃连接 - 添加投影后的残差 + skips.append(x + residual) # 中间层 x = self.middle(x)