finish Inference and Matching Part.
This commit is contained in:
80
match.py
80
match.py
@@ -44,6 +44,24 @@ def extract_keypoints_and_descriptors(model, image_tensor, kp_thresh):
|
||||
|
||||
return keypoints, descriptors
|
||||
|
||||
|
||||
# --- (新增) 简单半径 NMS 去重 ---
|
||||
def radius_nms(kps: torch.Tensor, scores: torch.Tensor, radius: float) -> torch.Tensor:
|
||||
if kps.numel() == 0:
|
||||
return torch.empty((0,), dtype=torch.long, device=kps.device)
|
||||
idx = torch.argsort(scores, descending=True)
|
||||
keep = []
|
||||
taken = torch.zeros(len(kps), dtype=torch.bool, device=kps.device)
|
||||
for i in idx:
|
||||
if taken[i]:
|
||||
continue
|
||||
keep.append(i.item())
|
||||
di = kps - kps[i]
|
||||
dist2 = (di[:, 0]**2 + di[:, 1]**2)
|
||||
taken |= dist2 <= (radius * radius)
|
||||
taken[i] = True
|
||||
return torch.tensor(keep, dtype=torch.long, device=kps.device)
|
||||
|
||||
# --- (新增) 滑动窗口特征提取函数 ---
|
||||
def extract_features_sliding_window(model, large_image, transform, matching_cfg):
|
||||
"""
|
||||
@@ -88,6 +106,40 @@ def extract_features_sliding_window(model, large_image, transform, matching_cfg)
|
||||
return torch.cat(all_kps, dim=0), torch.cat(all_descs, dim=0)
|
||||
|
||||
|
||||
# --- (新增) FPN 路径的关键点与描述子抽取 ---
|
||||
def extract_from_pyramid(model, image_tensor, kp_thresh, nms_cfg):
|
||||
with torch.no_grad():
|
||||
pyramid = model(image_tensor, return_pyramid=True)
|
||||
all_kps = []
|
||||
all_desc = []
|
||||
for level_name, (det, desc, stride) in pyramid.items():
|
||||
binary = (det > kp_thresh).squeeze(0).squeeze(0)
|
||||
coords = torch.nonzero(binary).float() # y,x
|
||||
if len(coords) == 0:
|
||||
continue
|
||||
scores = det.squeeze()[binary]
|
||||
# 采样描述子
|
||||
coords_for_grid = coords.flip(1).view(1, -1, 1, 2)
|
||||
coords_for_grid = coords_for_grid / torch.tensor([(desc.shape[3]-1)/2, (desc.shape[2]-1)/2], device=desc.device) - 1
|
||||
descs = F.grid_sample(desc, coords_for_grid, align_corners=True).squeeze().T
|
||||
descs = F.normalize(descs, p=2, dim=1)
|
||||
|
||||
# 映射回原图坐标
|
||||
kps = coords.flip(1) * float(stride)
|
||||
|
||||
# NMS
|
||||
if nms_cfg and nms_cfg.get('enabled', False):
|
||||
keep = radius_nms(kps, scores, float(nms_cfg.get('radius', 4)))
|
||||
if len(keep) > 0:
|
||||
kps = kps[keep]
|
||||
descs = descs[keep]
|
||||
all_kps.append(kps)
|
||||
all_desc.append(descs)
|
||||
if not all_kps:
|
||||
return torch.tensor([], device=image_tensor.device), torch.tensor([], device=image_tensor.device)
|
||||
return torch.cat(all_kps, dim=0), torch.cat(all_desc, dim=0)
|
||||
|
||||
|
||||
# --- 互近邻匹配 (无变动) ---
|
||||
def mutual_nearest_neighbor(descs1, descs2):
|
||||
if len(descs1) == 0 or len(descs2) == 0:
|
||||
@@ -113,8 +165,13 @@ def match_template_multiscale(
|
||||
"""
|
||||
在不同尺度下搜索模板,并检测多个实例
|
||||
"""
|
||||
# 1. 对大版图使用滑动窗口提取全部特征
|
||||
layout_kps, layout_descs = extract_features_sliding_window(model, layout_image, transform, matching_cfg)
|
||||
# 1. 版图特征提取:根据配置选择 FPN 或滑窗
|
||||
device = next(model.parameters()).device
|
||||
if getattr(matching_cfg, 'use_fpn', False):
|
||||
layout_tensor = transform(layout_image).unsqueeze(0).to(device)
|
||||
layout_kps, layout_descs = extract_from_pyramid(model, layout_tensor, float(matching_cfg.keypoint_threshold), getattr(matching_cfg, 'nms', {}))
|
||||
else:
|
||||
layout_kps, layout_descs = extract_features_sliding_window(model, layout_image, transform, matching_cfg)
|
||||
if log_writer:
|
||||
log_writer.add_scalar("match/layout_keypoints", len(layout_kps), log_step)
|
||||
|
||||
@@ -154,8 +211,11 @@ def match_template_multiscale(
|
||||
scaled_template = template_image.resize((new_W, new_H), Image.LANCZOS)
|
||||
template_tensor = transform(scaled_template).unsqueeze(0).to(layout_kps.device)
|
||||
|
||||
# 提取缩放后模板的特征
|
||||
template_kps, template_descs = extract_keypoints_and_descriptors(model, template_tensor, keypoint_threshold)
|
||||
# 提取缩放后模板的特征:FPN 或单尺度
|
||||
if getattr(matching_cfg, 'use_fpn', False):
|
||||
template_kps, template_descs = extract_from_pyramid(model, template_tensor, keypoint_threshold, getattr(matching_cfg, 'nms', {}))
|
||||
else:
|
||||
template_kps, template_descs = extract_keypoints_and_descriptors(model, template_tensor, keypoint_threshold)
|
||||
|
||||
if len(template_kps) < 4: continue
|
||||
|
||||
@@ -227,6 +287,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument('--experiment_name', type=str, default=None, help="TensorBoard 实验名称,覆盖配置文件设置")
|
||||
parser.add_argument('--tb_log_matches', action='store_true', help="启用模板匹配过程的 TensorBoard 记录")
|
||||
parser.add_argument('--disable_tensorboard', action='store_true', help="禁用 TensorBoard 记录")
|
||||
parser.add_argument('--fpn_off', action='store_true', help="关闭 FPN 匹配路径(等同于 matching.use_fpn=false)")
|
||||
parser.add_argument('--no_nms', action='store_true', help="关闭关键点去重(NMS)")
|
||||
parser.add_argument('--layout', type=str, required=True)
|
||||
parser.add_argument('--template', type=str, required=True)
|
||||
parser.add_argument('--output', type=str)
|
||||
@@ -262,6 +324,16 @@ if __name__ == "__main__":
|
||||
tb_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
writer = SummaryWriter(tb_path.as_posix())
|
||||
|
||||
# CLI 快捷开关覆盖 YAML 配置
|
||||
try:
|
||||
if args.fpn_off:
|
||||
matching_cfg.use_fpn = False
|
||||
if args.no_nms and hasattr(matching_cfg, 'nms'):
|
||||
matching_cfg.nms.enabled = False
|
||||
except Exception:
|
||||
# 若 OmegaConf 结构不可写,忽略并在后续逻辑中以 getattr 的方式读取
|
||||
pass
|
||||
|
||||
transform = get_transform()
|
||||
model = RoRD().cuda()
|
||||
model.load_state_dict(torch.load(model_path))
|
||||
|
||||
Reference in New Issue
Block a user