From bacf8cd69d9f7cf7a194fcabd796aa050a3023d3 Mon Sep 17 00:00:00 2001 From: Jiao77 Date: Thu, 20 Nov 2025 03:11:33 +0800 Subject: [PATCH] change some problem 7 --- tools/diffusion/ic_layout_diffusion_optimized.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tools/diffusion/ic_layout_diffusion_optimized.py b/tools/diffusion/ic_layout_diffusion_optimized.py index 2b26dd1..ec5fa56 100644 --- a/tools/diffusion/ic_layout_diffusion_optimized.py +++ b/tools/diffusion/ic_layout_diffusion_optimized.py @@ -108,9 +108,9 @@ class EdgeAwareLoss(nn.Module): def __init__(self): super().__init__() - # 注册为缓冲区以避免重复创建 - self.register_buffer('sobel_x', torch.tensor([[[[-1,0,1],[-2,0,2],[-1,0,1]]]])) - self.register_buffer('sobel_y', torch.tensor([[[[-1,-2,-1],[0,0,0],[1,2,1]]]])) + # 注册为缓冲区以避免重复创建,并指定为浮点类型 + self.register_buffer('sobel_x', torch.tensor([[[[-1.,0.,1.],[-2.,0.,2.],[-1.,0.,1.]]]])) + self.register_buffer('sobel_y', torch.tensor([[[[-1.,-2.,-1.],[0.,0.,0.],[1.,2.,1.]]]])) def forward(self, pred, target): # 原始MSE损失