change some problem 7

This commit is contained in:
Jiao77
2025-11-20 03:11:33 +08:00
parent 26763fa75c
commit bacf8cd69d

View File

@@ -108,9 +108,9 @@ class EdgeAwareLoss(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
# 注册为缓冲区以避免重复创建 # 注册为缓冲区以避免重复创建,并指定为浮点类型
self.register_buffer('sobel_x', torch.tensor([[[[-1,0,1],[-2,0,2],[-1,0,1]]]])) self.register_buffer('sobel_x', torch.tensor([[[[-1.,0.,1.],[-2.,0.,2.],[-1.,0.,1.]]]]))
self.register_buffer('sobel_y', torch.tensor([[[[-1,-2,-1],[0,0,0],[1,2,1]]]])) self.register_buffer('sobel_y', torch.tensor([[[[-1.,-2.,-1.],[0.,0.,0.],[1.,2.,1.]]]]))
def forward(self, pred, target): def forward(self, pred, target):
# 原始MSE损失 # 原始MSE损失