chenge to english version
This commit is contained in:
36
evaluate.py
36
evaluate.py
@@ -10,7 +10,7 @@ import config
|
||||
from models.rord import RoRD
|
||||
from utils.data_utils import get_transform
|
||||
from data.ic_dataset import ICLayoutDataset
|
||||
# (已修改) 导入新的匹配函数
|
||||
# (Modified) Import new matching function
|
||||
from match import match_template_multiscale
|
||||
|
||||
def compute_iou(box1, box2):
|
||||
@@ -22,48 +22,48 @@ def compute_iou(box1, box2):
|
||||
union_area = w1 * h1 + w2 * h2 - inter_area
|
||||
return inter_area / union_area if union_area > 0 else 0
|
||||
|
||||
# --- (已修改) 评估函数 ---
|
||||
# --- (Modified) Evaluation function ---
|
||||
def evaluate(model, val_dataset_dir, val_annotations_dir, template_dir):
|
||||
model.eval()
|
||||
all_tp, all_fp, all_fn = 0, 0, 0
|
||||
|
||||
# 只需要一个统一的 transform 给匹配函数内部使用
|
||||
# Only need a unified transform for internal use by matching function
|
||||
transform = get_transform()
|
||||
|
||||
template_paths = [os.path.join(template_dir, f) for f in os.listdir(template_dir) if f.endswith('.png')]
|
||||
layout_image_names = [f for f in os.listdir(val_dataset_dir) if f.endswith('.png')]
|
||||
|
||||
# (已修改) 循环遍历验证集中的每个版图文件
|
||||
# (Modified) Loop through each layout file in validation set
|
||||
for layout_name in layout_image_names:
|
||||
print(f"\n正在评估版图: {layout_name}")
|
||||
print(f"\nEvaluating layout: {layout_name}")
|
||||
layout_path = os.path.join(val_dataset_dir, layout_name)
|
||||
annotation_path = os.path.join(val_annotations_dir, layout_name.replace('.png', '.json'))
|
||||
|
||||
# 加载原始PIL图像,以支持滑动窗口
|
||||
# Load original PIL image to support sliding window
|
||||
layout_image = Image.open(layout_path).convert('L')
|
||||
|
||||
# 加载标注信息
|
||||
# Load annotation information
|
||||
if not os.path.exists(annotation_path):
|
||||
continue
|
||||
with open(annotation_path, 'r') as f:
|
||||
annotation = json.load(f)
|
||||
|
||||
# 按模板对真实标注进行分组
|
||||
# Group ground truth annotations by template
|
||||
gt_by_template = {os.path.basename(box['template']): [] for box in annotation.get('boxes', [])}
|
||||
for box in annotation.get('boxes', []):
|
||||
gt_by_template[os.path.basename(box['template'])].append(box)
|
||||
|
||||
# 遍历每个模板,在当前版图上进行匹配
|
||||
# Iterate through each template and perform matching on current layout
|
||||
for template_path in template_paths:
|
||||
template_name = os.path.basename(template_path)
|
||||
template_image = Image.open(template_path).convert('L')
|
||||
|
||||
# (已修改) 调用新的多尺度匹配函数
|
||||
# (Modified) Call new multi-scale matching function
|
||||
detected = match_template_multiscale(model, layout_image, template_image, transform)
|
||||
|
||||
gt_boxes = gt_by_template.get(template_name, [])
|
||||
|
||||
# 计算 TP, FP, FN (这部分逻辑不变)
|
||||
# Calculate TP, FP, FN (this logic remains unchanged)
|
||||
matched_gt = [False] * len(gt_boxes)
|
||||
tp = 0
|
||||
if len(detected) > 0:
|
||||
@@ -88,14 +88,14 @@ def evaluate(model, val_dataset_dir, val_annotations_dir, template_dir):
|
||||
all_fp += fp
|
||||
all_fn += fn
|
||||
|
||||
# 计算最终指标
|
||||
# Calculate final metrics
|
||||
precision = all_tp / (all_tp + all_fp) if (all_tp + all_fp) > 0 else 0
|
||||
recall = all_tp / (all_tp + all_fn) if (all_tp + all_fn) > 0 else 0
|
||||
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
|
||||
return {'precision': precision, 'recall': recall, 'f1': f1}
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="评估 RoRD 模型性能")
|
||||
parser = argparse.ArgumentParser(description="Evaluate RoRD model performance")
|
||||
parser.add_argument('--model_path', type=str, default=config.MODEL_PATH)
|
||||
parser.add_argument('--val_dir', type=str, default=config.VAL_IMG_DIR)
|
||||
parser.add_argument('--annotations_dir', type=str, default=config.VAL_ANN_DIR)
|
||||
@@ -105,10 +105,10 @@ if __name__ == "__main__":
|
||||
model = RoRD().cuda()
|
||||
model.load_state_dict(torch.load(args.model_path))
|
||||
|
||||
# (已修改) 不再需要预加载数据集,直接传入路径
|
||||
# (Modified) No longer need to preload dataset, directly pass paths
|
||||
results = evaluate(model, args.val_dir, args.annotations_dir, args.templates_dir)
|
||||
|
||||
print("\n--- 评估结果 ---")
|
||||
print(f" 精确率 (Precision): {results['precision']:.4f}")
|
||||
print(f" 召回率 (Recall): {results['recall']:.4f}")
|
||||
print(f" F1 分数 (F1 Score): {results['f1']:.4f}")
|
||||
print("\n--- Evaluation Results ---")
|
||||
print(f" Precision: {results['precision']:.4f}")
|
||||
print(f" Recall: {results['recall']:.4f}")
|
||||
print(f" F1 Score: {results['f1']:.4f}")
|
||||
Reference in New Issue
Block a user