finish Experiment Tracking and Evaluation
This commit is contained in:
71
match.py
71
match.py
@@ -9,6 +9,10 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from PIL import Image
|
||||
try:
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
except ImportError: # pragma: no cover - fallback for environments without torch tensorboard
|
||||
from tensorboardX import SummaryWriter # type: ignore
|
||||
|
||||
from models.rord import RoRD
|
||||
from utils.config_loader import load_config, to_absolute_path
|
||||
@@ -97,16 +101,28 @@ def mutual_nearest_neighbor(descs1, descs2):
|
||||
return matches
|
||||
|
||||
# --- (已修改) 多尺度、多实例匹配主函数 ---
|
||||
def match_template_multiscale(model, layout_image, template_image, transform, matching_cfg):
|
||||
def match_template_multiscale(
|
||||
model,
|
||||
layout_image,
|
||||
template_image,
|
||||
transform,
|
||||
matching_cfg,
|
||||
log_writer: SummaryWriter | None = None,
|
||||
log_step: int = 0,
|
||||
):
|
||||
"""
|
||||
在不同尺度下搜索模板,并检测多个实例
|
||||
"""
|
||||
# 1. 对大版图使用滑动窗口提取全部特征
|
||||
layout_kps, layout_descs = extract_features_sliding_window(model, layout_image, transform, matching_cfg)
|
||||
if log_writer:
|
||||
log_writer.add_scalar("match/layout_keypoints", len(layout_kps), log_step)
|
||||
|
||||
min_inliers = int(matching_cfg.min_inliers)
|
||||
if len(layout_kps) < min_inliers:
|
||||
print("从大版图中提取的关键点过少,无法进行匹配。")
|
||||
if log_writer:
|
||||
log_writer.add_scalar("match/instances_found", 0, log_step)
|
||||
return []
|
||||
|
||||
found_instances = []
|
||||
@@ -162,6 +178,10 @@ def match_template_multiscale(model, layout_image, template_image, transform, ma
|
||||
# 4. 如果在所有尺度中找到了最佳匹配,则记录并屏蔽
|
||||
if best_match_info['inliers'] > min_inliers:
|
||||
print(f"找到一个匹配实例!内点数: {best_match_info['inliers']}, 使用的模板尺度: {best_match_info['scale']:.2f}x")
|
||||
if log_writer:
|
||||
instance_index = len(found_instances)
|
||||
log_writer.add_scalar("match/instance_inliers", int(best_match_info['inliers']), log_step + instance_index)
|
||||
log_writer.add_scalar("match/instance_scale", float(best_match_info['scale']), log_step + instance_index)
|
||||
|
||||
inlier_mask = best_match_info['mask'].ravel().astype(bool)
|
||||
inlier_layout_kps = best_match_info['dst_pts'][inlier_mask]
|
||||
@@ -183,6 +203,9 @@ def match_template_multiscale(model, layout_image, template_image, transform, ma
|
||||
print("在所有尺度下均未找到新的匹配实例,搜索结束。")
|
||||
break
|
||||
|
||||
if log_writer:
|
||||
log_writer.add_scalar("match/instances_found", len(found_instances), log_step)
|
||||
|
||||
return found_instances
|
||||
|
||||
|
||||
@@ -200,6 +223,10 @@ if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="使用 RoRD 进行多尺度模板匹配")
|
||||
parser.add_argument('--config', type=str, default="configs/base_config.yaml", help="YAML 配置文件路径")
|
||||
parser.add_argument('--model_path', type=str, default=None, help="模型权重路径,若未提供则使用配置文件中的路径")
|
||||
parser.add_argument('--log_dir', type=str, default=None, help="TensorBoard 日志根目录,覆盖配置文件设置")
|
||||
parser.add_argument('--experiment_name', type=str, default=None, help="TensorBoard 实验名称,覆盖配置文件设置")
|
||||
parser.add_argument('--tb_log_matches', action='store_true', help="启用模板匹配过程的 TensorBoard 记录")
|
||||
parser.add_argument('--disable_tensorboard', action='store_true', help="禁用 TensorBoard 记录")
|
||||
parser.add_argument('--layout', type=str, required=True)
|
||||
parser.add_argument('--template', type=str, required=True)
|
||||
parser.add_argument('--output', type=str)
|
||||
@@ -208,8 +235,33 @@ if __name__ == "__main__":
|
||||
cfg = load_config(args.config)
|
||||
config_dir = Path(args.config).resolve().parent
|
||||
matching_cfg = cfg.matching
|
||||
logging_cfg = cfg.get("logging", None)
|
||||
model_path = args.model_path or str(to_absolute_path(cfg.paths.model_path, config_dir))
|
||||
|
||||
use_tensorboard = False
|
||||
log_dir = None
|
||||
experiment_name = None
|
||||
if logging_cfg is not None:
|
||||
use_tensorboard = bool(logging_cfg.get("use_tensorboard", False))
|
||||
log_dir = logging_cfg.get("log_dir", "runs")
|
||||
experiment_name = logging_cfg.get("experiment_name", "default")
|
||||
|
||||
if args.disable_tensorboard:
|
||||
use_tensorboard = False
|
||||
if args.log_dir is not None:
|
||||
log_dir = args.log_dir
|
||||
if args.experiment_name is not None:
|
||||
experiment_name = args.experiment_name
|
||||
|
||||
should_log_matches = args.tb_log_matches and use_tensorboard and log_dir is not None
|
||||
writer = None
|
||||
if should_log_matches:
|
||||
log_root = Path(log_dir).expanduser()
|
||||
exp_folder = experiment_name or "default"
|
||||
tb_path = log_root / "match" / exp_folder
|
||||
tb_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
writer = SummaryWriter(tb_path.as_posix())
|
||||
|
||||
transform = get_transform()
|
||||
model = RoRD().cuda()
|
||||
model.load_state_dict(torch.load(model_path))
|
||||
@@ -218,11 +270,24 @@ if __name__ == "__main__":
|
||||
layout_image = Image.open(args.layout).convert('L')
|
||||
template_image = Image.open(args.template).convert('L')
|
||||
|
||||
detected_bboxes = match_template_multiscale(model, layout_image, template_image, transform, matching_cfg)
|
||||
detected_bboxes = match_template_multiscale(
|
||||
model,
|
||||
layout_image,
|
||||
template_image,
|
||||
transform,
|
||||
matching_cfg,
|
||||
log_writer=writer,
|
||||
log_step=0,
|
||||
)
|
||||
|
||||
print("\n检测到的边界框:")
|
||||
for bbox in detected_bboxes:
|
||||
print(bbox)
|
||||
|
||||
if args.output:
|
||||
visualize_matches(args.layout, detected_bboxes, args.output)
|
||||
visualize_matches(args.layout, detected_bboxes, args.output)
|
||||
|
||||
if writer:
|
||||
writer.add_scalar("match/output_instances", len(detected_bboxes), 0)
|
||||
writer.add_text("match/layout_path", args.layout, 0)
|
||||
writer.close()
|
||||
Reference in New Issue
Block a user