initial commit

This commit is contained in:
jiao77
2025-03-25 01:42:26 +08:00
commit 88ca482d5d
5 changed files with 208 additions and 0 deletions

0
models/__init__.py Normal file
View File

30
models/rotation_cnn.py Normal file
View File

@@ -0,0 +1,30 @@
import torch
import torch.nn as nn
class RotationInvariantNet(nn.Module):
"""轻量级旋转不变特征提取网络"""
def __init__(self, input_channels=1, num_features=64):
super().__init__()
self.cnn = nn.Sequential(
# 基础卷积层
nn.Conv2d(input_channels, 32, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2), # 下采样
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.AdaptiveAvgPool2d((1,1)) # 全局池化获取全局特征
)
def forward(self, x):
features = self.cnn(x)
return torch.flatten(features, 1) # 展平为特征向量
def get_rotational_features(model, input_image):
"""计算输入图像所有旋转角度的特征平均值"""
rotations = [0, 90, 180, 270]
features_list = []
for angle in rotations:
rotated_img = torch.rot90(input_image, k=angle//90, dims=[2,3])
feat = model(rotated_img.unsqueeze(0))
features_list.append(feat)
return torch.mean(torch.stack(features_list), dim=0).detach().numpy()