finish Experiment Tracking and Evaluation
This commit is contained in:
72
evaluate.py
72
evaluate.py
@@ -7,6 +7,7 @@ from pathlib import Path
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from match import match_template_multiscale
|
||||
from models.rord import RoRD
|
||||
@@ -23,7 +24,16 @@ def compute_iou(box1, box2):
|
||||
return inter_area / union_area if union_area > 0 else 0
|
||||
|
||||
# --- (已修改) 评估函数 ---
|
||||
def evaluate(model, val_dataset_dir, val_annotations_dir, template_dir, matching_cfg, iou_threshold):
|
||||
def evaluate(
|
||||
model,
|
||||
val_dataset_dir,
|
||||
val_annotations_dir,
|
||||
template_dir,
|
||||
matching_cfg,
|
||||
iou_threshold,
|
||||
summary_writer: SummaryWriter | None = None,
|
||||
global_step: int = 0,
|
||||
):
|
||||
model.eval()
|
||||
all_tp, all_fp, all_fn = 0, 0, 0
|
||||
|
||||
@@ -33,6 +43,13 @@ def evaluate(model, val_dataset_dir, val_annotations_dir, template_dir, matching
|
||||
template_paths = [os.path.join(template_dir, f) for f in os.listdir(template_dir) if f.endswith('.png')]
|
||||
layout_image_names = [f for f in os.listdir(val_dataset_dir) if f.endswith('.png')]
|
||||
|
||||
if summary_writer:
|
||||
summary_writer.add_text(
|
||||
"dataset/info",
|
||||
f"layouts={len(layout_image_names)}, templates={len(template_paths)}",
|
||||
global_step,
|
||||
)
|
||||
|
||||
# (已修改) 循环遍历验证集中的每个版图文件
|
||||
for layout_name in layout_image_names:
|
||||
print(f"\n正在评估版图: {layout_name}")
|
||||
@@ -92,6 +109,15 @@ def evaluate(model, val_dataset_dir, val_annotations_dir, template_dir, matching
|
||||
precision = all_tp / (all_tp + all_fp) if (all_tp + all_fp) > 0 else 0
|
||||
recall = all_tp / (all_tp + all_fn) if (all_tp + all_fn) > 0 else 0
|
||||
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
|
||||
|
||||
if summary_writer:
|
||||
summary_writer.add_scalar("metrics/precision", precision, global_step)
|
||||
summary_writer.add_scalar("metrics/recall", recall, global_step)
|
||||
summary_writer.add_scalar("metrics/f1", f1, global_step)
|
||||
summary_writer.add_scalar("counts/true_positive", all_tp, global_step)
|
||||
summary_writer.add_scalar("counts/false_positive", all_fp, global_step)
|
||||
summary_writer.add_scalar("counts/false_negative", all_fn, global_step)
|
||||
|
||||
return {'precision': precision, 'recall': recall, 'f1': f1}
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -101,6 +127,9 @@ if __name__ == "__main__":
|
||||
parser.add_argument('--val_dir', type=str, default=None, help="验证图像目录,若未提供则使用配置文件中的路径")
|
||||
parser.add_argument('--annotations_dir', type=str, default=None, help="验证标注目录,若未提供则使用配置文件中的路径")
|
||||
parser.add_argument('--templates_dir', type=str, default=None, help="模板目录,若未提供则使用配置文件中的路径")
|
||||
parser.add_argument('--log_dir', type=str, default=None, help="TensorBoard 日志根目录,覆盖配置文件设置")
|
||||
parser.add_argument('--experiment_name', type=str, default=None, help="TensorBoard 实验名称,覆盖配置文件设置")
|
||||
parser.add_argument('--disable_tensorboard', action='store_true', help="禁用 TensorBoard 记录")
|
||||
args = parser.parse_args()
|
||||
|
||||
cfg = load_config(args.config)
|
||||
@@ -108,6 +137,7 @@ if __name__ == "__main__":
|
||||
paths_cfg = cfg.paths
|
||||
matching_cfg = cfg.matching
|
||||
eval_cfg = cfg.evaluation
|
||||
logging_cfg = cfg.get("logging", None)
|
||||
|
||||
model_path = args.model_path or str(to_absolute_path(paths_cfg.model_path, config_dir))
|
||||
val_dir = args.val_dir or str(to_absolute_path(paths_cfg.val_img_dir, config_dir))
|
||||
@@ -115,12 +145,48 @@ if __name__ == "__main__":
|
||||
templates_dir = args.templates_dir or str(to_absolute_path(paths_cfg.template_dir, config_dir))
|
||||
iou_threshold = float(eval_cfg.iou_threshold)
|
||||
|
||||
use_tensorboard = False
|
||||
log_dir = None
|
||||
experiment_name = None
|
||||
if logging_cfg is not None:
|
||||
use_tensorboard = bool(logging_cfg.get("use_tensorboard", False))
|
||||
log_dir = logging_cfg.get("log_dir", "runs")
|
||||
experiment_name = logging_cfg.get("experiment_name", "default")
|
||||
|
||||
if args.disable_tensorboard:
|
||||
use_tensorboard = False
|
||||
if args.log_dir is not None:
|
||||
log_dir = args.log_dir
|
||||
if args.experiment_name is not None:
|
||||
experiment_name = args.experiment_name
|
||||
|
||||
writer = None
|
||||
if use_tensorboard and log_dir:
|
||||
log_root = Path(log_dir).expanduser()
|
||||
exp_folder = experiment_name or "default"
|
||||
tb_path = log_root / "eval" / exp_folder
|
||||
tb_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
writer = SummaryWriter(tb_path.as_posix())
|
||||
|
||||
model = RoRD().cuda()
|
||||
model.load_state_dict(torch.load(model_path))
|
||||
|
||||
results = evaluate(model, val_dir, annotations_dir, templates_dir, matching_cfg, iou_threshold)
|
||||
results = evaluate(
|
||||
model,
|
||||
val_dir,
|
||||
annotations_dir,
|
||||
templates_dir,
|
||||
matching_cfg,
|
||||
iou_threshold,
|
||||
summary_writer=writer,
|
||||
global_step=0,
|
||||
)
|
||||
|
||||
print("\n--- 评估结果 ---")
|
||||
print(f" 精确率 (Precision): {results['precision']:.4f}")
|
||||
print(f" 召回率 (Recall): {results['recall']:.4f}")
|
||||
print(f" F1 分数 (F1 Score): {results['f1']:.4f}")
|
||||
print(f" F1 分数 (F1 Score): {results['f1']:.4f}")
|
||||
|
||||
if writer:
|
||||
writer.add_text("metadata/model_path", model_path)
|
||||
writer.close()
|
||||
Reference in New Issue
Block a user