diff --git a/tools/diffusion/ic_layout_diffusion_optimized.py b/tools/diffusion/ic_layout_diffusion_optimized.py index 8920741..4cc0997 100644 --- a/tools/diffusion/ic_layout_diffusion_optimized.py +++ b/tools/diffusion/ic_layout_diffusion_optimized.py @@ -207,12 +207,12 @@ class ManhattanAwareUNet(nn.Module): self._make_block(512, 1024, stride=2), ]) - # 残差投影层 - 确保残差连接的通道数匹配 + # 残差投影层 - 确保残差连接的通道数和空间尺寸匹配 self.residual_projections = nn.ModuleList([ - nn.Conv2d(64, 128, 1), # 64 -> 128 - nn.Conv2d(128, 256, 1), # 128 -> 256 - nn.Conv2d(256, 512, 1), # 256 -> 512 - nn.Conv2d(512, 1024, 1), # 512 -> 1024 + nn.Conv2d(64, 128, 1), # 64 -> 128, 保持尺寸 + nn.Conv2d(128, 256, 3, stride=2, padding=1), # 128 -> 256, 尺寸减半 + nn.Conv2d(256, 512, 3, stride=2, padding=1), # 256 -> 512, 尺寸减半 + nn.Conv2d(512, 1024, 3, stride=2, padding=1), # 512 -> 1024, 尺寸减半 ]) # 中间层