添加数据增强方案以及扩散生成模型的想法

This commit is contained in:
Jiao77
2025-10-20 21:14:03 +08:00
parent d6d00cf088
commit 08f488f0d8
22 changed files with 1903 additions and 190 deletions

View File

@@ -91,7 +91,12 @@ class RoRD(nn.Module):
# 默认各层通道VGG 对齐)
c2_ch, c3_ch, c4_ch = 128, 256, 512
if backbone_name == "resnet34":
res = models.resnet34(weights=models.ResNet34_Weights.DEFAULT if pretrained else None)
# 构建骨干并按需手动加载权重,便于打印加载摘要
if pretrained:
res = models.resnet34(weights=None)
self._summarize_pretrained_load(res, models.ResNet34_Weights.DEFAULT, "resnet34")
else:
res = models.resnet34(weights=None)
self.backbone = nn.Sequential(
res.conv1, res.bn1, res.relu, res.maxpool,
res.layer1, res.layer2, res.layer3, res.layer4,
@@ -102,14 +107,23 @@ class RoRD(nn.Module):
# 选择 layer2/layer3/layer4 作为 C2/C3/C4
c2_ch, c3_ch, c4_ch = 128, 256, 512
elif backbone_name == "efficientnet_b0":
eff = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT if pretrained else None)
if pretrained:
eff = models.efficientnet_b0(weights=None)
self._summarize_pretrained_load(eff, models.EfficientNet_B0_Weights.DEFAULT, "efficientnet_b0")
else:
eff = models.efficientnet_b0(weights=None)
self.backbone = eff.features
self._backbone_raw = eff
out_channels_backbone = 1280
# 选择 features[2]/[3]/[6] 作为 C2/C3/C4约 24/40/192
c2_ch, c3_ch, c4_ch = 24, 40, 192
else:
vgg16_features = models.vgg16(weights=models.VGG16_Weights.DEFAULT if pretrained else None).features
if pretrained:
vgg = models.vgg16(weights=None)
self._summarize_pretrained_load(vgg, models.VGG16_Weights.DEFAULT, "vgg16")
else:
vgg = models.vgg16(weights=None)
vgg16_features = vgg.features
# VGG16 特征各阶段索引conv & relu 层序列)
# relu2_2 索引 8relu3_3 索引 15relu4_3 索引 22
self.features = vgg16_features
@@ -263,4 +277,33 @@ class RoRD(nn.Module):
x = feats[6](x); c4 = x
return c2, c3, c4
raise RuntimeError(f"Unsupported backbone for FPN: {self.backbone_name}")
raise RuntimeError(f"Unsupported backbone for FPN: {self.backbone_name}")
# --- Utils ---
def _summarize_pretrained_load(self, torch_model: nn.Module, weights_enum, arch_name: str) -> None:
"""手动加载 torchvision 预训练权重并打印加载摘要。
- 使用 strict=False 以兼容可能的键差异,打印 missing/unexpected keys。
- 输出参数量统计,便于快速核对加载情况。
"""
try:
state_dict = weights_enum.get_state_dict(progress=False)
except Exception:
# 回退:若权重枚举不支持 get_state_dict则跳过摘要通常已在构造器中加载
print(f"[Pretrained] {arch_name}: skip summary (weights enum lacks get_state_dict)")
return
incompatible = torch_model.load_state_dict(state_dict, strict=False)
total_params = sum(p.numel() for p in torch_model.parameters())
trainable_params = sum(p.numel() for p in torch_model.parameters() if p.requires_grad)
missing = list(getattr(incompatible, 'missing_keys', []))
unexpected = list(getattr(incompatible, 'unexpected_keys', []))
try:
matched = len(state_dict) - len(unexpected)
except Exception:
matched = 0
print(f"[Pretrained] {arch_name}: ImageNet weights loaded (strict=False)")
print(f" params: total={total_params/1e6:.2f}M, trainable={trainable_params/1e6:.2f}M")
print(f" keys: matched≈{matched} | missing={len(missing)} | unexpected={len(unexpected)}")
if missing and len(missing) <= 10:
print(f" missing: {missing}")
if unexpected and len(unexpected) <= 10:
print(f" unexpected: {unexpected}")