diff --git a/ai_layout_match/dataset.py b/ai_layout_match/dataset.py new file mode 100644 index 0000000..e69de29 diff --git a/ai_layout_match/models.py b/ai_layout_match/models.py new file mode 100644 index 0000000..e69de29 diff --git a/ai_layout_match/utils.py b/ai_layout_match/utils.py new file mode 100644 index 0000000..e69de29 diff --git a/data_units.py b/data_units.py index 4518af2..93ff820 100644 --- a/data_units.py +++ b/data_units.py @@ -19,10 +19,10 @@ def layout_to_tensor(layout_path, target_size=(256, 256)): img = img.resize(target_size, resample=Image.BILINEAR) return np.array(img) / 255.0 # 归一化到[0,1] -def tile_layout(large_layout, block_size=64): +def tile_layout(large_layout, block_size=64, overlap_ratio=0.5): """将大版图分割为小块(滑动窗口方式)""" height, width = large_layout.shape - stride = block_size // 2 # 步长设置重叠区域 + stride = int(block_size * (1 - overlap_ratio)) # 步长设置重叠区域 tiles = [] for y in range(0, height - block_size +1, stride): for x in range(0, width - block_size +1, stride): diff --git a/inference.py b/inference.py index e85f41f..325d1a3 100644 --- a/inference.py +++ b/inference.py @@ -1,15 +1,8 @@ import faiss import numpy as np import torch -# 导入 models.rotation_cnn 模块中的 RotationInvariantNet 类 -from models.rotation_cnn import RotationInvariantNet - -from models.rotation_cnn import get_rotational_features - -# 导入 data_utils 中的 layout_to_tensor 函数(假设该函数存在) -from data_units import layout_to_tensor # 如果 data_utils.py 存在此函数 - -from data_units import tile_layout +from models.rotation_cnn import RotationInvariantNet, get_rotational_features +from data_units import layout_to_tensor, tile_layout def main(): # 配置参数(需根据实际调整) @@ -17,7 +10,6 @@ def main(): target_module_path = "target.png" large_layout_path = "layout_large.png" - # 加载模型 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = RotationInvariantNet().to(device) model.load_state_dict(torch.load("rotation_cnn.pth")) @@ -33,16 +25,27 @@ def main(): # 构建特征索引(使用Faiss加速) index = faiss.IndexFlatL2(64) # 特征维度由模型决定 features_db = [] - for (x,y,tile) in tiles: + for (x, y, tile) in tiles: feat = get_rotational_features(model, torch.tensor(tile).to(device)) features_db.append(feat) index.add(np.stack(features_db)) # 检索相似区域 - D, I = index.search(target_feat[np.newaxis,:], k=10) - for idx in I[0]: - x,y,_ = tiles[idx] - print(f"匹配区域坐标: ({x}, {y}), 相似度: {D[0][idx]}") + D, I = index.search(target_feat[np.newaxis, :], k=10) + for idx in I[0]: + x, y, _ = tiles[idx] + + # 计算最佳匹配角度的显式计算 + min_angle, min_dist = 90, float('inf') + target_vec = target_feat + feat = features_db[idx] + for a in [0, 1, 2, 3]: # 代表0°、90°、180°、270° + rotated_feat = np.rot90(feat.reshape(block_size, block_size), k=a) + dist = np.linalg.norm(target_vec - rotated_feat.flatten()) + if dist < min_dist: + min_dist, min_angle = dist, a * 90 + + print(f"坐标({x},{y}), 最佳旋转方向{min_angle}度,距离: {min_dist}") if __name__ == "__main__": main() \ No newline at end of file diff --git a/models/rotation_cnn.py b/models/rotation_cnn.py index 1e81d84..226cbde 100644 --- a/models/rotation_cnn.py +++ b/models/rotation_cnn.py @@ -3,7 +3,7 @@ import torch.nn as nn class RotationInvariantNet(nn.Module): """轻量级旋转不变特征提取网络""" - def __init__(self, input_channels=1, num_features=64): + def __init__(self, input_channels=1): super().__init__() self.cnn = nn.Sequential( # 基础卷积层 @@ -12,13 +12,14 @@ class RotationInvariantNet(nn.Module): nn.MaxPool2d(2), # 下采样 nn.Conv2d(32, 64, kernel_size=3, padding=1), nn.ReLU(), - nn.AdaptiveAvgPool2d((1,1)) # 全局池化获取全局特征 + nn.Conv2d(64, 64, kernel_size=3, stride=2), # 更大感受野 + nn.AdaptiveAvgPool2d((4,4)), # 全局池化获取全局特征,调整输出尺寸为4x4 + nn.Flatten(), # 展平为一维向量 + nn.Linear(64*16, 128) # 增加全连接层以降低维度到128 ) - + def forward(self, x): - features = self.cnn(x) - return torch.flatten(features, 1) # 展平为特征向量 - + return self.cnn(x) def get_rotational_features(model, input_image): """计算输入图像所有旋转角度的特征平均值""" rotations = [0, 90, 180, 270]