一个目标实时检测的模型

This commit is contained in:
jiao77
2025-03-31 14:49:04 +08:00
parent 956805997e
commit a5c63ad0de
6 changed files with 188 additions and 151 deletions

View File

@@ -1,51 +1,70 @@
import faiss
import numpy as np
import torch import torch
from models.rotation_cnn import RotationInvariantNet, get_rotational_features import cv2
from data_units import layout_to_tensor, tile_layout import numpy as np
from models.superpoint_custom import SuperPointCustom
def main(): def get_keypoints_from_heatmap(semi, threshold=0.015):
# 配置参数(需根据实际调整) semi = semi.squeeze().cpu().numpy() # [65, H/8, W/8]
block_size = 64 # 分块尺寸 prob = cv2.softmax(semi, axis=0)[:-1] # [64, H/8, W/8]
target_module_path = "target.png" prob = prob.reshape(8, 8, semi.shape[1], semi.shape[2])
large_layout_path = "layout_large.png" prob = prob.transpose(0, 2, 1, 3).reshape(8*semi.shape[1], 8*semi.shape[2]) # [H, W]
keypoints = []
for y in range(prob.shape[0]):
for x in range(prob.shape[1]):
if prob[y, x] > threshold:
keypoints.append(cv2.KeyPoint(x, y, 1))
return keypoints
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def get_descriptors_from_map(desc, keypoints):
model = RotationInvariantNet().to(device) desc = desc.squeeze().cpu().numpy() # [256, H/8, W/8]
model.load_state_dict(torch.load("rotation_cnn.pth")) descriptors = []
scale = 8
for kp in keypoints:
x, y = int(kp.pt[0] / scale), int(kp.pt[1] / scale)
if 0 <= x < desc.shape[2] and 0 <= y < desc.shape[1]:
descriptors.append(desc[:, y, x])
return np.array(descriptors)
def match_and_estimate(layout_path, module_path, model_path, num_channels, device='cuda'):
model = SuperPointCustom(num_channels=num_channels).to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval() model.eval()
# 预处理目标模块与大版图 layout = np.load(layout_path) # [C, H, W]
target_tensor = layout_to_tensor(target_module_path, (block_size, block_size)) module = np.load(module_path) # [C, H, W]
target_feat = get_rotational_features(model, torch.tensor(target_tensor).to(device)) layout_tensor = torch.from_numpy(layout).float().unsqueeze(0).to(device)
module_tensor = torch.from_numpy(module).float().unsqueeze(0).to(device)
large_layout = layout_to_tensor(large_layout_path) with torch.no_grad():
tiles = tile_layout(large_layout) semi_layout, desc_layout = model(layout_tensor)
semi_module, desc_module = model(module_tensor)
kp_layout = get_keypoints_from_heatmap(semi_layout)
desc_layout = get_descriptors_from_map(desc_layout, kp_layout)
kp_module = get_keypoints_from_heatmap(semi_module)
desc_module = get_descriptors_from_map(desc_module, kp_module)
# 构建特征索引使用Faiss加速 bf = cv2.BFMatcher(cv2.NORM_L2, crossCheck=True)
index = faiss.IndexFlatL2(64) # 特征维度由模型决定 matches = bf.match(desc_module, desc_layout)
features_db = [] matches = sorted(matches, key=lambda x: x.distance)
for (x, y, tile) in tiles:
feat = get_rotational_features(model, torch.tensor(tile).to(device))
features_db.append(feat)
index.add(np.stack(features_db))
# 检索相似区域 src_pts = np.float32([kp_module[m.queryIdx].pt for m in matches]).reshape(-1, 1, 2)
D, I = index.search(target_feat[np.newaxis, :], k=10) dst_pts = np.float32([kp_layout[m.trainIdx].pt for m in matches]).reshape(-1, 1, 2)
H, _ = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 5.0)
for idx in I[0]: h, w = module.shape[1], module.shape[2]
x, y, _ = tiles[idx] corners = np.float32([[0, 0], [w, 0], [w, h], [0, h]]).reshape(-1, 1, 2)
transformed_corners = cv2.perspectiveTransform(corners, H)
x_min, y_min = np.min(transformed_corners, axis=0).ravel().astype(int)
x_max, y_max = np.max(transformed_corners, axis=0).ravel().astype(int)
theta = np.arctan2(H[1, 0], H[0, 0]) * 180 / np.pi
# 计算最佳匹配角度的显式计算 print(f"Matched region: [{x_min}, {y_min}, {x_max}, {y_max}], Rotation: {theta:.2f} degrees")
min_angle, min_dist = 90, float('inf') return x_min, y_min, x_max, y_max, theta
target_vec = target_feat
feat = features_db[idx]
for a in [0, 1, 2, 3]: # 代表0°、90°、180°、270°
rotated_feat = np.rot90(feat.reshape(block_size, block_size), k=a)
dist = np.linalg.norm(target_vec - rotated_feat.flatten())
if dist < min_dist:
min_dist, min_angle = dist, a * 90
print(f"坐标({x},{y}), 最佳旋转方向{min_angle}度,距离: {min_dist}")
if __name__ == "__main__": if __name__ == "__main__":
main() layout_path = "data/large_layout.npy"
module_path = "data/small_module.npy"
model_path = "superpoint_custom_model.pth"
num_channels = 3 # 替换为实际通道数
match_and_estimate(layout_path, module_path, model_path, num_channels)

View File

@@ -1,31 +0,0 @@
import torch
import torch.nn as nn
class RotationInvariantNet(nn.Module):
"""轻量级旋转不变特征提取网络"""
def __init__(self, input_channels=1):
super().__init__()
self.cnn = nn.Sequential(
# 基础卷积层
nn.Conv2d(input_channels, 32, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2), # 下采样
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, stride=2), # 更大感受野
nn.AdaptiveAvgPool2d((4,4)), # 全局池化获取全局特征调整输出尺寸为4x4
nn.Flatten(), # 展平为一维向量
nn.Linear(64*16, 128) # 增加全连接层以降低维度到128
)
def forward(self, x):
return self.cnn(x)
def get_rotational_features(model, input_image):
"""计算输入图像所有旋转角度的特征平均值"""
rotations = [0, 90, 180, 270]
features_list = []
for angle in rotations:
rotated_img = torch.rot90(input_image, k=angle//90, dims=[2,3])
feat = model(rotated_img.unsqueeze(0))
features_list.append(feat)
return torch.mean(torch.stack(features_list), dim=0).detach().numpy()

View File

@@ -0,0 +1,47 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class SuperPointCustom(nn.Module):
def __init__(self, num_channels=3): # num_channels 为版图通道数
super(SuperPointCustom, self).__init__()
self.relu = nn.ReLU(inplace=True)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
c1, c2, c3, c4, c5, d1 = 64, 64, 128, 128, 256, 256
# 编码器
self.conv1a = nn.Conv2d(num_channels, c1, kernel_size=3, stride=1, padding=1)
self.conv1b = nn.Conv2d(c1, c1, kernel_size=3, stride=1, padding=1)
self.conv2a = nn.Conv2d(c1, c2, kernel_size=3, stride=1, padding=1)
self.conv2b = nn.Conv2d(c2, c2, kernel_size=3, stride=1, padding=1)
self.conv3a = nn.Conv2d(c2, c3, kernel_size=3, stride=1, padding=1)
self.conv3b = nn.Conv2d(c3, c3, kernel_size=3, stride=1, padding=1)
self.conv4a = nn.Conv2d(c3, c4, kernel_size=3, stride=1, padding=1)
self.conv4b = nn.Conv2d(c4, c4, kernel_size=3, stride=1, padding=1)
# 检测头
self.convPa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
self.convPb = nn.Conv2d(c5, 65, kernel_size=1, stride=1, padding=0) # 65 = 8x8 + dustbin
# 描述符头
self.convDa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
self.convDb = nn.Conv2d(c5, d1, kernel_size=1, stride=1, padding=0)
def forward(self, x):
# 编码器
x = self.relu(self.conv1a(x))
x = self.relu(self.conv1b(x))
x = self.pool(x)
x = self.relu(self.conv2a(x))
x = self.relu(self.conv2b(x))
x = self.pool(x)
x = self.relu(self.conv3a(x))
x = self.relu(self.conv3b(x))
x = self.pool(x)
x = self.relu(self.conv4a(x))
x = self.relu(self.conv4b(x))
# 检测头
cPa = self.relu(self.convPa(x))
semi = self.convPb(cPa) # [B, 65, H/8, W/8]
# 描述符头
cDa = self.relu(self.convDa(x))
desc = self.convDb(cDa) # [B, 256, H/8, W/8]
desc = F.normalize(desc, p=2, dim=1) # L2归一化
return semi, desc

131
train.py
View File

@@ -1,99 +1,68 @@
import os import os
import torch import torch
from torch import nn, optim import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split from torch.utils.data import Dataset, DataLoader
import numpy as np import numpy as np
from datetime import datetime from models.superpoint_custom import SuperPointCustom
import argparse from utils.data_augmentation import generate_training_pair
# 导入项目模块(根据你的路径调整) class BinaryDataset(Dataset):
from models.rotation_cnn import RotationInvariantCNN # 模型实现 def __init__(self, image_dir, patch_size, num_channels):
from data_units import LayoutDataset, layout_transforms # 数据集和预处理函数 self.image_dir = image_dir
self.patch_size = patch_size
self.num_channels = num_channels
self.image_paths = [os.path.join(image_dir, f) for f in os.listdir(image_dir)
if f.endswith('.npy')]
# 设置随机种子(可选) def __len__(self):
torch.manual_seed(42) return len(self.image_paths)
np.random.seed(42)
def main(): def __getitem__(self, idx):
"""训练流程""" img_path = self.image_paths[idx]
# 解析命令行参数 image = np.load(img_path) # [C, H, W]
parser = argparse.ArgumentParser(description="Train Rotation-Invariant Layout Matcher") patch, warped_patch, H = generate_training_pair(image, self.patch_size)
parser.add_argument("--data_dir", type=str, default="./data/train/", help="训练数据目录") patch = torch.from_numpy(patch).float()
parser.add_argument("--val_split", type=float, default=0.2, help="验证集比例") warped_patch = torch.from_numpy(warped_patch).float()
parser.add_argument("--batch_size", type=int, default=16, help="批量大小") return patch, warped_patch, torch.from_numpy(H).float()
parser.add_argument("--epochs", type=int, default=50, help="训练轮次")
parser.add_argument("--lr", type=float, default=1e-3, help="学习率")
parser.add_argument("--model_save_dir", type=str, default="./models/", help="模型保存路径")
args = parser.parse_args()
# 创建输出目录 def simple_detector_loss(semi, semi_w, H, device):
os.makedirs(args.model_save_dir, exist_ok=True) return F.mse_loss(semi, semi_w) # 简化版,实际需更复杂实现
# 数据加载 def simple_descriptor_loss(desc, desc_w, H, device):
dataset = LayoutDataset(root_dir=args.data_dir, transform=layout_transforms()) return F.mse_loss(desc, desc_w) # 简化版,实际需更复杂实现
total_samples = len(dataset)
val_size = int(total_samples * args.val_split)
train_size = total_samples - val_size
# 划分训练集和验证集 def train():
train_dataset, val_dataset = random_split( device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dataset, batch_size = 8
[train_size, val_size], learning_rate = 0.001
generator=torch.Generator().manual_seed(42) num_epochs = 10
) image_dir = 'data/train_images' # 替换为实际路径
patch_size = 256
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4) num_channels = 3 # 替换为实际通道数
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)
# 初始化模型、损失函数和优化器 dataset = BinaryDataset(image_dir, patch_size, num_channels)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
model = RotationInvariantCNN().to(device) # 根据你的模型结构调整参数
criterion = nn.CrossEntropyLoss() # 分类任务示例,根据任务类型选择损失函数
optimizer = optim.Adam(model.parameters(), lr=args.lr)
# 训练循环 model = SuperPointCustom(num_channels=num_channels).to(device)
best_val_loss = float("inf") optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
for epoch in range(1, args.epochs + 1):
for epoch in range(num_epochs):
model.train() model.train()
train_loss = 0.0 total_loss = 0
for batch_idx, (data, targets) in enumerate(train_loader): for patch, warped_patch, H in dataloader:
data, targets = data.to(device), targets.to(device) patch, warped_patch, H = patch.to(device), warped_patch.to(device), H.to(device)
semi, desc = model(patch)
# 前向传播 semi_w, desc_w = model(warped_patch)
outputs = model(data) det_loss = simple_detector_loss(semi, semi_w, H, device)
loss = criterion(outputs, targets) desc_loss = simple_descriptor_loss(desc, desc_w, H, device)
loss = det_loss + desc_loss
# 反向传播和优化
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
total_loss += loss.item()
train_loss += loss.item() print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss / len(dataloader):.4f}")
if (batch_idx + 1) % 10 == 0:
print(f"Epoch [{epoch}/{args.epochs}] Batch {batch_idx+1}/{len(train_loader)} Loss: {loss.item():.4f}")
# 验证
model.eval()
val_loss = 0.0
with torch.no_grad():
for data, targets in val_loader:
data, targets = data.to(device), targets.to(device)
outputs = model(data)
loss = criterion(outputs, targets)
val_loss += loss.item()
avg_train_loss = train_loss / len(train_loader)
avg_val_loss = val_loss / len(val_loader)
print(f"Epoch {epoch} - Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")
# 保存最佳模型
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
torch.save(model.state_dict(), os.path.join(args.model_save_dir, f"best_model_{datetime.now().strftime('%Y%m%d%H%M')}.pth"))
print("训练完成!") torch.save(model.state_dict(), 'superpoint_custom_model.pth')
if __name__ == "__main__": if __name__ == "__main__":
main() train()

0
utils/__init__.py Normal file
View File

View File

@@ -0,0 +1,33 @@
import numpy as np
import cv2
def get_random_patch(image, patch_size):
h, w = image.shape[1:3]
x = np.random.randint(0, w - patch_size)
y = np.random.randint(0, h - patch_size)
return image[:, y:y+patch_size, x:x+patch_size]
def get_random_homography(max_rotation=30, max_translation=20):
theta = np.random.uniform(-max_rotation, max_rotation) * np.pi / 180.0
tx = np.random.uniform(-max_translation, max_translation)
ty = np.random.uniform(-max_translation, max_translation)
cos_theta = np.cos(theta)
sin_theta = np.sin(theta)
H = np.array([
[cos_theta, -sin_theta, tx],
[sin_theta, cos_theta, ty],
[0, 0, 1]
])
return H
def apply_homography_to_image(image, H, output_size):
warped = np.zeros_like(image)
for c in range(image.shape[0]):
warped[c] = cv2.warpPerspective(image[c], H, output_size, flags=cv2.INTER_NEAREST)
return warped
def generate_training_pair(image, patch_size):
patch = get_random_patch(image, patch_size)
H = get_random_homography()
warped_patch = apply_homography_to_image(patch, H, (patch_size, patch_size))
return patch, warped_patch, H