add docs/loss_function.md

This commit is contained in:
Jiao77
2025-07-20 16:06:33 +08:00
parent bfcd63725b
commit eae29ba502

103
docs/loss_function.md Normal file
View File

@@ -0,0 +1,103 @@
# RoRD 模型训练损失函数详解
本文档详细描述了 RoRDRobust Layout Representation and Detection模型训练过程中使用的损失函数设计。
## 1. 检测损失Detection Loss
### 数学公式
$$L_{\text{det}} = \text{BCE}(\text{det}_{\text{original}}, \text{warp}(\text{det}_{\text{rotated}}, H^{-1})) + 0.1 \times \text{SmoothL1}(\text{det}_{\text{original}}, \text{warp}(\text{det}_{\text{rotated}}, H^{-1}))$$
### 组成说明
- **BCE损失**:二元交叉熵损失,适用于二分类检测任务
- 衡量原始检测图与变换后检测图之间的差异
- 公式:
$$\text{BCE}(y, \hat{y}) = -[y \cdot \log(\hat{y}) + (1-y) \cdot \log(1-\hat{y})]$$
- **Smooth L1损失**平滑L1损失对异常值更鲁棒
- 公式:
$$\text{SmoothL1}(x) = \begin{cases}
0.5x^2 & \text{if } |x| < 1 \\
|x| - 0.5 & \text{otherwise}
\end{cases}$$
- 作为BCE损失的辅助正则项
- **权重比例**
- BCE损失权重 1.0主导损失
- Smooth L1损失权重 0.1辅助正则
### 空间变换
- **warp操作**使用逆变换矩阵H⁻¹对特征图进行空间变换对齐
- **实现**通过`F.affine_grid``F.grid_sample`完成
## 2. 描述子损失Descriptor Loss
### Triplet Loss公式
$$L_{\text{desc}} = \max\left(0, \|f(a) - f(p)\|_2^2 - \|f(a) - f(n)\|_2^2 + \text{margin}\right)$$
### 符号定义
- **a** (anchor)原始图像的描述子特征
- **p** (positive)变换后图像对应位置的描述子特征
- **n** (negative)困难负样本的描述子特征
- **margin**边界参数默认值为1.0
- **f(·)**描述子特征提取函数
### 采样策略
#### 正样本采样
- **采样方法**均匀网格采样
- **采样点数**200个点
- **空间分布**在特征图上均匀分布确保训练稳定性
#### 困难负样本挖掘
1. **候选生成**随机生成负样本坐标点
2. **距离计算**计算anchor与所有负候选的距离
3. **选择策略**选择距离最近的负样本作为困难负样本
4. **计算优化**使用`torch.gather`高效选择
### 实现细节
- **特征维度**128维描述子向量
- **归一化**使用InstanceNorm进行特征归一化
- **距离度量**L2范数欧氏距离
- **损失函数**`nn.TripletMarginLoss(margin=1.0, p=2)`
## 3. 总损失函数
### 最终公式
$$L_{\text{total}} = L_{\text{det}} + L_{\text{desc}}$$
### 设计特点
- **无权重平衡**两个损失直接相加依靠网络自动学习平衡
- **端到端训练**检测和描述任务联合优化
- **多任务学习**同时学习几何变换不变性和特征描述能力
## 4. 训练策略
### 损失优化
- **优化器**Adam优化器
- **学习率**初始1e-3使用ReduceLROnPlateau调度
- **梯度裁剪**max_norm=1.0,防止梯度爆炸
### 验证指标
- **检测损失**验证集上的检测任务性能
- **描述子损失**验证集上的特征匹配性能
- **总损失**两个损失的加权和
## 5. 实现代码位置
- **检测损失**`train.py::compute_detection_loss()`第126-138行
- **描述子损失**`train.py::compute_description_loss()`第140-178行
- **总损失**`train.py::main()`第242行
## 6. 数学符号对照表
| 符号 | 含义 | 维度 |
|------|------|------|
| det_original | 原始图像检测图 | (B, 1, H, W) |
| det_rotated | 变换图像检测图 | (B, 1, H, W) |
| desc_original | 原始图像描述子 | (B, 128, H, W) |
| desc_rotated | 变换图像描述子 | (B, 128, H, W) |
| H | 几何变换矩阵 | (B, 3, 3) |
| margin | Triplet Loss边界 | 标量 |
| B | 批次大小 | 标量 |
| C | 特征维度 | 128 |
| H, W | 特征图高宽 | 标量 |