diff --git a/docs/loss_function.md b/docs/loss_function.md new file mode 100644 index 0000000..9aa5e32 --- /dev/null +++ b/docs/loss_function.md @@ -0,0 +1,103 @@ +# RoRD 模型训练损失函数详解 + +本文档详细描述了 RoRD(Robust Layout Representation and Detection)模型训练过程中使用的损失函数设计。 + +## 1. 检测损失(Detection Loss) + +### 数学公式 +$$L_{\text{det}} = \text{BCE}(\text{det}_{\text{original}}, \text{warp}(\text{det}_{\text{rotated}}, H^{-1})) + 0.1 \times \text{SmoothL1}(\text{det}_{\text{original}}, \text{warp}(\text{det}_{\text{rotated}}, H^{-1}))$$ + +### 组成说明 +- **BCE损失**:二元交叉熵损失,适用于二分类检测任务 + - 衡量原始检测图与变换后检测图之间的差异 + - 公式: +$$\text{BCE}(y, \hat{y}) = -[y \cdot \log(\hat{y}) + (1-y) \cdot \log(1-\hat{y})]$$ + +- **Smooth L1损失**:平滑L1损失,对异常值更鲁棒 + - 公式: +$$\text{SmoothL1}(x) = \begin{cases} +0.5x^2 & \text{if } |x| < 1 \\ +|x| - 0.5 & \text{otherwise} +\end{cases}$$ + - 作为BCE损失的辅助正则项 + +- **权重比例**: + - BCE损失:权重 1.0(主导损失) + - Smooth L1损失:权重 0.1(辅助正则) + +### 空间变换 +- **warp操作**:使用逆变换矩阵H⁻¹对特征图进行空间变换对齐 +- **实现**:通过`F.affine_grid`和`F.grid_sample`完成 + +## 2. 描述子损失(Descriptor Loss) + +### Triplet Loss公式 +$$L_{\text{desc}} = \max\left(0, \|f(a) - f(p)\|_2^2 - \|f(a) - f(n)\|_2^2 + \text{margin}\right)$$ + +### 符号定义 +- **a** (anchor):原始图像的描述子特征 +- **p** (positive):变换后图像对应位置的描述子特征 +- **n** (negative):困难负样本的描述子特征 +- **margin**:边界参数,默认值为1.0 +- **f(·)**:描述子特征提取函数 + +### 采样策略 + +#### 正样本采样 +- **采样方法**:均匀网格采样 +- **采样点数**:200个点 +- **空间分布**:在特征图上均匀分布,确保训练稳定性 + +#### 困难负样本挖掘 +1. **候选生成**:随机生成负样本坐标点 +2. **距离计算**:计算anchor与所有负候选的距离 +3. **选择策略**:选择距离最近的负样本作为困难负样本 +4. **计算优化**:使用`torch.gather`高效选择 + +### 实现细节 +- **特征维度**:128维描述子向量 +- **归一化**:使用InstanceNorm进行特征归一化 +- **距离度量**:L2范数(欧氏距离) +- **损失函数**:`nn.TripletMarginLoss(margin=1.0, p=2)` + +## 3. 总损失函数 + +### 最终公式 +$$L_{\text{total}} = L_{\text{det}} + L_{\text{desc}}$$ + +### 设计特点 +- **无权重平衡**:两个损失直接相加,依靠网络自动学习平衡 +- **端到端训练**:检测和描述任务联合优化 +- **多任务学习**:同时学习几何变换不变性和特征描述能力 + +## 4. 训练策略 + +### 损失优化 +- **优化器**:Adam优化器 +- **学习率**:初始1e-3,使用ReduceLROnPlateau调度 +- **梯度裁剪**:max_norm=1.0,防止梯度爆炸 + +### 验证指标 +- **检测损失**:验证集上的检测任务性能 +- **描述子损失**:验证集上的特征匹配性能 +- **总损失**:两个损失的加权和 + +## 5. 实现代码位置 + +- **检测损失**:`train.py::compute_detection_loss()`(第126-138行) +- **描述子损失**:`train.py::compute_description_loss()`(第140-178行) +- **总损失**:`train.py::main()`(第242行) + +## 6. 数学符号对照表 + +| 符号 | 含义 | 维度 | +|------|------|------| +| det_original | 原始图像检测图 | (B, 1, H, W) | +| det_rotated | 变换图像检测图 | (B, 1, H, W) | +| desc_original | 原始图像描述子 | (B, 128, H, W) | +| desc_rotated | 变换图像描述子 | (B, 128, H, W) | +| H | 几何变换矩阵 | (B, 3, 3) | +| margin | Triplet Loss边界 | 标量 | +| B | 批次大小 | 标量 | +| C | 特征维度 | 128 | +| H, W | 特征图高宽 | 标量 | \ No newline at end of file