第二次大修

This commit is contained in:
Jiao77
2025-06-08 15:38:56 +08:00
parent 53ef1ec99c
commit f0b2e1b605
10 changed files with 315 additions and 508 deletions

0
models/__init__.py Normal file
View File

View File

@@ -1,16 +1,25 @@
# models/rord.py
import torch
import torch.nn as nn
from torchvision import models
class RoRD(nn.Module):
def __init__(self):
"""
修复后的 RoRD 模型。
- 实现了共享骨干网络,以提高计算效率和减少内存占用。
- 移除了冗余的 descriptor_head_vanilla。
"""
super(RoRD, self).__init__()
# 检测骨干网络VGG-16 直到 relu5_3层 0 到 29
self.backbone_det = models.vgg16(pretrained=True).features[:30]
# 描述骨干网络VGG-16 直到 relu4_3层 0 到 22
self.backbone_desc = models.vgg16(pretrained=True).features[:23]
# 检测头:输出关键点概率图
vgg16_features = models.vgg16(pretrained=True).features
# 共享骨干网络
self.slice1 = vgg16_features[:23] # 到 relu4_3
self.slice2 = vgg16_features[23:30] # 从 relu4_3 到 relu5_3
# 检测头
self.detection_head = nn.Sequential(
nn.Conv2d(512, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
@@ -18,16 +27,8 @@ class RoRD(nn.Module):
nn.Sigmoid()
)
# 普通描述子头D2-Net 风格)
self.descriptor_head_vanilla = nn.Sequential(
nn.Conv2d(512, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 128, kernel_size=1),
nn.InstanceNorm2d(128)
)
# RoRD 描述子头(旋转鲁棒)
self.descriptor_head_rord = nn.Sequential(
# 描述子头
self.descriptor_head = nn.Sequential(
nn.Conv2d(512, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 128, kernel_size=1),
@@ -35,13 +36,14 @@ class RoRD(nn.Module):
)
def forward(self, x):
# 检测分支
features_det = self.backbone_det(x)
detection = self.detection_head(features_det)
# 共享特征提取
features_shared = self.slice1(x)
# 描述分支
features_desc = self.backbone_desc(x)
desc_vanilla = self.descriptor_head_vanilla(features_desc)
desc_rord = self.descriptor_head_rord(features_desc)
# 描述分支
descriptors = self.descriptor_head(features_shared)
return detection, desc_vanilla, desc_rord
# 检测器分支
features_det = self.slice2(features_shared)
detection_map = self.detection_head(features_det)
return detection_map, descriptors