complete code struction update

This commit is contained in:
Jiao77
2025-09-25 20:20:24 +08:00
parent e0b250e77f
commit 8c9926c815
10 changed files with 480 additions and 290 deletions

View File

@@ -1,15 +1,17 @@
# match.py
import torch
import torch.nn.functional as F
import numpy as np
import cv2
from PIL import Image
import argparse
import os
from pathlib import Path
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
import config
from models.rord import RoRD
from utils.config_loader import load_config, to_absolute_path
from utils.data_utils import get_transform
# --- 特征提取函数 (基本无变动) ---
@@ -39,15 +41,16 @@ def extract_keypoints_and_descriptors(model, image_tensor, kp_thresh):
return keypoints, descriptors
# --- (新增) 滑动窗口特征提取函数 ---
def extract_features_sliding_window(model, large_image, transform):
def extract_features_sliding_window(model, large_image, transform, matching_cfg):
"""
使用滑动窗口从大图上提取所有关键点和描述子
"""
print("使用滑动窗口提取大版图特征...")
device = next(model.parameters()).device
W, H = large_image.size
window_size = config.INFERENCE_WINDOW_SIZE
stride = config.INFERENCE_STRIDE
window_size = int(matching_cfg.inference_window_size)
stride = int(matching_cfg.inference_stride)
keypoint_threshold = float(matching_cfg.keypoint_threshold)
all_kps = []
all_descs = []
@@ -65,7 +68,7 @@ def extract_features_sliding_window(model, large_image, transform):
patch_tensor = transform(patch).unsqueeze(0).to(device)
# 提取特征
kps, descs = extract_keypoints_and_descriptors(model, patch_tensor, config.KEYPOINT_THRESHOLD)
kps, descs = extract_keypoints_and_descriptors(model, patch_tensor, keypoint_threshold)
if len(kps) > 0:
# 将局部坐标转换为全局坐标
@@ -94,26 +97,30 @@ def mutual_nearest_neighbor(descs1, descs2):
return matches
# --- (已修改) 多尺度、多实例匹配主函数 ---
def match_template_multiscale(model, layout_image, template_image, transform):
def match_template_multiscale(model, layout_image, template_image, transform, matching_cfg):
"""
在不同尺度下搜索模板,并检测多个实例
"""
# 1. 对大版图使用滑动窗口提取全部特征
layout_kps, layout_descs = extract_features_sliding_window(model, layout_image, transform)
layout_kps, layout_descs = extract_features_sliding_window(model, layout_image, transform, matching_cfg)
if len(layout_kps) < config.MIN_INLIERS:
min_inliers = int(matching_cfg.min_inliers)
if len(layout_kps) < min_inliers:
print("从大版图中提取的关键点过少,无法进行匹配。")
return []
found_instances = []
active_layout_mask = torch.ones(len(layout_kps), dtype=bool, device=layout_kps.device)
pyramid_scales = [float(s) for s in matching_cfg.pyramid_scales]
keypoint_threshold = float(matching_cfg.keypoint_threshold)
ransac_threshold = float(matching_cfg.ransac_reproj_threshold)
# 2. 多实例迭代检测
while True:
current_active_indices = torch.nonzero(active_layout_mask).squeeze(1)
# 如果剩余活动关键点过少,则停止
if len(current_active_indices) < config.MIN_INLIERS:
if len(current_active_indices) < min_inliers:
break
current_layout_kps = layout_kps[current_active_indices]
@@ -123,7 +130,7 @@ def match_template_multiscale(model, layout_image, template_image, transform):
# 3. 图像金字塔:遍历模板的每个尺度
print("在新尺度下搜索模板...")
for scale in config.PYRAMID_SCALES:
for scale in pyramid_scales:
W, H = template_image.size
new_W, new_H = int(W * scale), int(H * scale)
@@ -132,7 +139,7 @@ def match_template_multiscale(model, layout_image, template_image, transform):
template_tensor = transform(scaled_template).unsqueeze(0).to(layout_kps.device)
# 提取缩放后模板的特征
template_kps, template_descs = extract_keypoints_and_descriptors(model, template_tensor, config.KEYPOINT_THRESHOLD)
template_kps, template_descs = extract_keypoints_and_descriptors(model, template_tensor, keypoint_threshold)
if len(template_kps) < 4: continue
@@ -147,13 +154,13 @@ def match_template_multiscale(model, layout_image, template_image, transform):
dst_pts_indices = current_active_indices[matches[:, 1]]
dst_pts = layout_kps[dst_pts_indices].cpu().numpy()
H, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, config.RANSAC_REPROJ_THRESHOLD)
H, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, ransac_threshold)
if H is not None and mask.sum() > best_match_info['inliers']:
best_match_info = {'inliers': mask.sum(), 'H': H, 'mask': mask, 'scale': scale, 'dst_pts': dst_pts}
# 4. 如果在所有尺度中找到了最佳匹配,则记录并屏蔽
if best_match_info['inliers'] > config.MIN_INLIERS:
if best_match_info['inliers'] > min_inliers:
print(f"找到一个匹配实例!内点数: {best_match_info['inliers']}, 使用的模板尺度: {best_match_info['scale']:.2f}x")
inlier_mask = best_match_info['mask'].ravel().astype(bool)
@@ -191,21 +198,27 @@ def visualize_matches(layout_path, bboxes, output_path):
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="使用 RoRD 进行多尺度模板匹配")
parser.add_argument('--model_path', type=str, default=config.MODEL_PATH)
parser.add_argument('--config', type=str, default="configs/base_config.yaml", help="YAML 配置文件路径")
parser.add_argument('--model_path', type=str, default=None, help="模型权重路径,若未提供则使用配置文件中的路径")
parser.add_argument('--layout', type=str, required=True)
parser.add_argument('--template', type=str, required=True)
parser.add_argument('--output', type=str)
args = parser.parse_args()
cfg = load_config(args.config)
config_dir = Path(args.config).resolve().parent
matching_cfg = cfg.matching
model_path = args.model_path or str(to_absolute_path(cfg.paths.model_path, config_dir))
transform = get_transform()
model = RoRD().cuda()
model.load_state_dict(torch.load(args.model_path))
model.load_state_dict(torch.load(model_path))
model.eval()
layout_image = Image.open(args.layout).convert('L')
template_image = Image.open(args.template).convert('L')
detected_bboxes = match_template_multiscale(model, layout_image, template_image, transform)
detected_bboxes = match_template_multiscale(model, layout_image, template_image, transform, matching_cfg)
print("\n检测到的边界框:")
for bbox in detected_bboxes: