change some problem 6
This commit is contained in:
@@ -308,6 +308,21 @@ class ManhattanAwareUNet(nn.Module):
|
||||
# 解码器路径
|
||||
for i, (decoder, skip) in enumerate(zip(self.decoder, reversed(skips))):
|
||||
x = decoder(x)
|
||||
|
||||
# 确保跳跃连接尺寸匹配 - 如果尺寸不匹配则进行裁剪或填充
|
||||
if x.shape[2:] != skip.shape[2:]:
|
||||
# 裁剪到最小尺寸
|
||||
h_min = min(x.shape[2], skip.shape[2])
|
||||
w_min = min(x.shape[3], skip.shape[3])
|
||||
if x.shape[2] > h_min:
|
||||
x = x[:, :, :h_min, :]
|
||||
if x.shape[3] > w_min:
|
||||
x = x[:, :, :, :w_min]
|
||||
if skip.shape[2] > h_min:
|
||||
skip = skip[:, :, :h_min, :]
|
||||
if skip.shape[3] > w_min:
|
||||
skip = skip[:, :, :, :w_min]
|
||||
|
||||
x = x + skip # 跳跃连接
|
||||
|
||||
# 输出
|
||||
|
||||
Reference in New Issue
Block a user