Files
LayoutMatch/models/rotation_cnn.py
2025-03-26 22:33:36 +08:00

31 lines
1.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
class RotationInvariantNet(nn.Module):
"""轻量级旋转不变特征提取网络"""
def __init__(self, input_channels=1):
super().__init__()
self.cnn = nn.Sequential(
# 基础卷积层
nn.Conv2d(input_channels, 32, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2), # 下采样
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, stride=2), # 更大感受野
nn.AdaptiveAvgPool2d((4,4)), # 全局池化获取全局特征调整输出尺寸为4x4
nn.Flatten(), # 展平为一维向量
nn.Linear(64*16, 128) # 增加全连接层以降低维度到128
)
def forward(self, x):
return self.cnn(x)
def get_rotational_features(model, input_image):
"""计算输入图像所有旋转角度的特征平均值"""
rotations = [0, 90, 180, 270]
features_list = []
for angle in rotations:
rotated_img = torch.rot90(input_image, k=angle//90, dims=[2,3])
feat = model(rotated_img.unsqueeze(0))
features_list.append(feat)
return torch.mean(torch.stack(features_list), dim=0).detach().numpy()