增加角度显示计算,优化CNN架构
This commit is contained in:
0
ai_layout_match/dataset.py
Normal file
0
ai_layout_match/dataset.py
Normal file
0
ai_layout_match/models.py
Normal file
0
ai_layout_match/models.py
Normal file
0
ai_layout_match/utils.py
Normal file
0
ai_layout_match/utils.py
Normal file
@@ -19,10 +19,10 @@ def layout_to_tensor(layout_path, target_size=(256, 256)):
|
|||||||
img = img.resize(target_size, resample=Image.BILINEAR)
|
img = img.resize(target_size, resample=Image.BILINEAR)
|
||||||
return np.array(img) / 255.0 # 归一化到[0,1]
|
return np.array(img) / 255.0 # 归一化到[0,1]
|
||||||
|
|
||||||
def tile_layout(large_layout, block_size=64):
|
def tile_layout(large_layout, block_size=64, overlap_ratio=0.5):
|
||||||
"""将大版图分割为小块(滑动窗口方式)"""
|
"""将大版图分割为小块(滑动窗口方式)"""
|
||||||
height, width = large_layout.shape
|
height, width = large_layout.shape
|
||||||
stride = block_size // 2 # 步长设置重叠区域
|
stride = int(block_size * (1 - overlap_ratio)) # 步长设置重叠区域
|
||||||
tiles = []
|
tiles = []
|
||||||
for y in range(0, height - block_size +1, stride):
|
for y in range(0, height - block_size +1, stride):
|
||||||
for x in range(0, width - block_size +1, stride):
|
for x in range(0, width - block_size +1, stride):
|
||||||
|
|||||||
33
inference.py
33
inference.py
@@ -1,15 +1,8 @@
|
|||||||
import faiss
|
import faiss
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
# 导入 models.rotation_cnn 模块中的 RotationInvariantNet 类
|
from models.rotation_cnn import RotationInvariantNet, get_rotational_features
|
||||||
from models.rotation_cnn import RotationInvariantNet
|
from data_units import layout_to_tensor, tile_layout
|
||||||
|
|
||||||
from models.rotation_cnn import get_rotational_features
|
|
||||||
|
|
||||||
# 导入 data_utils 中的 layout_to_tensor 函数(假设该函数存在)
|
|
||||||
from data_units import layout_to_tensor # 如果 data_utils.py 存在此函数
|
|
||||||
|
|
||||||
from data_units import tile_layout
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# 配置参数(需根据实际调整)
|
# 配置参数(需根据实际调整)
|
||||||
@@ -17,7 +10,6 @@ def main():
|
|||||||
target_module_path = "target.png"
|
target_module_path = "target.png"
|
||||||
large_layout_path = "layout_large.png"
|
large_layout_path = "layout_large.png"
|
||||||
|
|
||||||
# 加载模型
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
model = RotationInvariantNet().to(device)
|
model = RotationInvariantNet().to(device)
|
||||||
model.load_state_dict(torch.load("rotation_cnn.pth"))
|
model.load_state_dict(torch.load("rotation_cnn.pth"))
|
||||||
@@ -33,16 +25,27 @@ def main():
|
|||||||
# 构建特征索引(使用Faiss加速)
|
# 构建特征索引(使用Faiss加速)
|
||||||
index = faiss.IndexFlatL2(64) # 特征维度由模型决定
|
index = faiss.IndexFlatL2(64) # 特征维度由模型决定
|
||||||
features_db = []
|
features_db = []
|
||||||
for (x,y,tile) in tiles:
|
for (x, y, tile) in tiles:
|
||||||
feat = get_rotational_features(model, torch.tensor(tile).to(device))
|
feat = get_rotational_features(model, torch.tensor(tile).to(device))
|
||||||
features_db.append(feat)
|
features_db.append(feat)
|
||||||
index.add(np.stack(features_db))
|
index.add(np.stack(features_db))
|
||||||
|
|
||||||
# 检索相似区域
|
# 检索相似区域
|
||||||
D, I = index.search(target_feat[np.newaxis,:], k=10)
|
D, I = index.search(target_feat[np.newaxis, :], k=10)
|
||||||
for idx in I[0]:
|
|
||||||
x,y,_ = tiles[idx]
|
|
||||||
print(f"匹配区域坐标: ({x}, {y}), 相似度: {D[0][idx]}")
|
|
||||||
|
|
||||||
|
for idx in I[0]:
|
||||||
|
x, y, _ = tiles[idx]
|
||||||
|
|
||||||
|
# 计算最佳匹配角度的显式计算
|
||||||
|
min_angle, min_dist = 90, float('inf')
|
||||||
|
target_vec = target_feat
|
||||||
|
feat = features_db[idx]
|
||||||
|
for a in [0, 1, 2, 3]: # 代表0°、90°、180°、270°
|
||||||
|
rotated_feat = np.rot90(feat.reshape(block_size, block_size), k=a)
|
||||||
|
dist = np.linalg.norm(target_vec - rotated_feat.flatten())
|
||||||
|
if dist < min_dist:
|
||||||
|
min_dist, min_angle = dist, a * 90
|
||||||
|
|
||||||
|
print(f"坐标({x},{y}), 最佳旋转方向{min_angle}度,距离: {min_dist}")
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
@@ -3,7 +3,7 @@ import torch.nn as nn
|
|||||||
|
|
||||||
class RotationInvariantNet(nn.Module):
|
class RotationInvariantNet(nn.Module):
|
||||||
"""轻量级旋转不变特征提取网络"""
|
"""轻量级旋转不变特征提取网络"""
|
||||||
def __init__(self, input_channels=1, num_features=64):
|
def __init__(self, input_channels=1):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.cnn = nn.Sequential(
|
self.cnn = nn.Sequential(
|
||||||
# 基础卷积层
|
# 基础卷积层
|
||||||
@@ -12,13 +12,14 @@ class RotationInvariantNet(nn.Module):
|
|||||||
nn.MaxPool2d(2), # 下采样
|
nn.MaxPool2d(2), # 下采样
|
||||||
nn.Conv2d(32, 64, kernel_size=3, padding=1),
|
nn.Conv2d(32, 64, kernel_size=3, padding=1),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.AdaptiveAvgPool2d((1,1)) # 全局池化获取全局特征
|
nn.Conv2d(64, 64, kernel_size=3, stride=2), # 更大感受野
|
||||||
|
nn.AdaptiveAvgPool2d((4,4)), # 全局池化获取全局特征,调整输出尺寸为4x4
|
||||||
|
nn.Flatten(), # 展平为一维向量
|
||||||
|
nn.Linear(64*16, 128) # 增加全连接层以降低维度到128
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
features = self.cnn(x)
|
return self.cnn(x)
|
||||||
return torch.flatten(features, 1) # 展平为特征向量
|
|
||||||
|
|
||||||
def get_rotational_features(model, input_image):
|
def get_rotational_features(model, input_image):
|
||||||
"""计算输入图像所有旋转角度的特征平均值"""
|
"""计算输入图像所有旋转角度的特征平均值"""
|
||||||
rotations = [0, 90, 180, 270]
|
rotations = [0, 90, 180, 270]
|
||||||
|
|||||||
Reference in New Issue
Block a user