add midtern report and change data source
This commit is contained in:
343
match.py
343
match.py
@@ -1,6 +1,7 @@
|
||||
# match.py
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
@@ -18,6 +19,127 @@ from models.rord import RoRD
|
||||
from utils.config_loader import load_config, to_absolute_path
|
||||
from utils.data_utils import get_transform
|
||||
|
||||
# --- 新增:功能增强函数 ---
|
||||
def extract_rotation_angle(H):
|
||||
"""
|
||||
从单应性矩阵中提取旋转角度
|
||||
返回0°, 90°, 180°, 270°之一
|
||||
"""
|
||||
if H is None:
|
||||
return 0
|
||||
|
||||
# 提取旋转分量
|
||||
cos_theta = H[0, 0] / np.sqrt(H[0, 0]**2 + H[1, 0]**2 + 1e-8)
|
||||
sin_theta = H[1, 0] / np.sqrt(H[0, 0]**2 + H[1, 0]**2 + 1e-8)
|
||||
|
||||
# 计算角度(弧度转角度)
|
||||
angle = np.arctan2(sin_theta, cos_theta) * 180 / np.pi
|
||||
|
||||
# 四舍五入到最近的90度倍数
|
||||
angles = [0, 90, 180, 270]
|
||||
nearest_angle = min(angles, key=lambda x: abs(x - angle))
|
||||
|
||||
return nearest_angle
|
||||
|
||||
|
||||
def calculate_match_score(inlier_count, total_keypoints, H, inlier_ratio=None):
|
||||
"""
|
||||
计算匹配质量评分 (0-1)
|
||||
|
||||
Args:
|
||||
inlier_count: 内点数量
|
||||
total_keypoints: 总关键点数量
|
||||
H: 单应性矩阵
|
||||
inlier_ratio: 内点比例(可选)
|
||||
"""
|
||||
if inlier_ratio is None:
|
||||
inlier_ratio = inlier_count / max(total_keypoints, 1)
|
||||
|
||||
# 基于内点比例的基础分数
|
||||
base_score = inlier_ratio
|
||||
|
||||
# 基于变换矩阵质量的分数(越接近单位矩阵分数越高)
|
||||
if H is not None:
|
||||
# 计算变换的"理想程度"
|
||||
det = np.linalg.det(H)
|
||||
ideal_det = 1.0
|
||||
det_score = 1.0 / (1.0 + abs(np.log(det + 1e-8)))
|
||||
|
||||
# 综合评分
|
||||
final_score = base_score * 0.7 + det_score * 0.3
|
||||
else:
|
||||
final_score = base_score
|
||||
|
||||
return min(max(final_score, 0.0), 1.0)
|
||||
|
||||
|
||||
def calculate_similarity(matches_count, template_kps_count, layout_kps_count):
|
||||
"""
|
||||
计算模板和版图之间的相似度
|
||||
|
||||
Args:
|
||||
matches_count: 匹配对数量
|
||||
template_kps_count: 模板关键点数量
|
||||
layout_kps_count: 版图关键点数量
|
||||
"""
|
||||
# 匹配率
|
||||
template_match_rate = matches_count / max(template_kps_count, 1)
|
||||
|
||||
# 覆盖率(简化计算)
|
||||
coverage_rate = min(matches_count / max(layout_kps_count, 1), 1.0)
|
||||
|
||||
# 综合相似度
|
||||
similarity = (template_match_rate * 0.6 + coverage_rate * 0.4)
|
||||
|
||||
return min(max(similarity, 0.0), 1.0)
|
||||
|
||||
|
||||
def generate_difference_description(H, inlier_count, total_matches, angle_diff=0):
|
||||
"""
|
||||
生成差异描述
|
||||
|
||||
Args:
|
||||
H: 单应性矩阵
|
||||
inlier_count: 内点数量
|
||||
total_matches: 总匹配数
|
||||
angle_diff: 角度差异
|
||||
"""
|
||||
descriptions = []
|
||||
|
||||
# 基于内点比例的描述
|
||||
if total_matches > 0:
|
||||
inlier_ratio = inlier_count / total_matches
|
||||
if inlier_ratio > 0.8:
|
||||
descriptions.append("高度匹配")
|
||||
elif inlier_ratio > 0.6:
|
||||
descriptions.append("良好匹配")
|
||||
elif inlier_ratio > 0.4:
|
||||
descriptions.append("中等匹配")
|
||||
else:
|
||||
descriptions.append("低质量匹配")
|
||||
|
||||
# 基于旋转的描述
|
||||
if angle_diff != 0:
|
||||
descriptions.append(f"旋转{angle_diff}度")
|
||||
else:
|
||||
descriptions.append("无旋转")
|
||||
|
||||
# 基于变换的描述
|
||||
if H is not None:
|
||||
# 检查缩放
|
||||
scale_x = np.sqrt(H[0,0]**2 + H[1,0]**2)
|
||||
scale_y = np.sqrt(H[0,1]**2 + H[1,1]**2)
|
||||
avg_scale = (scale_x + scale_y) / 2
|
||||
|
||||
if abs(avg_scale - 1.0) > 0.1:
|
||||
if avg_scale > 1.0:
|
||||
descriptions.append(f"放大{avg_scale:.2f}倍")
|
||||
else:
|
||||
descriptions.append(f"缩小{1/avg_scale:.2f}倍")
|
||||
|
||||
return ", ".join(descriptions) if descriptions else "无法评估差异"
|
||||
|
||||
|
||||
# --- 特征提取函数 (基本无变动) ---
|
||||
def extract_keypoints_and_descriptors(model, image_tensor, kp_thresh):
|
||||
with torch.no_grad():
|
||||
@@ -161,9 +283,23 @@ def match_template_multiscale(
|
||||
matching_cfg,
|
||||
log_writer: SummaryWriter | None = None,
|
||||
log_step: int = 0,
|
||||
return_detailed_info: bool = True,
|
||||
):
|
||||
"""
|
||||
在不同尺度下搜索模板,并检测多个实例
|
||||
|
||||
Args:
|
||||
model: RoRD模型
|
||||
layout_image: 大版图图像
|
||||
template_image: 小版图图像
|
||||
transform: 图像预处理变换
|
||||
matching_cfg: 匹配配置
|
||||
log_writer: TensorBoard日志记录器
|
||||
log_step: 日志步数
|
||||
return_detailed_info: 是否返回详细信息
|
||||
|
||||
Returns:
|
||||
匹配结果列表,包含坐标、旋转角度、置信度等信息
|
||||
"""
|
||||
# 1. 版图特征提取:根据配置选择 FPN 或滑窗
|
||||
device = next(model.parameters()).device
|
||||
@@ -248,8 +384,59 @@ def match_template_multiscale(
|
||||
|
||||
x_min, y_min = inlier_layout_kps.min(axis=0)
|
||||
x_max, y_max = inlier_layout_kps.max(axis=0)
|
||||
|
||||
instance = {'x': int(x_min), 'y': int(y_min), 'width': int(x_max - x_min), 'height': int(y_max - y_min), 'homography': best_match_info['H']}
|
||||
|
||||
# 提取旋转角度
|
||||
rotation_angle = extract_rotation_angle(best_match_info['H'])
|
||||
|
||||
# 计算匹配质量评分
|
||||
confidence = calculate_match_score(
|
||||
inlier_count=int(best_match_info['inliers']),
|
||||
total_keypoints=len(current_layout_kps),
|
||||
H=best_match_info['H']
|
||||
)
|
||||
|
||||
# 计算相似度
|
||||
similarity = calculate_similarity(
|
||||
matches_count=int(best_match_info['inliers']),
|
||||
template_kps_count=len(template_kps),
|
||||
layout_kps_count=len(current_layout_kps)
|
||||
)
|
||||
|
||||
# 生成差异描述
|
||||
diff_description = generate_difference_description(
|
||||
H=best_match_info['H'],
|
||||
inlier_count=int(best_match_info['inliers']),
|
||||
total_matches=len(matches),
|
||||
angle_diff=rotation_angle
|
||||
)
|
||||
|
||||
# 构建详细实例信息
|
||||
if return_detailed_info:
|
||||
instance = {
|
||||
'bbox': {
|
||||
'x': int(x_min),
|
||||
'y': int(y_min),
|
||||
'width': int(x_max - x_min),
|
||||
'height': int(y_max - y_min)
|
||||
},
|
||||
'rotation': rotation_angle,
|
||||
'confidence': round(confidence, 3),
|
||||
'similarity': round(similarity, 3),
|
||||
'inliers': int(best_match_info['inliers']),
|
||||
'scale': best_match_info.get('scale', 1.0),
|
||||
'homography': best_match_info['H'].tolist() if best_match_info['H'] is not None else None,
|
||||
'description': diff_description
|
||||
}
|
||||
else:
|
||||
# 兼容旧格式
|
||||
instance = {
|
||||
'x': int(x_min),
|
||||
'y': int(y_min),
|
||||
'width': int(x_max - x_min),
|
||||
'height': int(y_max - y_min),
|
||||
'homography': best_match_info['H']
|
||||
}
|
||||
|
||||
found_instances.append(instance)
|
||||
|
||||
# 屏蔽已匹配区域的关键点,以便检测下一个实例
|
||||
@@ -269,16 +456,124 @@ def match_template_multiscale(
|
||||
return found_instances
|
||||
|
||||
|
||||
def visualize_matches(layout_path, bboxes, output_path):
|
||||
def visualize_matches(layout_path, matches, output_path):
|
||||
"""
|
||||
可视化匹配结果,支持新的详细格式
|
||||
|
||||
Args:
|
||||
layout_path: 大版图路径
|
||||
matches: 匹配结果列表
|
||||
output_path: 输出图像路径
|
||||
"""
|
||||
layout_img = cv2.imread(layout_path)
|
||||
for i, bbox in enumerate(bboxes):
|
||||
x, y, w, h = bbox['x'], bbox['y'], bbox['width'], bbox['height']
|
||||
if layout_img is None:
|
||||
print(f"错误:无法读取图像 {layout_path}")
|
||||
return
|
||||
|
||||
for i, match in enumerate(matches):
|
||||
# 支持新旧格式
|
||||
if 'bbox' in match:
|
||||
x, y, w, h = match['bbox']['x'], match['bbox']['y'], match['bbox']['width'], match['bbox']['height']
|
||||
confidence = match.get('confidence', 0)
|
||||
rotation = match.get('rotation', 0)
|
||||
description = match.get('description', '')
|
||||
else:
|
||||
# 兼容旧格式
|
||||
x, y, w, h = match['x'], match['y'], match['width'], match['height']
|
||||
confidence = 0
|
||||
rotation = 0
|
||||
description = ''
|
||||
|
||||
# 绘制边界框
|
||||
cv2.rectangle(layout_img, (x, y), (x + w, y + h), (0, 255, 0), 2)
|
||||
cv2.putText(layout_img, f"Match {i+1}", (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
|
||||
|
||||
# 准备标签文本
|
||||
label_parts = [f"Match {i+1}"]
|
||||
if confidence > 0:
|
||||
label_parts.append(f"Conf: {confidence:.2f}")
|
||||
if rotation != 0:
|
||||
label_parts.append(f"Rot: {rotation}°")
|
||||
if description:
|
||||
label_parts.append(f"{description[:20]}...") # 截断长描述
|
||||
|
||||
label = " | ".join(label_parts)
|
||||
|
||||
# 绘制标签背景
|
||||
(label_width, label_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)
|
||||
cv2.rectangle(layout_img, (x, y - label_height - 10), (x + label_width, y), (0, 255, 0), -1)
|
||||
cv2.putText(layout_img, label, (x, y - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 2)
|
||||
|
||||
cv2.imwrite(output_path, layout_img)
|
||||
print(f"可视化结果已保存至: {output_path}")
|
||||
|
||||
|
||||
def save_matches_json(matches, output_path):
|
||||
"""
|
||||
保存匹配结果到JSON文件
|
||||
|
||||
Args:
|
||||
matches: 匹配结果列表
|
||||
output_path: 输出JSON文件路径
|
||||
"""
|
||||
result = {
|
||||
'found_matches': len(matches) > 0,
|
||||
'total_matches': len(matches),
|
||||
'matches': matches
|
||||
}
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(result, f, indent=2, ensure_ascii=False)
|
||||
|
||||
print(f"匹配结果已保存至: {output_path}")
|
||||
|
||||
|
||||
def print_detailed_results(matches):
|
||||
"""
|
||||
打印详细的匹配结果
|
||||
|
||||
Args:
|
||||
matches: 匹配结果列表
|
||||
"""
|
||||
print("\n" + "="*60)
|
||||
print("🎯 版图匹配结果详情")
|
||||
print("="*60)
|
||||
|
||||
if not matches:
|
||||
print("❌ 未找到任何匹配区域")
|
||||
return
|
||||
|
||||
print(f"✅ 共找到 {len(matches)} 个匹配区域\n")
|
||||
|
||||
for i, match in enumerate(matches, 1):
|
||||
print(f"📍 匹配区域 #{i}")
|
||||
print("-" * 40)
|
||||
|
||||
# 支持新旧格式
|
||||
if 'bbox' in match:
|
||||
bbox = match['bbox']
|
||||
print(f"📐 位置: ({bbox['x']}, {bbox['y']})")
|
||||
print(f"📏 尺寸: {bbox['width']} × {bbox['height']} 像素")
|
||||
|
||||
if 'rotation' in match:
|
||||
print(f"🔄 旋转角度: {match['rotation']}°")
|
||||
if 'confidence' in match:
|
||||
print(f"🎯 置信度: {match['confidence']:.3f}")
|
||||
if 'similarity' in match:
|
||||
print(f"📊 相似度: {match['similarity']:.3f}")
|
||||
if 'inliers' in match:
|
||||
print(f"🔗 内点数量: {match['inliers']}")
|
||||
if 'scale' in match:
|
||||
print(f"📈 匹配尺度: {match['scale']:.2f}x")
|
||||
if 'description' in match:
|
||||
print(f"📝 差异描述: {match['description']}")
|
||||
else:
|
||||
# 兼容旧格式
|
||||
print(f"📐 位置: ({match['x']}, {match['y']})")
|
||||
print(f"📏 尺寸: {match['width']} × {match['height']} 像素")
|
||||
|
||||
print() # 空行分隔
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="使用 RoRD 进行多尺度模板匹配")
|
||||
parser.add_argument('--config', type=str, default="configs/base_config.yaml", help="YAML 配置文件路径")
|
||||
@@ -289,9 +584,11 @@ if __name__ == "__main__":
|
||||
parser.add_argument('--disable_tensorboard', action='store_true', help="禁用 TensorBoard 记录")
|
||||
parser.add_argument('--fpn_off', action='store_true', help="关闭 FPN 匹配路径(等同于 matching.use_fpn=false)")
|
||||
parser.add_argument('--no_nms', action='store_true', help="关闭关键点去重(NMS)")
|
||||
parser.add_argument('--layout', type=str, required=True)
|
||||
parser.add_argument('--template', type=str, required=True)
|
||||
parser.add_argument('--output', type=str)
|
||||
parser.add_argument('--layout', type=str, required=True, help="大版图图像路径")
|
||||
parser.add_argument('--template', type=str, required=True, help="小版图(模板)图像路径")
|
||||
parser.add_argument('--output', type=str, help="可视化结果保存路径")
|
||||
parser.add_argument('--json_output', type=str, help="JSON结果保存路径")
|
||||
parser.add_argument('--simple_format', action='store_true', help="使用简单的输出格式(兼容旧版本)")
|
||||
args = parser.parse_args()
|
||||
|
||||
cfg = load_config(args.config)
|
||||
@@ -342,7 +639,8 @@ if __name__ == "__main__":
|
||||
layout_image = Image.open(args.layout).convert('L')
|
||||
template_image = Image.open(args.template).convert('L')
|
||||
|
||||
detected_bboxes = match_template_multiscale(
|
||||
# 执行匹配,根据参数选择详细或简单格式
|
||||
detected_matches = match_template_multiscale(
|
||||
model,
|
||||
layout_image,
|
||||
template_image,
|
||||
@@ -350,16 +648,27 @@ if __name__ == "__main__":
|
||||
matching_cfg,
|
||||
log_writer=writer,
|
||||
log_step=0,
|
||||
return_detailed_info=not args.simple_format,
|
||||
)
|
||||
|
||||
print("\n检测到的边界框:")
|
||||
for bbox in detected_bboxes:
|
||||
print(bbox)
|
||||
|
||||
# 打印详细结果
|
||||
print_detailed_results(detected_matches)
|
||||
|
||||
# 保存JSON结果
|
||||
if args.json_output:
|
||||
save_matches_json(detected_matches, args.json_output)
|
||||
|
||||
# 可视化结果
|
||||
if args.output:
|
||||
visualize_matches(args.layout, detected_bboxes, args.output)
|
||||
visualize_matches(args.layout, detected_matches, args.output)
|
||||
|
||||
if writer:
|
||||
writer.add_scalar("match/output_instances", len(detected_bboxes), 0)
|
||||
writer.add_scalar("match/output_instances", len(detected_matches), 0)
|
||||
writer.add_text("match/layout_path", args.layout, 0)
|
||||
writer.close()
|
||||
writer.close()
|
||||
|
||||
print("\n🎉 匹配完成!")
|
||||
if args.json_output:
|
||||
print(f"📄 详细结果已保存到: {args.json_output}")
|
||||
if args.output:
|
||||
print(f"🖼️ 可视化结果已保存到: {args.output}")
|
||||
Reference in New Issue
Block a user