finish Experiment Tracking and Evaluation

This commit is contained in:
Jiao77
2025-09-25 21:24:41 +08:00
parent 05ec32bac1
commit 17d3f419f6
9 changed files with 565 additions and 37 deletions

View File

@@ -7,6 +7,7 @@ from pathlib import Path
import torch
from PIL import Image
from torch.utils.tensorboard import SummaryWriter
from match import match_template_multiscale
from models.rord import RoRD
@@ -23,7 +24,16 @@ def compute_iou(box1, box2):
return inter_area / union_area if union_area > 0 else 0
# --- (已修改) 评估函数 ---
def evaluate(model, val_dataset_dir, val_annotations_dir, template_dir, matching_cfg, iou_threshold):
def evaluate(
model,
val_dataset_dir,
val_annotations_dir,
template_dir,
matching_cfg,
iou_threshold,
summary_writer: SummaryWriter | None = None,
global_step: int = 0,
):
model.eval()
all_tp, all_fp, all_fn = 0, 0, 0
@@ -33,6 +43,13 @@ def evaluate(model, val_dataset_dir, val_annotations_dir, template_dir, matching
template_paths = [os.path.join(template_dir, f) for f in os.listdir(template_dir) if f.endswith('.png')]
layout_image_names = [f for f in os.listdir(val_dataset_dir) if f.endswith('.png')]
if summary_writer:
summary_writer.add_text(
"dataset/info",
f"layouts={len(layout_image_names)}, templates={len(template_paths)}",
global_step,
)
# (已修改) 循环遍历验证集中的每个版图文件
for layout_name in layout_image_names:
print(f"\n正在评估版图: {layout_name}")
@@ -92,6 +109,15 @@ def evaluate(model, val_dataset_dir, val_annotations_dir, template_dir, matching
precision = all_tp / (all_tp + all_fp) if (all_tp + all_fp) > 0 else 0
recall = all_tp / (all_tp + all_fn) if (all_tp + all_fn) > 0 else 0
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
if summary_writer:
summary_writer.add_scalar("metrics/precision", precision, global_step)
summary_writer.add_scalar("metrics/recall", recall, global_step)
summary_writer.add_scalar("metrics/f1", f1, global_step)
summary_writer.add_scalar("counts/true_positive", all_tp, global_step)
summary_writer.add_scalar("counts/false_positive", all_fp, global_step)
summary_writer.add_scalar("counts/false_negative", all_fn, global_step)
return {'precision': precision, 'recall': recall, 'f1': f1}
if __name__ == "__main__":
@@ -101,6 +127,9 @@ if __name__ == "__main__":
parser.add_argument('--val_dir', type=str, default=None, help="验证图像目录,若未提供则使用配置文件中的路径")
parser.add_argument('--annotations_dir', type=str, default=None, help="验证标注目录,若未提供则使用配置文件中的路径")
parser.add_argument('--templates_dir', type=str, default=None, help="模板目录,若未提供则使用配置文件中的路径")
parser.add_argument('--log_dir', type=str, default=None, help="TensorBoard 日志根目录,覆盖配置文件设置")
parser.add_argument('--experiment_name', type=str, default=None, help="TensorBoard 实验名称,覆盖配置文件设置")
parser.add_argument('--disable_tensorboard', action='store_true', help="禁用 TensorBoard 记录")
args = parser.parse_args()
cfg = load_config(args.config)
@@ -108,6 +137,7 @@ if __name__ == "__main__":
paths_cfg = cfg.paths
matching_cfg = cfg.matching
eval_cfg = cfg.evaluation
logging_cfg = cfg.get("logging", None)
model_path = args.model_path or str(to_absolute_path(paths_cfg.model_path, config_dir))
val_dir = args.val_dir or str(to_absolute_path(paths_cfg.val_img_dir, config_dir))
@@ -115,12 +145,48 @@ if __name__ == "__main__":
templates_dir = args.templates_dir or str(to_absolute_path(paths_cfg.template_dir, config_dir))
iou_threshold = float(eval_cfg.iou_threshold)
use_tensorboard = False
log_dir = None
experiment_name = None
if logging_cfg is not None:
use_tensorboard = bool(logging_cfg.get("use_tensorboard", False))
log_dir = logging_cfg.get("log_dir", "runs")
experiment_name = logging_cfg.get("experiment_name", "default")
if args.disable_tensorboard:
use_tensorboard = False
if args.log_dir is not None:
log_dir = args.log_dir
if args.experiment_name is not None:
experiment_name = args.experiment_name
writer = None
if use_tensorboard and log_dir:
log_root = Path(log_dir).expanduser()
exp_folder = experiment_name or "default"
tb_path = log_root / "eval" / exp_folder
tb_path.parent.mkdir(parents=True, exist_ok=True)
writer = SummaryWriter(tb_path.as_posix())
model = RoRD().cuda()
model.load_state_dict(torch.load(model_path))
results = evaluate(model, val_dir, annotations_dir, templates_dir, matching_cfg, iou_threshold)
results = evaluate(
model,
val_dir,
annotations_dir,
templates_dir,
matching_cfg,
iou_threshold,
summary_writer=writer,
global_step=0,
)
print("\n--- 评估结果 ---")
print(f" 精确率 (Precision): {results['precision']:.4f}")
print(f" 召回率 (Recall): {results['recall']:.4f}")
print(f" F1 分数 (F1 Score): {results['f1']:.4f}")
print(f" F1 分数 (F1 Score): {results['f1']:.4f}")
if writer:
writer.add_text("metadata/model_path", model_path)
writer.close()