complete code struction update
This commit is contained in:
55
match.py
55
match.py
@@ -1,15 +1,17 @@
|
||||
# match.py
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
import cv2
|
||||
from PIL import Image
|
||||
import argparse
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from PIL import Image
|
||||
|
||||
import config
|
||||
from models.rord import RoRD
|
||||
from utils.config_loader import load_config, to_absolute_path
|
||||
from utils.data_utils import get_transform
|
||||
|
||||
# --- 特征提取函数 (基本无变动) ---
|
||||
@@ -39,15 +41,16 @@ def extract_keypoints_and_descriptors(model, image_tensor, kp_thresh):
|
||||
return keypoints, descriptors
|
||||
|
||||
# --- (新增) 滑动窗口特征提取函数 ---
|
||||
def extract_features_sliding_window(model, large_image, transform):
|
||||
def extract_features_sliding_window(model, large_image, transform, matching_cfg):
|
||||
"""
|
||||
使用滑动窗口从大图上提取所有关键点和描述子
|
||||
"""
|
||||
print("使用滑动窗口提取大版图特征...")
|
||||
device = next(model.parameters()).device
|
||||
W, H = large_image.size
|
||||
window_size = config.INFERENCE_WINDOW_SIZE
|
||||
stride = config.INFERENCE_STRIDE
|
||||
window_size = int(matching_cfg.inference_window_size)
|
||||
stride = int(matching_cfg.inference_stride)
|
||||
keypoint_threshold = float(matching_cfg.keypoint_threshold)
|
||||
|
||||
all_kps = []
|
||||
all_descs = []
|
||||
@@ -65,7 +68,7 @@ def extract_features_sliding_window(model, large_image, transform):
|
||||
patch_tensor = transform(patch).unsqueeze(0).to(device)
|
||||
|
||||
# 提取特征
|
||||
kps, descs = extract_keypoints_and_descriptors(model, patch_tensor, config.KEYPOINT_THRESHOLD)
|
||||
kps, descs = extract_keypoints_and_descriptors(model, patch_tensor, keypoint_threshold)
|
||||
|
||||
if len(kps) > 0:
|
||||
# 将局部坐标转换为全局坐标
|
||||
@@ -94,26 +97,30 @@ def mutual_nearest_neighbor(descs1, descs2):
|
||||
return matches
|
||||
|
||||
# --- (已修改) 多尺度、多实例匹配主函数 ---
|
||||
def match_template_multiscale(model, layout_image, template_image, transform):
|
||||
def match_template_multiscale(model, layout_image, template_image, transform, matching_cfg):
|
||||
"""
|
||||
在不同尺度下搜索模板,并检测多个实例
|
||||
"""
|
||||
# 1. 对大版图使用滑动窗口提取全部特征
|
||||
layout_kps, layout_descs = extract_features_sliding_window(model, layout_image, transform)
|
||||
layout_kps, layout_descs = extract_features_sliding_window(model, layout_image, transform, matching_cfg)
|
||||
|
||||
if len(layout_kps) < config.MIN_INLIERS:
|
||||
min_inliers = int(matching_cfg.min_inliers)
|
||||
if len(layout_kps) < min_inliers:
|
||||
print("从大版图中提取的关键点过少,无法进行匹配。")
|
||||
return []
|
||||
|
||||
found_instances = []
|
||||
active_layout_mask = torch.ones(len(layout_kps), dtype=bool, device=layout_kps.device)
|
||||
pyramid_scales = [float(s) for s in matching_cfg.pyramid_scales]
|
||||
keypoint_threshold = float(matching_cfg.keypoint_threshold)
|
||||
ransac_threshold = float(matching_cfg.ransac_reproj_threshold)
|
||||
|
||||
# 2. 多实例迭代检测
|
||||
while True:
|
||||
current_active_indices = torch.nonzero(active_layout_mask).squeeze(1)
|
||||
|
||||
# 如果剩余活动关键点过少,则停止
|
||||
if len(current_active_indices) < config.MIN_INLIERS:
|
||||
if len(current_active_indices) < min_inliers:
|
||||
break
|
||||
|
||||
current_layout_kps = layout_kps[current_active_indices]
|
||||
@@ -123,7 +130,7 @@ def match_template_multiscale(model, layout_image, template_image, transform):
|
||||
|
||||
# 3. 图像金字塔:遍历模板的每个尺度
|
||||
print("在新尺度下搜索模板...")
|
||||
for scale in config.PYRAMID_SCALES:
|
||||
for scale in pyramid_scales:
|
||||
W, H = template_image.size
|
||||
new_W, new_H = int(W * scale), int(H * scale)
|
||||
|
||||
@@ -132,7 +139,7 @@ def match_template_multiscale(model, layout_image, template_image, transform):
|
||||
template_tensor = transform(scaled_template).unsqueeze(0).to(layout_kps.device)
|
||||
|
||||
# 提取缩放后模板的特征
|
||||
template_kps, template_descs = extract_keypoints_and_descriptors(model, template_tensor, config.KEYPOINT_THRESHOLD)
|
||||
template_kps, template_descs = extract_keypoints_and_descriptors(model, template_tensor, keypoint_threshold)
|
||||
|
||||
if len(template_kps) < 4: continue
|
||||
|
||||
@@ -147,13 +154,13 @@ def match_template_multiscale(model, layout_image, template_image, transform):
|
||||
dst_pts_indices = current_active_indices[matches[:, 1]]
|
||||
dst_pts = layout_kps[dst_pts_indices].cpu().numpy()
|
||||
|
||||
H, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, config.RANSAC_REPROJ_THRESHOLD)
|
||||
H, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, ransac_threshold)
|
||||
|
||||
if H is not None and mask.sum() > best_match_info['inliers']:
|
||||
best_match_info = {'inliers': mask.sum(), 'H': H, 'mask': mask, 'scale': scale, 'dst_pts': dst_pts}
|
||||
|
||||
# 4. 如果在所有尺度中找到了最佳匹配,则记录并屏蔽
|
||||
if best_match_info['inliers'] > config.MIN_INLIERS:
|
||||
if best_match_info['inliers'] > min_inliers:
|
||||
print(f"找到一个匹配实例!内点数: {best_match_info['inliers']}, 使用的模板尺度: {best_match_info['scale']:.2f}x")
|
||||
|
||||
inlier_mask = best_match_info['mask'].ravel().astype(bool)
|
||||
@@ -191,21 +198,27 @@ def visualize_matches(layout_path, bboxes, output_path):
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="使用 RoRD 进行多尺度模板匹配")
|
||||
parser.add_argument('--model_path', type=str, default=config.MODEL_PATH)
|
||||
parser.add_argument('--config', type=str, default="configs/base_config.yaml", help="YAML 配置文件路径")
|
||||
parser.add_argument('--model_path', type=str, default=None, help="模型权重路径,若未提供则使用配置文件中的路径")
|
||||
parser.add_argument('--layout', type=str, required=True)
|
||||
parser.add_argument('--template', type=str, required=True)
|
||||
parser.add_argument('--output', type=str)
|
||||
args = parser.parse_args()
|
||||
|
||||
cfg = load_config(args.config)
|
||||
config_dir = Path(args.config).resolve().parent
|
||||
matching_cfg = cfg.matching
|
||||
model_path = args.model_path or str(to_absolute_path(cfg.paths.model_path, config_dir))
|
||||
|
||||
transform = get_transform()
|
||||
model = RoRD().cuda()
|
||||
model.load_state_dict(torch.load(args.model_path))
|
||||
model.load_state_dict(torch.load(model_path))
|
||||
model.eval()
|
||||
|
||||
layout_image = Image.open(args.layout).convert('L')
|
||||
template_image = Image.open(args.template).convert('L')
|
||||
|
||||
detected_bboxes = match_template_multiscale(model, layout_image, template_image, transform)
|
||||
detected_bboxes = match_template_multiscale(model, layout_image, template_image, transform, matching_cfg)
|
||||
|
||||
print("\n检测到的边界框:")
|
||||
for bbox in detected_bboxes:
|
||||
|
||||
Reference in New Issue
Block a user