diff --git a/tools/diffusion/ic_layout_diffusion_optimized.py b/tools/diffusion/ic_layout_diffusion_optimized.py index 2b26dd1..ec5fa56 100644 --- a/tools/diffusion/ic_layout_diffusion_optimized.py +++ b/tools/diffusion/ic_layout_diffusion_optimized.py @@ -108,9 +108,9 @@ class EdgeAwareLoss(nn.Module): def __init__(self): super().__init__() - # 注册为缓冲区以避免重复创建 - self.register_buffer('sobel_x', torch.tensor([[[[-1,0,1],[-2,0,2],[-1,0,1]]]])) - self.register_buffer('sobel_y', torch.tensor([[[[-1,-2,-1],[0,0,0],[1,2,1]]]])) + # 注册为缓冲区以避免重复创建,并指定为浮点类型 + self.register_buffer('sobel_x', torch.tensor([[[[-1.,0.,1.],[-2.,0.,2.],[-1.,0.,1.]]]])) + self.register_buffer('sobel_y', torch.tensor([[[[-1.,-2.,-1.],[0.,0.,0.],[1.,2.,1.]]]])) def forward(self, pred, target): # 原始MSE损失