Files
RoRD-Layout-Recognation/docs/loss_function.md
2025-07-20 16:06:33 +08:00

3.7 KiB
Raw Blame History

RoRD 模型训练损失函数详解

本文档详细描述了 RoRDRobust Layout Representation and Detection模型训练过程中使用的损失函数设计。

1. 检测损失Detection Loss

数学公式

L_{\text{det}} = \text{BCE}(\text{det}_{\text{original}}, \text{warp}(\text{det}_{\text{rotated}}, H^{-1})) + 0.1 \times \text{SmoothL1}(\text{det}_{\text{original}}, \text{warp}(\text{det}_{\text{rotated}}, H^{-1}))

组成说明

  • BCE损失:二元交叉熵损失,适用于二分类检测任务
    • 衡量原始检测图与变换后检测图之间的差异
    • 公式:
\text{BCE}(y, \hat{y}) = -[y \cdot \log(\hat{y}) + (1-y) \cdot \log(1-\hat{y})]
  • Smooth L1损失平滑L1损失对异常值更鲁棒

    • 公式: $$\text{SmoothL1}(x) = \begin{cases} 0.5x^2 & \text{if } |x| < 1 \ |x| - 0.5 & \text{otherwise} \end{cases}$$
    • 作为BCE损失的辅助正则项
  • 权重比例

    • BCE损失权重 1.0(主导损失)
    • Smooth L1损失权重 0.1(辅助正则)

空间变换

  • warp操作使用逆变换矩阵H⁻¹对特征图进行空间变换对齐
  • 实现:通过F.affine_gridF.grid_sample完成

2. 描述子损失Descriptor Loss

Triplet Loss公式

L_{\text{desc}} = \max\left(0, \|f(a) - f(p)\|_2^2 - \|f(a) - f(n)\|_2^2 + \text{margin}\right)

符号定义

  • a (anchor):原始图像的描述子特征
  • p (positive):变换后图像对应位置的描述子特征
  • n (negative):困难负样本的描述子特征
  • margin边界参数默认值为1.0
  • f(·):描述子特征提取函数

采样策略

正样本采样

  • 采样方法:均匀网格采样
  • 采样点数200个点
  • 空间分布:在特征图上均匀分布,确保训练稳定性

困难负样本挖掘

  1. 候选生成:随机生成负样本坐标点
  2. 距离计算计算anchor与所有负候选的距离
  3. 选择策略:选择距离最近的负样本作为困难负样本
  4. 计算优化:使用torch.gather高效选择

实现细节

  • 特征维度128维描述子向量
  • 归一化使用InstanceNorm进行特征归一化
  • 距离度量L2范数欧氏距离
  • 损失函数nn.TripletMarginLoss(margin=1.0, p=2)

3. 总损失函数

最终公式

L_{\text{total}} = L_{\text{det}} + L_{\text{desc}}

设计特点

  • 无权重平衡:两个损失直接相加,依靠网络自动学习平衡
  • 端到端训练:检测和描述任务联合优化
  • 多任务学习:同时学习几何变换不变性和特征描述能力

4. 训练策略

损失优化

  • 优化器Adam优化器
  • 学习率初始1e-3使用ReduceLROnPlateau调度
  • 梯度裁剪max_norm=1.0,防止梯度爆炸

验证指标

  • 检测损失:验证集上的检测任务性能
  • 描述子损失:验证集上的特征匹配性能
  • 总损失:两个损失的加权和

5. 实现代码位置

  • 检测损失train.py::compute_detection_loss()第126-138行
  • 描述子损失train.py::compute_description_loss()第140-178行
  • 总损失train.py::main()第242行

6. 数学符号对照表

符号 含义 维度
det_original 原始图像检测图 (B, 1, H, W)
det_rotated 变换图像检测图 (B, 1, H, W)
desc_original 原始图像描述子 (B, 128, H, W)
desc_rotated 变换图像描述子 (B, 128, H, W)
H 几何变换矩阵 (B, 3, 3)
margin Triplet Loss边界 标量
B 批次大小 标量
C 特征维度 128
H, W 特征图高宽 标量