Compare commits
	
		
			2 Commits
		
	
	
		
			2ccfe7b07f
			...
			3566ae6bfb
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 3566ae6bfb | |||
|   | 419a7db543 | 
							
								
								
									
										45
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										45
									
								
								README.md
									
									
									
									
									
								
							| @@ -274,6 +274,51 @@ RoRD 模型基于 D2-Net 架构,使用 VGG-16 作为骨干网络,**专门针 | |||||||
| - **二值化特征距离**: 强化几何边界特征,弱化灰度变化 | - **二值化特征距离**: 强化几何边界特征,弱化灰度变化 | ||||||
| - **几何感知困难负样本**: 基于结构相似性而非像素相似性选择负样本 | - **几何感知困难负样本**: 基于结构相似性而非像素相似性选择负样本 | ||||||
|  |  | ||||||
|  | ## 🔎 推理与匹配(FPN 路径与 NMS) | ||||||
|  |  | ||||||
|  | 项目已支持通过 FPN 单次推理产生多尺度特征,并在匹配阶段引入半径 NMS 去重以减少冗余关键点: | ||||||
|  |  | ||||||
|  | 在 `configs/base_config.yaml` 中启用 FPN 与 NMS: | ||||||
|  |  | ||||||
|  | ```yaml | ||||||
|  | model: | ||||||
|  |   fpn: | ||||||
|  |     enabled: true | ||||||
|  |     out_channels: 256 | ||||||
|  |     levels: [2, 3, 4] | ||||||
|  |  | ||||||
|  | matching: | ||||||
|  |   use_fpn: true | ||||||
|  |   nms: | ||||||
|  |     enabled: true | ||||||
|  |     radius: 4 | ||||||
|  |     score_threshold: 0.5 | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | 运行匹配并将过程写入 TensorBoard: | ||||||
|  |  | ||||||
|  | ```bash | ||||||
|  | uv run python match.py \ | ||||||
|  |   --config configs/base_config.yaml \ | ||||||
|  |   --layout /path/to/layout.png \ | ||||||
|  |   --template /path/to/template.png \ | ||||||
|  |   --tb_log_matches | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | 如需回退旧“图像金字塔”路径,将 `matching.use_fpn` 设为 `false` 即可。 | ||||||
|  |  | ||||||
|  | 也可使用 CLI 快捷开关临时覆盖: | ||||||
|  |  | ||||||
|  | ```bash | ||||||
|  | # 关闭 FPN(等同 matching.use_fpn=false) | ||||||
|  | uv run python match.py --config configs/base_config.yaml --fpn_off \ | ||||||
|  |   --layout /path/to/layout.png --template /path/to/template.png | ||||||
|  |  | ||||||
|  | # 关闭关键点去重(NMS) | ||||||
|  | uv run python match.py --config configs/base_config.yaml --no_nms \ | ||||||
|  |   --layout /path/to/layout.png --template /path/to/template.png | ||||||
|  | ``` | ||||||
|  |  | ||||||
| ### 训练策略 - 几何结构学习 | ### 训练策略 - 几何结构学习 | ||||||
| 模型通过**几何结构学习**策略进行训练: | 模型通过**几何结构学习**策略进行训练: | ||||||
| - **曼哈顿变换生成训练对**: 利用90度旋转等曼哈顿变换 | - **曼哈顿变换生成训练对**: 利用90度旋转等曼哈顿变换 | ||||||
|   | |||||||
| @@ -5,6 +5,13 @@ training: | |||||||
|   patch_size: 256 |   patch_size: 256 | ||||||
|   scale_jitter_range: [0.8, 1.2] |   scale_jitter_range: [0.8, 1.2] | ||||||
|  |  | ||||||
|  | model: | ||||||
|  |   fpn: | ||||||
|  |     enabled: true | ||||||
|  |     out_channels: 256 | ||||||
|  |     levels: [2, 3, 4] | ||||||
|  |     norm: "bn" | ||||||
|  |  | ||||||
| matching: | matching: | ||||||
|   keypoint_threshold: 0.5 |   keypoint_threshold: 0.5 | ||||||
|   ransac_reproj_threshold: 5.0 |   ransac_reproj_threshold: 5.0 | ||||||
| @@ -12,6 +19,11 @@ matching: | |||||||
|   pyramid_scales: [0.75, 1.0, 1.5] |   pyramid_scales: [0.75, 1.0, 1.5] | ||||||
|   inference_window_size: 1024 |   inference_window_size: 1024 | ||||||
|   inference_stride: 768 |   inference_stride: 768 | ||||||
|  |   use_fpn: true | ||||||
|  |   nms: | ||||||
|  |     enabled: true | ||||||
|  |     radius: 4 | ||||||
|  |     score_threshold: 0.5 | ||||||
|  |  | ||||||
| evaluation: | evaluation: | ||||||
|   iou_threshold: 0.5 |   iou_threshold: 0.5 | ||||||
|   | |||||||
							
								
								
									
										133
									
								
								docs/NextStep.md
									
									
									
									
									
								
							
							
						
						
									
										133
									
								
								docs/NextStep.md
									
									
									
									
									
								
							| @@ -122,3 +122,136 @@ | |||||||
| - [ ] 修改配置和脚本,接入 SummaryWriter。 | - [ ] 修改配置和脚本,接入 SummaryWriter。 | ||||||
| - [ ] 准备示例 Notebook/文档,展示 TensorBoard 面板截图。 | - [ ] 准备示例 Notebook/文档,展示 TensorBoard 面板截图。 | ||||||
| - [ ] 后续评估是否需要接入 W&B、MLflow 等更高级平台。 | - [ ] 后续评估是否需要接入 W&B、MLflow 等更高级平台。 | ||||||
|  |  | ||||||
|  | --- | ||||||
|  |  | ||||||
|  | # 推理与匹配改造计划(FPN + NMS) | ||||||
|  |  | ||||||
|  | 日期:2025-09-25 | ||||||
|  |  | ||||||
|  | ## 目标 | ||||||
|  | - 将当前的“图像金字塔 + 多次推理”的匹配流程,升级为“单次推理 + 特征金字塔 (FPN)”以显著提速。 | ||||||
|  | - 在滑动窗口提取关键点后增加去重(NMS/半径抑制),降低冗余点与后续 RANSAC 的计算量。 | ||||||
|  | - 保持与现有 YAML 配置、TensorBoard 记录和命令行接口的一致性;以 uv 为包管理器管理依赖和运行。 | ||||||
|  |  | ||||||
|  | ## 设计概览 | ||||||
|  | - FPN:在 `models/rord.py` 中,从骨干网络多层提取特征(例如 VGG 的 relu2_2/relu3_3/relu4_3),通过横向 1x1 卷积与自顶向下上采样构建 P2/P3/P4 金字塔特征;为每个尺度共享或独立地接上检测头与描述子头,导出同维度描述子。 | ||||||
|  | - 匹配路径:`match.py` 新增 FPN 路径,单次前向获得多尺度特征,逐层与模板进行匹配与几何验证;保留旧路径(图像金字塔)作为回退,通过配置开关切换。 | ||||||
|  | - 去重策略:在滑窗聚合关键点后,基于“分数优先 + 半径抑制 (radius NMS)”进行去重;半径和分数阈值配置化。 | ||||||
|  |  | ||||||
|  | ## 配置变更(YAML) | ||||||
|  | 在 `configs/base_config.yaml` 中新增/扩展: | ||||||
|  |  | ||||||
|  | ```yaml | ||||||
|  | model: | ||||||
|  |    fpn: | ||||||
|  |       enabled: true            # 开启 FPN 推理 | ||||||
|  |       out_channels: 256        # 金字塔特征通道数 | ||||||
|  |       levels: [2, 3, 4]        # 输出层级,对应 P2/P3/P4 | ||||||
|  |       norm: "bn"              # 归一化类型:bn/gn/none | ||||||
|  |  | ||||||
|  | matching: | ||||||
|  |    use_fpn: true              # 使用 FPN 路径;false 则沿用图像金字塔 | ||||||
|  |    nms: | ||||||
|  |       enabled: true | ||||||
|  |       radius: 4                # 半径抑制像素半径 | ||||||
|  |       score_threshold: 0.5     # 关键点保留的最低分数 | ||||||
|  |    # 其余已有参数保留,如 ransac_reproj_threshold/min_inliers/inference_window_size... | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | 注意:所有相对路径依旧使用 `utils.config_loader.to_absolute_path` 以配置文件所在目录为基准解析。 | ||||||
|  |  | ||||||
|  | ## 实施步骤 | ||||||
|  |  | ||||||
|  | 1) 基线分支与依赖 | ||||||
|  | - 新开分支保存改造: | ||||||
|  |    ```bash | ||||||
|  |    git checkout -b feature/fpn-matching | ||||||
|  |    uv sync | ||||||
|  |    ``` | ||||||
|  | - 目前不引入新三方库,继续使用现有 `torch/opencv/numpy`。 | ||||||
|  |  | ||||||
|  | 2) 模型侧改造(`models/rord.py`) | ||||||
|  | - 提取多层特征:在骨干网络中暴露中间层输出(如 C2/C3/C4)。 | ||||||
|  | - 构建 FPN: | ||||||
|  |    - 使用 1x1 conv 降维对齐通道; | ||||||
|  |    - 自顶向下上采样并逐级相加; | ||||||
|  |    - 3x3 conv 平滑,得到 P2/P3/P4; | ||||||
|  |    - 可选归一化(BN/GN)。 | ||||||
|  | - 头部适配:复用或复制现有检测头/描述子头到每个 P 层,输出: | ||||||
|  |    - det_scores[L]:B×1×H_L×W_L | ||||||
|  |    - descriptors[L]:B×D×H_L×W_L(D 与现有描述子维度一致) | ||||||
|  | - 前向接口: | ||||||
|  |    - 训练模式:维持现有输出以兼容训练; | ||||||
|  |    - 匹配/评估模式:支持 `return_pyramid=True` 返回 {P2,P3,P4} 的 det/desc。 | ||||||
|  |  | ||||||
|  | 3) 匹配侧改造(`match.py`) | ||||||
|  | - 配置读取:根据 `matching.use_fpn` 决定走 FPN 或旧图像金字塔路径。 | ||||||
|  | - FPN 路径: | ||||||
|  |    - 对 layout 与 template 各做一次前向,获得 {det, desc}×L; | ||||||
|  |    - 对每个层级 L: | ||||||
|  |       - 从 det_scores[L] 以 `score_threshold` 抽取关键点坐标与分数; | ||||||
|  |       - 半径 NMS 去重(见步骤 4); | ||||||
|  |       - 使用 desc[L] 在对应层做特征最近邻匹配(可选比值测试)并估计单应性 H_L(RANSAC); | ||||||
|  |    - 融合多个层级的候选:选取内点数最多或综合打分最佳的实例; | ||||||
|  |    - 将层级坐标映射回原图坐标;输出 bbox 与 H。 | ||||||
|  | - 旧路径保留:若 `use_fpn=false`,继续使用当前图像金字塔多次推理策略,便于回退与对照实验。 | ||||||
|  |  | ||||||
|  | 4) 关键点去重(NMS/半径抑制) | ||||||
|  | - 输入:关键点集合 K = {(x, y, score)}。 | ||||||
|  | - 算法:按 score 降序遍历,若与已保留点的欧氏距离 < radius 则丢弃,否则保留。 | ||||||
|  | - 复杂度:O(N log N) 排序 + O(N·k) 检查(k 为邻域个数,可通过网格划分加速)。 | ||||||
|  | - 参数:`matching.nms.radius`、`matching.nms.score_threshold`。 | ||||||
|  |  | ||||||
|  | 5) TensorBoard 记录(扩展) | ||||||
|  | - Scalars: | ||||||
|  |    - `match_fpn/level_L/keypoints_before_nms`、`keypoints_after_nms` | ||||||
|  |    - `match_fpn/level_L/inliers`、`best_instance_inliers` | ||||||
|  |    - `match_fpn/instances_found`、`runtime_ms` | ||||||
|  | - Text/Image: | ||||||
|  |    - 关键点可视化(可选),最佳实例覆盖图; | ||||||
|  |    - 记录使用的层级与最终选中尺度信息。 | ||||||
|  |  | ||||||
|  | 6) 兼容性与回退 | ||||||
|  | - 通过 YAML `matching.use_fpn` 开关控制路径; | ||||||
|  | - 保持 CLI 不变,新增可选 `--fpn-off`(等同 use_fpn=false)仅作为临时调试; | ||||||
|  | - 若新路径异常可快速回退旧路径,保证生产可用性。 | ||||||
|  |  | ||||||
|  | ## 开发里程碑与工时预估 | ||||||
|  | - M1(0.5 天):配置与分支、占位接口、日志钩子。 | ||||||
|  | - M2(1.5 天):FPN 实现与前向接口;单图 smoke 测试。 | ||||||
|  | - M3(1 天):`match.py` FPN 路径、尺度回映射与候选融合。 | ||||||
|  | - M4(0.5 天):NMS 实现与参数打通; | ||||||
|  | - M5(0.5 天):TensorBoard 指标与可视化; | ||||||
|  | - M6(0.5 天):对照基线的性能与速度评估,整理报告。 | ||||||
|  |  | ||||||
|  | ## 质量门禁与验收标准 | ||||||
|  | - 构建:`uv sync` 无错误;`python -m compileall` 通过; | ||||||
|  | - 功能:在 2–3 张样例上,FPN 路径输出的实例数量与旧路径相当或更优; | ||||||
|  | - 速度:相同输入,FPN 路径总耗时较旧路径下降 ≥ 30%; | ||||||
|  | - 稳定性:无异常崩溃;在找不到匹配时能优雅返回空结果; | ||||||
|  | - 指标:TensorBoard 中关键点数量、NMS 前后对比、内点数、总实例数与运行时均可见。 | ||||||
|  |  | ||||||
|  | ## 快速试用(命令) | ||||||
|  | ```bash | ||||||
|  | # 同步环境 | ||||||
|  | uv sync | ||||||
|  |  | ||||||
|  | # 基于 YAML 启用 FPN 匹配(推荐) | ||||||
|  | uv run python match.py \ | ||||||
|  |    --config configs/base_config.yaml \ | ||||||
|  |    --layout /path/to/layout.png \ | ||||||
|  |    --template /path/to/template.png \ | ||||||
|  |    --tb_log_matches | ||||||
|  |  | ||||||
|  | # 临时关闭 FPN(对照实验) | ||||||
|  | # 可通过把 configs 中 matching.use_fpn 设为 false,或后续提供 --fpn-off 开关 | ||||||
|  |  | ||||||
|  | # 打开 TensorBoard 查看匹配指标 | ||||||
|  | uv run tensorboard --logdir runs | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | ## 风险与回滚 | ||||||
|  | - FPN 输出与原检测/描述子头的维度/分布不一致,需在实现时对齐通道与归一化; | ||||||
|  | - 多层融合策略(如何选取最佳实例)可能影响稳定性,可先以“内点数最大”作为启发式; | ||||||
|  | - 如出现精度下降或不稳定,立即回退 `matching.use_fpn=false`,保留旧流程并开启数据记录比对差异。 | ||||||
|   | |||||||
| @@ -62,13 +62,13 @@ | |||||||
|  |  | ||||||
| > *目标:大幅提升大尺寸版图的匹配速度和多尺度检测能力。* | > *目标:大幅提升大尺寸版图的匹配速度和多尺度检测能力。* | ||||||
|  |  | ||||||
| - [ ] **将模型改造为特征金字塔网络 (FPN) 架构** | - [x] **将模型改造为特征金字塔网络 (FPN) 架构** | ||||||
|   - **✔️ 价值**: 当前的多尺度匹配需要多次缩放图像并推理,速度慢。FPN 只需一次推理即可获得所有尺度的特征,极大加速匹配过程。 |   - **✔️ 价值**: 当前的多尺度匹配需要多次缩放图像并推理,速度慢。FPN 只需一次推理即可获得所有尺度的特征,极大加速匹配过程。 | ||||||
|   - **📝 执行方案**: |   - **📝 执行方案**: | ||||||
|     1. 修改 `models/rord.py`,从骨干网络的不同层级(如 VGG 的 `relu2_2`, `relu3_3`, `relu4_3`)提取特征图。 |     1. 修改 `models/rord.py`,从骨干网络的不同层级(如 VGG 的 `relu2_2`, `relu3_3`, `relu4_3`)提取特征图。 | ||||||
|     2. 添加上采样和横向连接层来融合这些特征图,构建出特征金字塔。 |     2. 添加上采样和横向连接层来融合这些特征图,构建出特征金字塔。 | ||||||
|     3. 修改 `match.py`,使其能够直接从 FPN 的不同层级获取特征,替代原有的图像金字塔循环。 |     3. 修改 `match.py`,使其能够直接从 FPN 的不同层级获取特征,替代原有的图像金字塔循环。 | ||||||
| - [ ] **在滑动窗口匹配后增加关键点去重** | - [x] **在滑动窗口匹配后增加关键点去重** | ||||||
|   - **✔️ 价值**: `match.py` 中的滑动窗口在重叠区域会产生大量重复的关键点,增加后续匹配的计算量并可能影响精度。 |   - **✔️ 价值**: `match.py` 中的滑动窗口在重叠区域会产生大量重复的关键点,增加后续匹配的计算量并可能影响精度。 | ||||||
|   - **📝 执行方案**: |   - **📝 执行方案**: | ||||||
|     1. 在 `match.py` 的 `extract_features_sliding_window` 函数返回前。 |     1. 在 `match.py` 的 `extract_features_sliding_window` 函数返回前。 | ||||||
|   | |||||||
							
								
								
									
										76
									
								
								match.py
									
									
									
									
									
								
							
							
						
						
									
										76
									
								
								match.py
									
									
									
									
									
								
							| @@ -44,6 +44,24 @@ def extract_keypoints_and_descriptors(model, image_tensor, kp_thresh): | |||||||
|  |  | ||||||
|     return keypoints, descriptors |     return keypoints, descriptors | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # --- (新增) 简单半径 NMS 去重 --- | ||||||
|  | def radius_nms(kps: torch.Tensor, scores: torch.Tensor, radius: float) -> torch.Tensor: | ||||||
|  |     if kps.numel() == 0: | ||||||
|  |         return torch.empty((0,), dtype=torch.long, device=kps.device) | ||||||
|  |     idx = torch.argsort(scores, descending=True) | ||||||
|  |     keep = [] | ||||||
|  |     taken = torch.zeros(len(kps), dtype=torch.bool, device=kps.device) | ||||||
|  |     for i in idx: | ||||||
|  |         if taken[i]: | ||||||
|  |             continue | ||||||
|  |         keep.append(i.item()) | ||||||
|  |         di = kps - kps[i] | ||||||
|  |         dist2 = (di[:, 0]**2 + di[:, 1]**2) | ||||||
|  |         taken |= dist2 <= (radius * radius) | ||||||
|  |         taken[i] = True | ||||||
|  |     return torch.tensor(keep, dtype=torch.long, device=kps.device) | ||||||
|  |  | ||||||
| # --- (新增) 滑动窗口特征提取函数 --- | # --- (新增) 滑动窗口特征提取函数 --- | ||||||
| def extract_features_sliding_window(model, large_image, transform, matching_cfg): | def extract_features_sliding_window(model, large_image, transform, matching_cfg): | ||||||
|     """ |     """ | ||||||
| @@ -88,6 +106,40 @@ def extract_features_sliding_window(model, large_image, transform, matching_cfg) | |||||||
|     return torch.cat(all_kps, dim=0), torch.cat(all_descs, dim=0) |     return torch.cat(all_kps, dim=0), torch.cat(all_descs, dim=0) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # --- (新增) FPN 路径的关键点与描述子抽取 --- | ||||||
|  | def extract_from_pyramid(model, image_tensor, kp_thresh, nms_cfg): | ||||||
|  |     with torch.no_grad(): | ||||||
|  |         pyramid = model(image_tensor, return_pyramid=True) | ||||||
|  |     all_kps = [] | ||||||
|  |     all_desc = [] | ||||||
|  |     for level_name, (det, desc, stride) in pyramid.items(): | ||||||
|  |         binary = (det > kp_thresh).squeeze(0).squeeze(0) | ||||||
|  |         coords = torch.nonzero(binary).float()  # y,x | ||||||
|  |         if len(coords) == 0: | ||||||
|  |             continue | ||||||
|  |         scores = det.squeeze()[binary] | ||||||
|  |         # 采样描述子 | ||||||
|  |         coords_for_grid = coords.flip(1).view(1, -1, 1, 2) | ||||||
|  |         coords_for_grid = coords_for_grid / torch.tensor([(desc.shape[3]-1)/2, (desc.shape[2]-1)/2], device=desc.device) - 1 | ||||||
|  |         descs = F.grid_sample(desc, coords_for_grid, align_corners=True).squeeze().T | ||||||
|  |         descs = F.normalize(descs, p=2, dim=1) | ||||||
|  |  | ||||||
|  |         # 映射回原图坐标 | ||||||
|  |         kps = coords.flip(1) * float(stride) | ||||||
|  |  | ||||||
|  |         # NMS | ||||||
|  |         if nms_cfg and nms_cfg.get('enabled', False): | ||||||
|  |             keep = radius_nms(kps, scores, float(nms_cfg.get('radius', 4))) | ||||||
|  |             if len(keep) > 0: | ||||||
|  |                 kps = kps[keep] | ||||||
|  |                 descs = descs[keep] | ||||||
|  |         all_kps.append(kps) | ||||||
|  |         all_desc.append(descs) | ||||||
|  |     if not all_kps: | ||||||
|  |         return torch.tensor([], device=image_tensor.device), torch.tensor([], device=image_tensor.device) | ||||||
|  |     return torch.cat(all_kps, dim=0), torch.cat(all_desc, dim=0) | ||||||
|  |  | ||||||
|  |  | ||||||
| # --- 互近邻匹配 (无变动) --- | # --- 互近邻匹配 (无变动) --- | ||||||
| def mutual_nearest_neighbor(descs1, descs2): | def mutual_nearest_neighbor(descs1, descs2): | ||||||
|     if len(descs1) == 0 or len(descs2) == 0: |     if len(descs1) == 0 or len(descs2) == 0: | ||||||
| @@ -113,7 +165,12 @@ def match_template_multiscale( | |||||||
|     """ |     """ | ||||||
|     在不同尺度下搜索模板,并检测多个实例 |     在不同尺度下搜索模板,并检测多个实例 | ||||||
|     """ |     """ | ||||||
|     # 1. 对大版图使用滑动窗口提取全部特征 |     # 1. 版图特征提取:根据配置选择 FPN 或滑窗 | ||||||
|  |     device = next(model.parameters()).device | ||||||
|  |     if getattr(matching_cfg, 'use_fpn', False): | ||||||
|  |         layout_tensor = transform(layout_image).unsqueeze(0).to(device) | ||||||
|  |         layout_kps, layout_descs = extract_from_pyramid(model, layout_tensor, float(matching_cfg.keypoint_threshold), getattr(matching_cfg, 'nms', {})) | ||||||
|  |     else: | ||||||
|         layout_kps, layout_descs = extract_features_sliding_window(model, layout_image, transform, matching_cfg) |         layout_kps, layout_descs = extract_features_sliding_window(model, layout_image, transform, matching_cfg) | ||||||
|     if log_writer: |     if log_writer: | ||||||
|         log_writer.add_scalar("match/layout_keypoints", len(layout_kps), log_step) |         log_writer.add_scalar("match/layout_keypoints", len(layout_kps), log_step) | ||||||
| @@ -154,7 +211,10 @@ def match_template_multiscale( | |||||||
|             scaled_template = template_image.resize((new_W, new_H), Image.LANCZOS) |             scaled_template = template_image.resize((new_W, new_H), Image.LANCZOS) | ||||||
|             template_tensor = transform(scaled_template).unsqueeze(0).to(layout_kps.device) |             template_tensor = transform(scaled_template).unsqueeze(0).to(layout_kps.device) | ||||||
|              |              | ||||||
|             # 提取缩放后模板的特征 |             # 提取缩放后模板的特征:FPN 或单尺度 | ||||||
|  |             if getattr(matching_cfg, 'use_fpn', False): | ||||||
|  |                 template_kps, template_descs = extract_from_pyramid(model, template_tensor, keypoint_threshold, getattr(matching_cfg, 'nms', {})) | ||||||
|  |             else: | ||||||
|                 template_kps, template_descs = extract_keypoints_and_descriptors(model, template_tensor, keypoint_threshold) |                 template_kps, template_descs = extract_keypoints_and_descriptors(model, template_tensor, keypoint_threshold) | ||||||
|              |              | ||||||
|             if len(template_kps) < 4: continue |             if len(template_kps) < 4: continue | ||||||
| @@ -227,6 +287,8 @@ if __name__ == "__main__": | |||||||
|     parser.add_argument('--experiment_name', type=str, default=None, help="TensorBoard 实验名称,覆盖配置文件设置") |     parser.add_argument('--experiment_name', type=str, default=None, help="TensorBoard 实验名称,覆盖配置文件设置") | ||||||
|     parser.add_argument('--tb_log_matches', action='store_true', help="启用模板匹配过程的 TensorBoard 记录") |     parser.add_argument('--tb_log_matches', action='store_true', help="启用模板匹配过程的 TensorBoard 记录") | ||||||
|     parser.add_argument('--disable_tensorboard', action='store_true', help="禁用 TensorBoard 记录") |     parser.add_argument('--disable_tensorboard', action='store_true', help="禁用 TensorBoard 记录") | ||||||
|  |     parser.add_argument('--fpn_off', action='store_true', help="关闭 FPN 匹配路径(等同于 matching.use_fpn=false)") | ||||||
|  |     parser.add_argument('--no_nms', action='store_true', help="关闭关键点去重(NMS)") | ||||||
|     parser.add_argument('--layout', type=str, required=True) |     parser.add_argument('--layout', type=str, required=True) | ||||||
|     parser.add_argument('--template', type=str, required=True) |     parser.add_argument('--template', type=str, required=True) | ||||||
|     parser.add_argument('--output', type=str) |     parser.add_argument('--output', type=str) | ||||||
| @@ -262,6 +324,16 @@ if __name__ == "__main__": | |||||||
|         tb_path.parent.mkdir(parents=True, exist_ok=True) |         tb_path.parent.mkdir(parents=True, exist_ok=True) | ||||||
|         writer = SummaryWriter(tb_path.as_posix()) |         writer = SummaryWriter(tb_path.as_posix()) | ||||||
|  |  | ||||||
|  |     # CLI 快捷开关覆盖 YAML 配置 | ||||||
|  |     try: | ||||||
|  |         if args.fpn_off: | ||||||
|  |             matching_cfg.use_fpn = False | ||||||
|  |         if args.no_nms and hasattr(matching_cfg, 'nms'): | ||||||
|  |             matching_cfg.nms.enabled = False | ||||||
|  |     except Exception: | ||||||
|  |         # 若 OmegaConf 结构不可写,忽略并在后续逻辑中以 getattr 的方式读取 | ||||||
|  |         pass | ||||||
|  |  | ||||||
|     transform = get_transform() |     transform = get_transform() | ||||||
|     model = RoRD().cuda() |     model = RoRD().cuda() | ||||||
|     model.load_state_dict(torch.load(model_path)) |     model.load_state_dict(torch.load(model_path)) | ||||||
|   | |||||||
| @@ -2,20 +2,26 @@ | |||||||
|  |  | ||||||
| import torch | import torch | ||||||
| import torch.nn as nn | import torch.nn as nn | ||||||
|  | import torch.nn.functional as F | ||||||
| from torchvision import models | from torchvision import models | ||||||
|  |  | ||||||
| class RoRD(nn.Module): | class RoRD(nn.Module): | ||||||
|     def __init__(self): |     def __init__(self, fpn_out_channels: int = 256, fpn_levels=(2, 3, 4)): | ||||||
|         """ |         """ | ||||||
|         修复后的 RoRD 模型。 |         修复后的 RoRD 模型。 | ||||||
|         - 实现了共享骨干网络,以提高计算效率和减少内存占用。 |         - 实现了共享骨干网络,以提高计算效率和减少内存占用。 | ||||||
|         - 确保检测头和描述子头使用相同尺寸的特征图。 |         - 确保检测头和描述子头使用相同尺寸的特征图。 | ||||||
|  |         - 新增(可选)FPN 推理路径,提供多尺度特征用于高效匹配。 | ||||||
|         """ |         """ | ||||||
|         super(RoRD, self).__init__() |         super(RoRD, self).__init__() | ||||||
|          |          | ||||||
|         vgg16_features = models.vgg16(pretrained=False).features |         vgg16_features = models.vgg16(pretrained=False).features | ||||||
|  |  | ||||||
|         # 共享骨干网络 - 只使用到 relu4_3,确保特征图尺寸一致 |         # VGG16 特征各阶段索引(conv & relu 层序列) | ||||||
|  |         # relu2_2 索引 8,relu3_3 索引 15,relu4_3 索引 22 | ||||||
|  |         self.features = vgg16_features | ||||||
|  |  | ||||||
|  |         # 共享骨干(向后兼容单尺度路径,使用到 relu4_3) | ||||||
|         self.backbone = nn.Sequential(*list(vgg16_features.children())[:23]) |         self.backbone = nn.Sequential(*list(vgg16_features.children())[:23]) | ||||||
|  |  | ||||||
|         # 检测头 |         # 检测头 | ||||||
| @@ -38,12 +44,72 @@ class RoRD(nn.Module): | |||||||
|             nn.InstanceNorm2d(128) |             nn.InstanceNorm2d(128) | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     def forward(self, x): |         # --- FPN 组件(用于可选多尺度推理) --- | ||||||
|         # 共享特征提取 |         self.fpn_out_channels = fpn_out_channels | ||||||
|         features = self.backbone(x) |         self.fpn_levels = tuple(sorted(set(fpn_levels)))  # e.g., (2,3,4) | ||||||
|  |  | ||||||
|         # 检测器和描述子使用相同的特征图 |         # 横向连接 1x1 将 C2(128)/C3(256)/C4(512) 对齐到相同通道数 | ||||||
|  |         self.lateral_c2 = nn.Conv2d(128, fpn_out_channels, kernel_size=1) | ||||||
|  |         self.lateral_c3 = nn.Conv2d(256, fpn_out_channels, kernel_size=1) | ||||||
|  |         self.lateral_c4 = nn.Conv2d(512, fpn_out_channels, kernel_size=1) | ||||||
|  |  | ||||||
|  |         # 平滑 3x3 conv | ||||||
|  |         self.smooth_p2 = nn.Conv2d(fpn_out_channels, fpn_out_channels, kernel_size=3, padding=1) | ||||||
|  |         self.smooth_p3 = nn.Conv2d(fpn_out_channels, fpn_out_channels, kernel_size=3, padding=1) | ||||||
|  |         self.smooth_p4 = nn.Conv2d(fpn_out_channels, fpn_out_channels, kernel_size=3, padding=1) | ||||||
|  |  | ||||||
|  |         # 共享的 FPN 检测/描述子头(输入通道为 fpn_out_channels) | ||||||
|  |         self.det_head_fpn = nn.Sequential( | ||||||
|  |             nn.Conv2d(fpn_out_channels, 128, kernel_size=3, padding=1), | ||||||
|  |             nn.ReLU(inplace=True), | ||||||
|  |             nn.Conv2d(128, 1, kernel_size=1), | ||||||
|  |             nn.Sigmoid(), | ||||||
|  |         ) | ||||||
|  |         self.desc_head_fpn = nn.Sequential( | ||||||
|  |             nn.Conv2d(fpn_out_channels, 128, kernel_size=3, padding=1), | ||||||
|  |             nn.ReLU(inplace=True), | ||||||
|  |             nn.Conv2d(128, 128, kernel_size=1), | ||||||
|  |             nn.InstanceNorm2d(128), | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def forward(self, x: torch.Tensor, return_pyramid: bool = False): | ||||||
|  |         if not return_pyramid: | ||||||
|  |             # 向后兼容的单尺度路径(relu4_3) | ||||||
|  |             features = self.backbone(x) | ||||||
|             detection_map = self.detection_head(features) |             detection_map = self.detection_head(features) | ||||||
|             descriptors = self.descriptor_head(features) |             descriptors = self.descriptor_head(features) | ||||||
|          |  | ||||||
|             return detection_map, descriptors |             return detection_map, descriptors | ||||||
|  |  | ||||||
|  |         # --- FPN 路径:提取 C2/C3/C4 --- | ||||||
|  |         c2, c3, c4 = self._extract_c234(x) | ||||||
|  |         p4 = self.lateral_c4(c4) | ||||||
|  |         p3 = self.lateral_c3(c3) + F.interpolate(p4, size=c3.shape[-2:], mode="nearest") | ||||||
|  |         p2 = self.lateral_c2(c2) + F.interpolate(p3, size=c2.shape[-2:], mode="nearest") | ||||||
|  |  | ||||||
|  |         p4 = self.smooth_p4(p4) | ||||||
|  |         p3 = self.smooth_p3(p3) | ||||||
|  |         p2 = self.smooth_p2(p2) | ||||||
|  |  | ||||||
|  |         pyramid = {} | ||||||
|  |         if 4 in self.fpn_levels: | ||||||
|  |             pyramid["P4"] = (self.det_head_fpn(p4), self.desc_head_fpn(p4), 8) | ||||||
|  |         if 3 in self.fpn_levels: | ||||||
|  |             pyramid["P3"] = (self.det_head_fpn(p3), self.desc_head_fpn(p3), 4) | ||||||
|  |         if 2 in self.fpn_levels: | ||||||
|  |             pyramid["P2"] = (self.det_head_fpn(p2), self.desc_head_fpn(p2), 2) | ||||||
|  |         return pyramid | ||||||
|  |  | ||||||
|  |     def _extract_c234(self, x: torch.Tensor): | ||||||
|  |         """提取 VGG 中间层特征:C2(relU2_2), C3(relu3_3), C4(relu4_3).""" | ||||||
|  |         c2 = c3 = c4 = None | ||||||
|  |         for i, layer in enumerate(self.features): | ||||||
|  |             x = layer(x) | ||||||
|  |             if i == 8:   # relu2_2 | ||||||
|  |                 c2 = x | ||||||
|  |             elif i == 15:  # relu3_3 | ||||||
|  |                 c3 = x | ||||||
|  |             elif i == 22:  # relu4_3 | ||||||
|  |                 c4 = x | ||||||
|  |                 break | ||||||
|  |         assert c2 is not None and c3 is not None and c4 is not None | ||||||
|  |         return c2, c3, c4 | ||||||
		Reference in New Issue
	
	Block a user