Files
LayoutMatch/inference.py
2025-03-25 01:42:26 +08:00

48 lines
1.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import faiss
import numpy as np
import torch
# 导入 models.rotation_cnn 模块中的 RotationInvariantNet 类
from models.rotation_cnn import RotationInvariantNet
from models.rotation_cnn import get_rotational_features
# 导入 data_utils 中的 layout_to_tensor 函数(假设该函数存在)
from data_units import layout_to_tensor # 如果 data_utils.py 存在此函数
from data_units import tile_layout
def main():
# 配置参数(需根据实际调整)
block_size = 64 # 分块尺寸
target_module_path = "target.png"
large_layout_path = "layout_large.png"
# 加载模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = RotationInvariantNet().to(device)
model.load_state_dict(torch.load("rotation_cnn.pth"))
model.eval()
# 预处理目标模块与大版图
target_tensor = layout_to_tensor(target_module_path, (block_size, block_size))
target_feat = get_rotational_features(model, torch.tensor(target_tensor).to(device))
large_layout = layout_to_tensor(large_layout_path)
tiles = tile_layout(large_layout)
# 构建特征索引使用Faiss加速
index = faiss.IndexFlatL2(64) # 特征维度由模型决定
features_db = []
for (x,y,tile) in tiles:
feat = get_rotational_features(model, torch.tensor(tile).to(device))
features_db.append(feat)
index.add(np.stack(features_db))
# 检索相似区域
D, I = index.search(target_feat[np.newaxis,:], k=10)
for idx in I[0]:
x,y,_ = tiles[idx]
print(f"匹配区域坐标: ({x}, {y}), 相似度: {D[0][idx]}")
if __name__ == "__main__":
main()