3.7 KiB
3.7 KiB
RoRD 模型训练损失函数详解
本文档详细描述了 RoRD(Robust Layout Representation and Detection)模型训练过程中使用的损失函数设计。
1. 检测损失(Detection Loss)
数学公式
L_{\text{det}} = \text{BCE}(\text{det}_{\text{original}}, \text{warp}(\text{det}_{\text{rotated}}, H^{-1})) + 0.1 \times \text{SmoothL1}(\text{det}_{\text{original}}, \text{warp}(\text{det}_{\text{rotated}}, H^{-1}))
组成说明
- BCE损失:二元交叉熵损失,适用于二分类检测任务
- 衡量原始检测图与变换后检测图之间的差异
- 公式:
\text{BCE}(y, \hat{y}) = -[y \cdot \log(\hat{y}) + (1-y) \cdot \log(1-\hat{y})]
-
Smooth L1损失:平滑L1损失,对异常值更鲁棒
- 公式: $$\text{SmoothL1}(x) = \begin{cases} 0.5x^2 & \text{if } |x| < 1 \ |x| - 0.5 & \text{otherwise} \end{cases}$$
- 作为BCE损失的辅助正则项
-
权重比例:
- BCE损失:权重 1.0(主导损失)
- Smooth L1损失:权重 0.1(辅助正则)
空间变换
- warp操作:使用逆变换矩阵H⁻¹对特征图进行空间变换对齐
- 实现:通过
F.affine_grid和F.grid_sample完成
2. 描述子损失(Descriptor Loss)
Triplet Loss公式
L_{\text{desc}} = \max\left(0, \|f(a) - f(p)\|_2^2 - \|f(a) - f(n)\|_2^2 + \text{margin}\right)
符号定义
- a (anchor):原始图像的描述子特征
- p (positive):变换后图像对应位置的描述子特征
- n (negative):困难负样本的描述子特征
- margin:边界参数,默认值为1.0
- f(·):描述子特征提取函数
采样策略
正样本采样
- 采样方法:均匀网格采样
- 采样点数:200个点
- 空间分布:在特征图上均匀分布,确保训练稳定性
困难负样本挖掘
- 候选生成:随机生成负样本坐标点
- 距离计算:计算anchor与所有负候选的距离
- 选择策略:选择距离最近的负样本作为困难负样本
- 计算优化:使用
torch.gather高效选择
实现细节
- 特征维度:128维描述子向量
- 归一化:使用InstanceNorm进行特征归一化
- 距离度量:L2范数(欧氏距离)
- 损失函数:
nn.TripletMarginLoss(margin=1.0, p=2)
3. 总损失函数
最终公式
L_{\text{total}} = L_{\text{det}} + L_{\text{desc}}
设计特点
- 无权重平衡:两个损失直接相加,依靠网络自动学习平衡
- 端到端训练:检测和描述任务联合优化
- 多任务学习:同时学习几何变换不变性和特征描述能力
4. 训练策略
损失优化
- 优化器:Adam优化器
- 学习率:初始1e-3,使用ReduceLROnPlateau调度
- 梯度裁剪:max_norm=1.0,防止梯度爆炸
验证指标
- 检测损失:验证集上的检测任务性能
- 描述子损失:验证集上的特征匹配性能
- 总损失:两个损失的加权和
5. 实现代码位置
- 检测损失:
train.py::compute_detection_loss()(第126-138行) - 描述子损失:
train.py::compute_description_loss()(第140-178行) - 总损失:
train.py::main()(第242行)
6. 数学符号对照表
| 符号 | 含义 | 维度 |
|---|---|---|
| det_original | 原始图像检测图 | (B, 1, H, W) |
| det_rotated | 变换图像检测图 | (B, 1, H, W) |
| desc_original | 原始图像描述子 | (B, 128, H, W) |
| desc_rotated | 变换图像描述子 | (B, 128, H, W) |
| H | 几何变换矩阵 | (B, 3, 3) |
| margin | Triplet Loss边界 | 标量 |
| B | 批次大小 | 标量 |
| C | 特征维度 | 128 |
| H, W | 特征图高宽 | 标量 |