change some problem 6
This commit is contained in:
@@ -308,6 +308,21 @@ class ManhattanAwareUNet(nn.Module):
|
|||||||
# 解码器路径
|
# 解码器路径
|
||||||
for i, (decoder, skip) in enumerate(zip(self.decoder, reversed(skips))):
|
for i, (decoder, skip) in enumerate(zip(self.decoder, reversed(skips))):
|
||||||
x = decoder(x)
|
x = decoder(x)
|
||||||
|
|
||||||
|
# 确保跳跃连接尺寸匹配 - 如果尺寸不匹配则进行裁剪或填充
|
||||||
|
if x.shape[2:] != skip.shape[2:]:
|
||||||
|
# 裁剪到最小尺寸
|
||||||
|
h_min = min(x.shape[2], skip.shape[2])
|
||||||
|
w_min = min(x.shape[3], skip.shape[3])
|
||||||
|
if x.shape[2] > h_min:
|
||||||
|
x = x[:, :, :h_min, :]
|
||||||
|
if x.shape[3] > w_min:
|
||||||
|
x = x[:, :, :, :w_min]
|
||||||
|
if skip.shape[2] > h_min:
|
||||||
|
skip = skip[:, :, :h_min, :]
|
||||||
|
if skip.shape[3] > w_min:
|
||||||
|
skip = skip[:, :, :, :w_min]
|
||||||
|
|
||||||
x = x + skip # 跳跃连接
|
x = x + skip # 跳跃连接
|
||||||
|
|
||||||
# 输出
|
# 输出
|
||||||
|
|||||||
Reference in New Issue
Block a user