# RoRD 模型训练损失函数详解 本文档详细描述了 RoRD(Robust Layout Representation and Detection)模型训练过程中使用的损失函数设计。 ## 1. 检测损失(Detection Loss) ### 数学公式 $$L_{\text{det}} = \text{BCE}(\text{det}_{\text{original}}, \text{warp}(\text{det}_{\text{rotated}}, H^{-1})) + 0.1 \times \text{SmoothL1}(\text{det}_{\text{original}}, \text{warp}(\text{det}_{\text{rotated}}, H^{-1}))$$ ### 组成说明 - **BCE损失**:二元交叉熵损失,适用于二分类检测任务 - 衡量原始检测图与变换后检测图之间的差异 - 公式: $$\text{BCE}(y, \hat{y}) = -[y \cdot \log(\hat{y}) + (1-y) \cdot \log(1-\hat{y})]$$ - **Smooth L1损失**:平滑L1损失,对异常值更鲁棒 - 公式: $$\text{SmoothL1}(x) = \begin{cases} 0.5x^2 & \text{if } |x| < 1 \\ |x| - 0.5 & \text{otherwise} \end{cases}$$ - 作为BCE损失的辅助正则项 - **权重比例**: - BCE损失:权重 1.0(主导损失) - Smooth L1损失:权重 0.1(辅助正则) ### 空间变换 - **warp操作**:使用逆变换矩阵H⁻¹对特征图进行空间变换对齐 - **实现**:通过`F.affine_grid`和`F.grid_sample`完成 ## 2. 描述子损失(Descriptor Loss) ### Triplet Loss公式 $$L_{\text{desc}} = \max\left(0, \|f(a) - f(p)\|_2^2 - \|f(a) - f(n)\|_2^2 + \text{margin}\right)$$ ### 符号定义 - **a** (anchor):原始图像的描述子特征 - **p** (positive):变换后图像对应位置的描述子特征 - **n** (negative):困难负样本的描述子特征 - **margin**:边界参数,默认值为1.0 - **f(·)**:描述子特征提取函数 ### 采样策略 #### 正样本采样 - **采样方法**:均匀网格采样 - **采样点数**:200个点 - **空间分布**:在特征图上均匀分布,确保训练稳定性 #### 困难负样本挖掘 1. **候选生成**:随机生成负样本坐标点 2. **距离计算**:计算anchor与所有负候选的距离 3. **选择策略**:选择距离最近的负样本作为困难负样本 4. **计算优化**:使用`torch.gather`高效选择 ### 实现细节 - **特征维度**:128维描述子向量 - **归一化**:使用InstanceNorm进行特征归一化 - **距离度量**:L2范数(欧氏距离) - **损失函数**:`nn.TripletMarginLoss(margin=1.0, p=2)` ## 3. 总损失函数 ### 最终公式 $$L_{\text{total}} = L_{\text{det}} + L_{\text{desc}}$$ ### 设计特点 - **无权重平衡**:两个损失直接相加,依靠网络自动学习平衡 - **端到端训练**:检测和描述任务联合优化 - **多任务学习**:同时学习几何变换不变性和特征描述能力 ## 4. 训练策略 ### 损失优化 - **优化器**:Adam优化器 - **学习率**:初始1e-3,使用ReduceLROnPlateau调度 - **梯度裁剪**:max_norm=1.0,防止梯度爆炸 ### 验证指标 - **检测损失**:验证集上的检测任务性能 - **描述子损失**:验证集上的特征匹配性能 - **总损失**:两个损失的加权和 ## 5. 实现代码位置 - **检测损失**:`train.py::compute_detection_loss()`(第126-138行) - **描述子损失**:`train.py::compute_description_loss()`(第140-178行) - **总损失**:`train.py::main()`(第242行) ## 6. 数学符号对照表 | 符号 | 含义 | 维度 | |------|------|------| | det_original | 原始图像检测图 | (B, 1, H, W) | | det_rotated | 变换图像检测图 | (B, 1, H, W) | | desc_original | 原始图像描述子 | (B, 128, H, W) | | desc_rotated | 变换图像描述子 | (B, 128, H, W) | | H | 几何变换矩阵 | (B, 3, 3) | | margin | Triplet Loss边界 | 标量 | | B | 批次大小 | 标量 | | C | 特征维度 | 128 | | H, W | 特征图高宽 | 标量 |