change some problem 7
This commit is contained in:
@@ -108,9 +108,9 @@ class EdgeAwareLoss(nn.Module):
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# 注册为缓冲区以避免重复创建
|
# 注册为缓冲区以避免重复创建,并指定为浮点类型
|
||||||
self.register_buffer('sobel_x', torch.tensor([[[[-1,0,1],[-2,0,2],[-1,0,1]]]]))
|
self.register_buffer('sobel_x', torch.tensor([[[[-1.,0.,1.],[-2.,0.,2.],[-1.,0.,1.]]]]))
|
||||||
self.register_buffer('sobel_y', torch.tensor([[[[-1,-2,-1],[0,0,0],[1,2,1]]]]))
|
self.register_buffer('sobel_y', torch.tensor([[[[-1.,-2.,-1.],[0.,0.,0.],[1.,2.,1.]]]]))
|
||||||
|
|
||||||
def forward(self, pred, target):
|
def forward(self, pred, target):
|
||||||
# 原始MSE损失
|
# 原始MSE损失
|
||||||
|
|||||||
Reference in New Issue
Block a user