finish Inference and Matching Part.

This commit is contained in:
Jiao77
2025-09-25 22:05:39 +08:00
parent 2ccfe7b07f
commit 419a7db543
6 changed files with 346 additions and 18 deletions

View File

@@ -44,6 +44,24 @@ def extract_keypoints_and_descriptors(model, image_tensor, kp_thresh):
return keypoints, descriptors
# --- (新增) 简单半径 NMS 去重 ---
def radius_nms(kps: torch.Tensor, scores: torch.Tensor, radius: float) -> torch.Tensor:
if kps.numel() == 0:
return torch.empty((0,), dtype=torch.long, device=kps.device)
idx = torch.argsort(scores, descending=True)
keep = []
taken = torch.zeros(len(kps), dtype=torch.bool, device=kps.device)
for i in idx:
if taken[i]:
continue
keep.append(i.item())
di = kps - kps[i]
dist2 = (di[:, 0]**2 + di[:, 1]**2)
taken |= dist2 <= (radius * radius)
taken[i] = True
return torch.tensor(keep, dtype=torch.long, device=kps.device)
# --- (新增) 滑动窗口特征提取函数 ---
def extract_features_sliding_window(model, large_image, transform, matching_cfg):
"""
@@ -88,6 +106,40 @@ def extract_features_sliding_window(model, large_image, transform, matching_cfg)
return torch.cat(all_kps, dim=0), torch.cat(all_descs, dim=0)
# --- (新增) FPN 路径的关键点与描述子抽取 ---
def extract_from_pyramid(model, image_tensor, kp_thresh, nms_cfg):
with torch.no_grad():
pyramid = model(image_tensor, return_pyramid=True)
all_kps = []
all_desc = []
for level_name, (det, desc, stride) in pyramid.items():
binary = (det > kp_thresh).squeeze(0).squeeze(0)
coords = torch.nonzero(binary).float() # y,x
if len(coords) == 0:
continue
scores = det.squeeze()[binary]
# 采样描述子
coords_for_grid = coords.flip(1).view(1, -1, 1, 2)
coords_for_grid = coords_for_grid / torch.tensor([(desc.shape[3]-1)/2, (desc.shape[2]-1)/2], device=desc.device) - 1
descs = F.grid_sample(desc, coords_for_grid, align_corners=True).squeeze().T
descs = F.normalize(descs, p=2, dim=1)
# 映射回原图坐标
kps = coords.flip(1) * float(stride)
# NMS
if nms_cfg and nms_cfg.get('enabled', False):
keep = radius_nms(kps, scores, float(nms_cfg.get('radius', 4)))
if len(keep) > 0:
kps = kps[keep]
descs = descs[keep]
all_kps.append(kps)
all_desc.append(descs)
if not all_kps:
return torch.tensor([], device=image_tensor.device), torch.tensor([], device=image_tensor.device)
return torch.cat(all_kps, dim=0), torch.cat(all_desc, dim=0)
# --- 互近邻匹配 (无变动) ---
def mutual_nearest_neighbor(descs1, descs2):
if len(descs1) == 0 or len(descs2) == 0:
@@ -113,8 +165,13 @@ def match_template_multiscale(
"""
在不同尺度下搜索模板,并检测多个实例
"""
# 1. 对大版图使用滑动窗口提取全部特征
layout_kps, layout_descs = extract_features_sliding_window(model, layout_image, transform, matching_cfg)
# 1. 版图特征提取:根据配置选择 FPN 或滑窗
device = next(model.parameters()).device
if getattr(matching_cfg, 'use_fpn', False):
layout_tensor = transform(layout_image).unsqueeze(0).to(device)
layout_kps, layout_descs = extract_from_pyramid(model, layout_tensor, float(matching_cfg.keypoint_threshold), getattr(matching_cfg, 'nms', {}))
else:
layout_kps, layout_descs = extract_features_sliding_window(model, layout_image, transform, matching_cfg)
if log_writer:
log_writer.add_scalar("match/layout_keypoints", len(layout_kps), log_step)
@@ -154,8 +211,11 @@ def match_template_multiscale(
scaled_template = template_image.resize((new_W, new_H), Image.LANCZOS)
template_tensor = transform(scaled_template).unsqueeze(0).to(layout_kps.device)
# 提取缩放后模板的特征
template_kps, template_descs = extract_keypoints_and_descriptors(model, template_tensor, keypoint_threshold)
# 提取缩放后模板的特征FPN 或单尺度
if getattr(matching_cfg, 'use_fpn', False):
template_kps, template_descs = extract_from_pyramid(model, template_tensor, keypoint_threshold, getattr(matching_cfg, 'nms', {}))
else:
template_kps, template_descs = extract_keypoints_and_descriptors(model, template_tensor, keypoint_threshold)
if len(template_kps) < 4: continue
@@ -227,6 +287,8 @@ if __name__ == "__main__":
parser.add_argument('--experiment_name', type=str, default=None, help="TensorBoard 实验名称,覆盖配置文件设置")
parser.add_argument('--tb_log_matches', action='store_true', help="启用模板匹配过程的 TensorBoard 记录")
parser.add_argument('--disable_tensorboard', action='store_true', help="禁用 TensorBoard 记录")
parser.add_argument('--fpn_off', action='store_true', help="关闭 FPN 匹配路径(等同于 matching.use_fpn=false")
parser.add_argument('--no_nms', action='store_true', help="关闭关键点去重NMS")
parser.add_argument('--layout', type=str, required=True)
parser.add_argument('--template', type=str, required=True)
parser.add_argument('--output', type=str)
@@ -262,6 +324,16 @@ if __name__ == "__main__":
tb_path.parent.mkdir(parents=True, exist_ok=True)
writer = SummaryWriter(tb_path.as_posix())
# CLI 快捷开关覆盖 YAML 配置
try:
if args.fpn_off:
matching_cfg.use_fpn = False
if args.no_nms and hasattr(matching_cfg, 'nms'):
matching_cfg.nms.enabled = False
except Exception:
# 若 OmegaConf 结构不可写,忽略并在后续逻辑中以 getattr 的方式读取
pass
transform = get_transform()
model = RoRD().cuda()
model.load_state_dict(torch.load(model_path))