change some problem 7
This commit is contained in:
@@ -108,9 +108,9 @@ class EdgeAwareLoss(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# 注册为缓冲区以避免重复创建
|
||||
self.register_buffer('sobel_x', torch.tensor([[[[-1,0,1],[-2,0,2],[-1,0,1]]]]))
|
||||
self.register_buffer('sobel_y', torch.tensor([[[[-1,-2,-1],[0,0,0],[1,2,1]]]]))
|
||||
# 注册为缓冲区以避免重复创建,并指定为浮点类型
|
||||
self.register_buffer('sobel_x', torch.tensor([[[[-1.,0.,1.],[-2.,0.,2.],[-1.,0.,1.]]]]))
|
||||
self.register_buffer('sobel_y', torch.tensor([[[[-1.,-2.,-1.],[0.,0.,0.],[1.,2.,1.]]]]))
|
||||
|
||||
def forward(self, pred, target):
|
||||
# 原始MSE损失
|
||||
|
||||
Reference in New Issue
Block a user