complete code struction update

This commit is contained in:
Jiao77
2025-09-25 20:20:24 +08:00
parent e0b250e77f
commit 8c9926c815
10 changed files with 480 additions and 290 deletions

View File

@@ -1,17 +1,17 @@
# evaluate.py
import argparse
import json
import os
from pathlib import Path
import torch
from PIL import Image
import json
import os
import argparse
import config
from models.rord import RoRD
from utils.data_utils import get_transform
from data.ic_dataset import ICLayoutDataset
# (已修改) 导入新的匹配函数
from match import match_template_multiscale
from models.rord import RoRD
from utils.config_loader import load_config, to_absolute_path
from utils.data_utils import get_transform
def compute_iou(box1, box2):
x1, y1, w1, h1 = box1['x'], box1['y'], box1['width'], box1['height']
@@ -23,7 +23,7 @@ def compute_iou(box1, box2):
return inter_area / union_area if union_area > 0 else 0
# --- (已修改) 评估函数 ---
def evaluate(model, val_dataset_dir, val_annotations_dir, template_dir):
def evaluate(model, val_dataset_dir, val_annotations_dir, template_dir, matching_cfg, iou_threshold):
model.eval()
all_tp, all_fp, all_fn = 0, 0, 0
@@ -59,7 +59,7 @@ def evaluate(model, val_dataset_dir, val_annotations_dir, template_dir):
template_image = Image.open(template_path).convert('L')
# (已修改) 调用新的多尺度匹配函数
detected = match_template_multiscale(model, layout_image, template_image, transform)
detected = match_template_multiscale(model, layout_image, template_image, transform, matching_cfg)
gt_boxes = gt_by_template.get(template_name, [])
@@ -76,7 +76,7 @@ def evaluate(model, val_dataset_dir, val_annotations_dir, template_dir):
if iou > best_iou:
best_iou, best_gt_idx = iou, i
if best_iou > config.IOU_THRESHOLD:
if best_iou > iou_threshold:
if not matched_gt[best_gt_idx]:
tp += 1
matched_gt[best_gt_idx] = True
@@ -96,17 +96,29 @@ def evaluate(model, val_dataset_dir, val_annotations_dir, template_dir):
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="评估 RoRD 模型性能")
parser.add_argument('--model_path', type=str, default=config.MODEL_PATH)
parser.add_argument('--val_dir', type=str, default=config.VAL_IMG_DIR)
parser.add_argument('--annotations_dir', type=str, default=config.VAL_ANN_DIR)
parser.add_argument('--templates_dir', type=str, default=config.TEMPLATE_DIR)
parser.add_argument('--config', type=str, default="configs/base_config.yaml", help="YAML 配置文件路径")
parser.add_argument('--model_path', type=str, default=None, help="模型权重路径,若未提供则使用配置文件中的路径")
parser.add_argument('--val_dir', type=str, default=None, help="验证图像目录,若未提供则使用配置文件中的路径")
parser.add_argument('--annotations_dir', type=str, default=None, help="验证标注目录,若未提供则使用配置文件中的路径")
parser.add_argument('--templates_dir', type=str, default=None, help="模板目录,若未提供则使用配置文件中的路径")
args = parser.parse_args()
cfg = load_config(args.config)
config_dir = Path(args.config).resolve().parent
paths_cfg = cfg.paths
matching_cfg = cfg.matching
eval_cfg = cfg.evaluation
model_path = args.model_path or str(to_absolute_path(paths_cfg.model_path, config_dir))
val_dir = args.val_dir or str(to_absolute_path(paths_cfg.val_img_dir, config_dir))
annotations_dir = args.annotations_dir or str(to_absolute_path(paths_cfg.val_ann_dir, config_dir))
templates_dir = args.templates_dir or str(to_absolute_path(paths_cfg.template_dir, config_dir))
iou_threshold = float(eval_cfg.iou_threshold)
model = RoRD().cuda()
model.load_state_dict(torch.load(args.model_path))
model.load_state_dict(torch.load(model_path))
# (已修改) 不再需要预加载数据集,直接传入路径
results = evaluate(model, args.val_dir, args.annotations_dir, args.templates_dir)
results = evaluate(model, val_dir, annotations_dir, templates_dir, matching_cfg, iou_threshold)
print("\n--- 评估结果 ---")
print(f" 精确率 (Precision): {results['precision']:.4f}")