complete code struction update

This commit is contained in:
Jiao77
2025-09-25 20:20:24 +08:00
parent e0b250e77f
commit 8c9926c815
10 changed files with 480 additions and 290 deletions

View File

@@ -1,29 +1,34 @@
# config.py
"""Legacy config shim loading values from YAML."""
from __future__ import annotations
from pathlib import Path
from omegaconf import OmegaConf
_BASE_CONFIG_PATH = Path(__file__).resolve().parent / "configs" / "base_config.yaml"
_CFG = OmegaConf.load(_BASE_CONFIG_PATH)
# --- 训练参数 ---
LEARNING_RATE = 5e-5 # 降低学习率,提高训练稳定性
BATCH_SIZE = 8 # 增加批次大小,提高训练效率
NUM_EPOCHS = 50 # 增加训练轮数
PATCH_SIZE = 256
# (优化) 训练时尺度抖动范围 - 缩小范围提高稳定性
SCALE_JITTER_RANGE = (0.8, 1.2)
LEARNING_RATE = float(_CFG.training.learning_rate)
BATCH_SIZE = int(_CFG.training.batch_size)
NUM_EPOCHS = int(_CFG.training.num_epochs)
PATCH_SIZE = int(_CFG.training.patch_size)
SCALE_JITTER_RANGE = tuple(float(x) for x in _CFG.training.scale_jitter_range)
# --- 匹配与评估参数 ---
KEYPOINT_THRESHOLD = 0.5
RANSAC_REPROJ_THRESHOLD = 5.0
MIN_INLIERS = 15
IOU_THRESHOLD = 0.5
# (新增) 推理时模板匹配的图像金字塔尺度
PYRAMID_SCALES = [0.75, 1.0, 1.5]
# (新增) 推理时处理大版图的滑动窗口参数
INFERENCE_WINDOW_SIZE = 1024
INFERENCE_STRIDE = 768 # 小于INFERENCE_WINDOW_SIZE以保证重叠
KEYPOINT_THRESHOLD = float(_CFG.matching.keypoint_threshold)
RANSAC_REPROJ_THRESHOLD = float(_CFG.matching.ransac_reproj_threshold)
MIN_INLIERS = int(_CFG.matching.min_inliers)
PYRAMID_SCALES = [float(s) for s in _CFG.matching.pyramid_scales]
INFERENCE_WINDOW_SIZE = int(_CFG.matching.inference_window_size)
INFERENCE_STRIDE = int(_CFG.matching.inference_stride)
IOU_THRESHOLD = float(_CFG.evaluation.iou_threshold)
# --- 文件路径 ---
# (路径保持不变, 请根据您的环境修改)
LAYOUT_DIR = 'path/to/layouts'
SAVE_DIR = 'path/to/save'
VAL_IMG_DIR = 'path/to/val/images'
VAL_ANN_DIR = 'path/to/val/annotations'
TEMPLATE_DIR = 'path/to/templates'
MODEL_PATH = 'path/to/save/model_final.pth'
LAYOUT_DIR = str((_BASE_CONFIG_PATH.parent / _CFG.paths.layout_dir).resolve()) if not Path(_CFG.paths.layout_dir).is_absolute() else _CFG.paths.layout_dir
SAVE_DIR = str((_BASE_CONFIG_PATH.parent / _CFG.paths.save_dir).resolve()) if not Path(_CFG.paths.save_dir).is_absolute() else _CFG.paths.save_dir
VAL_IMG_DIR = str((_BASE_CONFIG_PATH.parent / _CFG.paths.val_img_dir).resolve()) if not Path(_CFG.paths.val_img_dir).is_absolute() else _CFG.paths.val_img_dir
VAL_ANN_DIR = str((_BASE_CONFIG_PATH.parent / _CFG.paths.val_ann_dir).resolve()) if not Path(_CFG.paths.val_ann_dir).is_absolute() else _CFG.paths.val_ann_dir
TEMPLATE_DIR = str((_BASE_CONFIG_PATH.parent / _CFG.paths.template_dir).resolve()) if not Path(_CFG.paths.template_dir).is_absolute() else _CFG.paths.template_dir
MODEL_PATH = str((_BASE_CONFIG_PATH.parent / _CFG.paths.model_path).resolve()) if not Path(_CFG.paths.model_path).is_absolute() else _CFG.paths.model_path