diff --git a/tools/diffusion/ic_layout_diffusion_optimized.py b/tools/diffusion/ic_layout_diffusion_optimized.py index ec5fa56..dd80cad 100644 --- a/tools/diffusion/ic_layout_diffusion_optimized.py +++ b/tools/diffusion/ic_layout_diffusion_optimized.py @@ -429,9 +429,9 @@ def manhattan_post_process(image, threshold=0.5): # 二值化 binary = (image > threshold).float() - # 形态学操作强化直角特征 - kernel_h = torch.tensor([[[[1,1,1]]]], device=device) - kernel_v = torch.tensor([[[[1],[1],[1]]]], device=device) + # 形态学操作强化直角特征 - 使用浮点类型 + kernel_h = torch.tensor([[[[1.,1.,1.]]]], device=device, dtype=image.dtype) + kernel_v = torch.tensor([[[[1.],[1.],[1.]]]], device=device, dtype=image.dtype) # 水平和垂直增强 horizontal = F.conv2d(binary, kernel_h, padding=(0,1))