Files
LayoutMatch/inference.py
2025-03-26 22:33:36 +08:00

51 lines
1.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import faiss
import numpy as np
import torch
from models.rotation_cnn import RotationInvariantNet, get_rotational_features
from data_units import layout_to_tensor, tile_layout
def main():
# 配置参数(需根据实际调整)
block_size = 64 # 分块尺寸
target_module_path = "target.png"
large_layout_path = "layout_large.png"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = RotationInvariantNet().to(device)
model.load_state_dict(torch.load("rotation_cnn.pth"))
model.eval()
# 预处理目标模块与大版图
target_tensor = layout_to_tensor(target_module_path, (block_size, block_size))
target_feat = get_rotational_features(model, torch.tensor(target_tensor).to(device))
large_layout = layout_to_tensor(large_layout_path)
tiles = tile_layout(large_layout)
# 构建特征索引使用Faiss加速
index = faiss.IndexFlatL2(64) # 特征维度由模型决定
features_db = []
for (x, y, tile) in tiles:
feat = get_rotational_features(model, torch.tensor(tile).to(device))
features_db.append(feat)
index.add(np.stack(features_db))
# 检索相似区域
D, I = index.search(target_feat[np.newaxis, :], k=10)
for idx in I[0]:
x, y, _ = tiles[idx]
# 计算最佳匹配角度的显式计算
min_angle, min_dist = 90, float('inf')
target_vec = target_feat
feat = features_db[idx]
for a in [0, 1, 2, 3]: # 代表0°、90°、180°、270°
rotated_feat = np.rot90(feat.reshape(block_size, block_size), k=a)
dist = np.linalg.norm(target_vec - rotated_feat.flatten())
if dist < min_dist:
min_dist, min_angle = dist, a * 90
print(f"坐标({x},{y}), 最佳旋转方向{min_angle}度,距离: {min_dist}")
if __name__ == "__main__":
main()