Files
LayoutMatch/inference.py
2025-03-31 14:49:04 +08:00

70 lines
3.0 KiB
Python

import torch
import cv2
import numpy as np
from models.superpoint_custom import SuperPointCustom
def get_keypoints_from_heatmap(semi, threshold=0.015):
semi = semi.squeeze().cpu().numpy() # [65, H/8, W/8]
prob = cv2.softmax(semi, axis=0)[:-1] # [64, H/8, W/8]
prob = prob.reshape(8, 8, semi.shape[1], semi.shape[2])
prob = prob.transpose(0, 2, 1, 3).reshape(8*semi.shape[1], 8*semi.shape[2]) # [H, W]
keypoints = []
for y in range(prob.shape[0]):
for x in range(prob.shape[1]):
if prob[y, x] > threshold:
keypoints.append(cv2.KeyPoint(x, y, 1))
return keypoints
def get_descriptors_from_map(desc, keypoints):
desc = desc.squeeze().cpu().numpy() # [256, H/8, W/8]
descriptors = []
scale = 8
for kp in keypoints:
x, y = int(kp.pt[0] / scale), int(kp.pt[1] / scale)
if 0 <= x < desc.shape[2] and 0 <= y < desc.shape[1]:
descriptors.append(desc[:, y, x])
return np.array(descriptors)
def match_and_estimate(layout_path, module_path, model_path, num_channels, device='cuda'):
model = SuperPointCustom(num_channels=num_channels).to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
layout = np.load(layout_path) # [C, H, W]
module = np.load(module_path) # [C, H, W]
layout_tensor = torch.from_numpy(layout).float().unsqueeze(0).to(device)
module_tensor = torch.from_numpy(module).float().unsqueeze(0).to(device)
with torch.no_grad():
semi_layout, desc_layout = model(layout_tensor)
semi_module, desc_module = model(module_tensor)
kp_layout = get_keypoints_from_heatmap(semi_layout)
desc_layout = get_descriptors_from_map(desc_layout, kp_layout)
kp_module = get_keypoints_from_heatmap(semi_module)
desc_module = get_descriptors_from_map(desc_module, kp_module)
bf = cv2.BFMatcher(cv2.NORM_L2, crossCheck=True)
matches = bf.match(desc_module, desc_layout)
matches = sorted(matches, key=lambda x: x.distance)
src_pts = np.float32([kp_module[m.queryIdx].pt for m in matches]).reshape(-1, 1, 2)
dst_pts = np.float32([kp_layout[m.trainIdx].pt for m in matches]).reshape(-1, 1, 2)
H, _ = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 5.0)
h, w = module.shape[1], module.shape[2]
corners = np.float32([[0, 0], [w, 0], [w, h], [0, h]]).reshape(-1, 1, 2)
transformed_corners = cv2.perspectiveTransform(corners, H)
x_min, y_min = np.min(transformed_corners, axis=0).ravel().astype(int)
x_max, y_max = np.max(transformed_corners, axis=0).ravel().astype(int)
theta = np.arctan2(H[1, 0], H[0, 0]) * 180 / np.pi
print(f"Matched region: [{x_min}, {y_min}, {x_max}, {y_max}], Rotation: {theta:.2f} degrees")
return x_min, y_min, x_max, y_max, theta
if __name__ == "__main__":
layout_path = "data/large_layout.npy"
module_path = "data/small_module.npy"
model_path = "superpoint_custom_model.pth"
num_channels = 3 # 替换为实际通道数
match_and_estimate(layout_path, module_path, model_path, num_channels)