一个目标实时检测的模型
This commit is contained in:
97
inference.py
97
inference.py
@@ -1,51 +1,70 @@
|
|||||||
import faiss
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
from models.rotation_cnn import RotationInvariantNet, get_rotational_features
|
import cv2
|
||||||
from data_units import layout_to_tensor, tile_layout
|
import numpy as np
|
||||||
|
from models.superpoint_custom import SuperPointCustom
|
||||||
|
|
||||||
def main():
|
def get_keypoints_from_heatmap(semi, threshold=0.015):
|
||||||
# 配置参数(需根据实际调整)
|
semi = semi.squeeze().cpu().numpy() # [65, H/8, W/8]
|
||||||
block_size = 64 # 分块尺寸
|
prob = cv2.softmax(semi, axis=0)[:-1] # [64, H/8, W/8]
|
||||||
target_module_path = "target.png"
|
prob = prob.reshape(8, 8, semi.shape[1], semi.shape[2])
|
||||||
large_layout_path = "layout_large.png"
|
prob = prob.transpose(0, 2, 1, 3).reshape(8*semi.shape[1], 8*semi.shape[2]) # [H, W]
|
||||||
|
keypoints = []
|
||||||
|
for y in range(prob.shape[0]):
|
||||||
|
for x in range(prob.shape[1]):
|
||||||
|
if prob[y, x] > threshold:
|
||||||
|
keypoints.append(cv2.KeyPoint(x, y, 1))
|
||||||
|
return keypoints
|
||||||
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
def get_descriptors_from_map(desc, keypoints):
|
||||||
model = RotationInvariantNet().to(device)
|
desc = desc.squeeze().cpu().numpy() # [256, H/8, W/8]
|
||||||
model.load_state_dict(torch.load("rotation_cnn.pth"))
|
descriptors = []
|
||||||
|
scale = 8
|
||||||
|
for kp in keypoints:
|
||||||
|
x, y = int(kp.pt[0] / scale), int(kp.pt[1] / scale)
|
||||||
|
if 0 <= x < desc.shape[2] and 0 <= y < desc.shape[1]:
|
||||||
|
descriptors.append(desc[:, y, x])
|
||||||
|
return np.array(descriptors)
|
||||||
|
|
||||||
|
def match_and_estimate(layout_path, module_path, model_path, num_channels, device='cuda'):
|
||||||
|
model = SuperPointCustom(num_channels=num_channels).to(device)
|
||||||
|
model.load_state_dict(torch.load(model_path, map_location=device))
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
# 预处理目标模块与大版图
|
layout = np.load(layout_path) # [C, H, W]
|
||||||
target_tensor = layout_to_tensor(target_module_path, (block_size, block_size))
|
module = np.load(module_path) # [C, H, W]
|
||||||
target_feat = get_rotational_features(model, torch.tensor(target_tensor).to(device))
|
layout_tensor = torch.from_numpy(layout).float().unsqueeze(0).to(device)
|
||||||
|
module_tensor = torch.from_numpy(module).float().unsqueeze(0).to(device)
|
||||||
|
|
||||||
large_layout = layout_to_tensor(large_layout_path)
|
with torch.no_grad():
|
||||||
tiles = tile_layout(large_layout)
|
semi_layout, desc_layout = model(layout_tensor)
|
||||||
|
semi_module, desc_module = model(module_tensor)
|
||||||
|
|
||||||
|
kp_layout = get_keypoints_from_heatmap(semi_layout)
|
||||||
|
desc_layout = get_descriptors_from_map(desc_layout, kp_layout)
|
||||||
|
kp_module = get_keypoints_from_heatmap(semi_module)
|
||||||
|
desc_module = get_descriptors_from_map(desc_module, kp_module)
|
||||||
|
|
||||||
# 构建特征索引(使用Faiss加速)
|
bf = cv2.BFMatcher(cv2.NORM_L2, crossCheck=True)
|
||||||
index = faiss.IndexFlatL2(64) # 特征维度由模型决定
|
matches = bf.match(desc_module, desc_layout)
|
||||||
features_db = []
|
matches = sorted(matches, key=lambda x: x.distance)
|
||||||
for (x, y, tile) in tiles:
|
|
||||||
feat = get_rotational_features(model, torch.tensor(tile).to(device))
|
|
||||||
features_db.append(feat)
|
|
||||||
index.add(np.stack(features_db))
|
|
||||||
|
|
||||||
# 检索相似区域
|
src_pts = np.float32([kp_module[m.queryIdx].pt for m in matches]).reshape(-1, 1, 2)
|
||||||
D, I = index.search(target_feat[np.newaxis, :], k=10)
|
dst_pts = np.float32([kp_layout[m.trainIdx].pt for m in matches]).reshape(-1, 1, 2)
|
||||||
|
H, _ = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 5.0)
|
||||||
|
|
||||||
for idx in I[0]:
|
h, w = module.shape[1], module.shape[2]
|
||||||
x, y, _ = tiles[idx]
|
corners = np.float32([[0, 0], [w, 0], [w, h], [0, h]]).reshape(-1, 1, 2)
|
||||||
|
transformed_corners = cv2.perspectiveTransform(corners, H)
|
||||||
|
x_min, y_min = np.min(transformed_corners, axis=0).ravel().astype(int)
|
||||||
|
x_max, y_max = np.max(transformed_corners, axis=0).ravel().astype(int)
|
||||||
|
theta = np.arctan2(H[1, 0], H[0, 0]) * 180 / np.pi
|
||||||
|
|
||||||
# 计算最佳匹配角度的显式计算
|
print(f"Matched region: [{x_min}, {y_min}, {x_max}, {y_max}], Rotation: {theta:.2f} degrees")
|
||||||
min_angle, min_dist = 90, float('inf')
|
return x_min, y_min, x_max, y_max, theta
|
||||||
target_vec = target_feat
|
|
||||||
feat = features_db[idx]
|
|
||||||
for a in [0, 1, 2, 3]: # 代表0°、90°、180°、270°
|
|
||||||
rotated_feat = np.rot90(feat.reshape(block_size, block_size), k=a)
|
|
||||||
dist = np.linalg.norm(target_vec - rotated_feat.flatten())
|
|
||||||
if dist < min_dist:
|
|
||||||
min_dist, min_angle = dist, a * 90
|
|
||||||
|
|
||||||
print(f"坐标({x},{y}), 最佳旋转方向{min_angle}度,距离: {min_dist}")
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
layout_path = "data/large_layout.npy"
|
||||||
|
module_path = "data/small_module.npy"
|
||||||
|
model_path = "superpoint_custom_model.pth"
|
||||||
|
num_channels = 3 # 替换为实际通道数
|
||||||
|
match_and_estimate(layout_path, module_path, model_path, num_channels)
|
||||||
@@ -1,31 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
class RotationInvariantNet(nn.Module):
|
|
||||||
"""轻量级旋转不变特征提取网络"""
|
|
||||||
def __init__(self, input_channels=1):
|
|
||||||
super().__init__()
|
|
||||||
self.cnn = nn.Sequential(
|
|
||||||
# 基础卷积层
|
|
||||||
nn.Conv2d(input_channels, 32, kernel_size=3, padding=1),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.MaxPool2d(2), # 下采样
|
|
||||||
nn.Conv2d(32, 64, kernel_size=3, padding=1),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Conv2d(64, 64, kernel_size=3, stride=2), # 更大感受野
|
|
||||||
nn.AdaptiveAvgPool2d((4,4)), # 全局池化获取全局特征,调整输出尺寸为4x4
|
|
||||||
nn.Flatten(), # 展平为一维向量
|
|
||||||
nn.Linear(64*16, 128) # 增加全连接层以降低维度到128
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.cnn(x)
|
|
||||||
def get_rotational_features(model, input_image):
|
|
||||||
"""计算输入图像所有旋转角度的特征平均值"""
|
|
||||||
rotations = [0, 90, 180, 270]
|
|
||||||
features_list = []
|
|
||||||
for angle in rotations:
|
|
||||||
rotated_img = torch.rot90(input_image, k=angle//90, dims=[2,3])
|
|
||||||
feat = model(rotated_img.unsqueeze(0))
|
|
||||||
features_list.append(feat)
|
|
||||||
return torch.mean(torch.stack(features_list), dim=0).detach().numpy()
|
|
||||||
47
models/superpoint_custom.py
Normal file
47
models/superpoint_custom.py
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
class SuperPointCustom(nn.Module):
|
||||||
|
def __init__(self, num_channels=3): # num_channels 为版图通道数
|
||||||
|
super(SuperPointCustom, self).__init__()
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||||
|
c1, c2, c3, c4, c5, d1 = 64, 64, 128, 128, 256, 256
|
||||||
|
# 编码器
|
||||||
|
self.conv1a = nn.Conv2d(num_channels, c1, kernel_size=3, stride=1, padding=1)
|
||||||
|
self.conv1b = nn.Conv2d(c1, c1, kernel_size=3, stride=1, padding=1)
|
||||||
|
self.conv2a = nn.Conv2d(c1, c2, kernel_size=3, stride=1, padding=1)
|
||||||
|
self.conv2b = nn.Conv2d(c2, c2, kernel_size=3, stride=1, padding=1)
|
||||||
|
self.conv3a = nn.Conv2d(c2, c3, kernel_size=3, stride=1, padding=1)
|
||||||
|
self.conv3b = nn.Conv2d(c3, c3, kernel_size=3, stride=1, padding=1)
|
||||||
|
self.conv4a = nn.Conv2d(c3, c4, kernel_size=3, stride=1, padding=1)
|
||||||
|
self.conv4b = nn.Conv2d(c4, c4, kernel_size=3, stride=1, padding=1)
|
||||||
|
# 检测头
|
||||||
|
self.convPa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
|
||||||
|
self.convPb = nn.Conv2d(c5, 65, kernel_size=1, stride=1, padding=0) # 65 = 8x8 + dustbin
|
||||||
|
# 描述符头
|
||||||
|
self.convDa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
|
||||||
|
self.convDb = nn.Conv2d(c5, d1, kernel_size=1, stride=1, padding=0)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# 编码器
|
||||||
|
x = self.relu(self.conv1a(x))
|
||||||
|
x = self.relu(self.conv1b(x))
|
||||||
|
x = self.pool(x)
|
||||||
|
x = self.relu(self.conv2a(x))
|
||||||
|
x = self.relu(self.conv2b(x))
|
||||||
|
x = self.pool(x)
|
||||||
|
x = self.relu(self.conv3a(x))
|
||||||
|
x = self.relu(self.conv3b(x))
|
||||||
|
x = self.pool(x)
|
||||||
|
x = self.relu(self.conv4a(x))
|
||||||
|
x = self.relu(self.conv4b(x))
|
||||||
|
# 检测头
|
||||||
|
cPa = self.relu(self.convPa(x))
|
||||||
|
semi = self.convPb(cPa) # [B, 65, H/8, W/8]
|
||||||
|
# 描述符头
|
||||||
|
cDa = self.relu(self.convDa(x))
|
||||||
|
desc = self.convDb(cDa) # [B, 256, H/8, W/8]
|
||||||
|
desc = F.normalize(desc, p=2, dim=1) # L2归一化
|
||||||
|
return semi, desc
|
||||||
131
train.py
131
train.py
@@ -1,99 +1,68 @@
|
|||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
from torch import nn, optim
|
import torch.nn.functional as F
|
||||||
from torch.utils.data import DataLoader, random_split
|
from torch.utils.data import Dataset, DataLoader
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from datetime import datetime
|
from models.superpoint_custom import SuperPointCustom
|
||||||
import argparse
|
from utils.data_augmentation import generate_training_pair
|
||||||
|
|
||||||
# 导入项目模块(根据你的路径调整)
|
class BinaryDataset(Dataset):
|
||||||
from models.rotation_cnn import RotationInvariantCNN # 模型实现
|
def __init__(self, image_dir, patch_size, num_channels):
|
||||||
from data_units import LayoutDataset, layout_transforms # 数据集和预处理函数
|
self.image_dir = image_dir
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.num_channels = num_channels
|
||||||
|
self.image_paths = [os.path.join(image_dir, f) for f in os.listdir(image_dir)
|
||||||
|
if f.endswith('.npy')]
|
||||||
|
|
||||||
# 设置随机种子(可选)
|
def __len__(self):
|
||||||
torch.manual_seed(42)
|
return len(self.image_paths)
|
||||||
np.random.seed(42)
|
|
||||||
|
|
||||||
def main():
|
def __getitem__(self, idx):
|
||||||
"""训练流程"""
|
img_path = self.image_paths[idx]
|
||||||
# 解析命令行参数
|
image = np.load(img_path) # [C, H, W]
|
||||||
parser = argparse.ArgumentParser(description="Train Rotation-Invariant Layout Matcher")
|
patch, warped_patch, H = generate_training_pair(image, self.patch_size)
|
||||||
parser.add_argument("--data_dir", type=str, default="./data/train/", help="训练数据目录")
|
patch = torch.from_numpy(patch).float()
|
||||||
parser.add_argument("--val_split", type=float, default=0.2, help="验证集比例")
|
warped_patch = torch.from_numpy(warped_patch).float()
|
||||||
parser.add_argument("--batch_size", type=int, default=16, help="批量大小")
|
return patch, warped_patch, torch.from_numpy(H).float()
|
||||||
parser.add_argument("--epochs", type=int, default=50, help="训练轮次")
|
|
||||||
parser.add_argument("--lr", type=float, default=1e-3, help="学习率")
|
|
||||||
parser.add_argument("--model_save_dir", type=str, default="./models/", help="模型保存路径")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# 创建输出目录
|
def simple_detector_loss(semi, semi_w, H, device):
|
||||||
os.makedirs(args.model_save_dir, exist_ok=True)
|
return F.mse_loss(semi, semi_w) # 简化版,实际需更复杂实现
|
||||||
|
|
||||||
# 数据加载
|
def simple_descriptor_loss(desc, desc_w, H, device):
|
||||||
dataset = LayoutDataset(root_dir=args.data_dir, transform=layout_transforms())
|
return F.mse_loss(desc, desc_w) # 简化版,实际需更复杂实现
|
||||||
total_samples = len(dataset)
|
|
||||||
val_size = int(total_samples * args.val_split)
|
|
||||||
train_size = total_samples - val_size
|
|
||||||
|
|
||||||
# 划分训练集和验证集
|
def train():
|
||||||
train_dataset, val_dataset = random_split(
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
dataset,
|
batch_size = 8
|
||||||
[train_size, val_size],
|
learning_rate = 0.001
|
||||||
generator=torch.Generator().manual_seed(42)
|
num_epochs = 10
|
||||||
)
|
image_dir = 'data/train_images' # 替换为实际路径
|
||||||
|
patch_size = 256
|
||||||
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4)
|
num_channels = 3 # 替换为实际通道数
|
||||||
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)
|
|
||||||
|
|
||||||
# 初始化模型、损失函数和优化器
|
dataset = BinaryDataset(image_dir, patch_size, num_channels)
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
||||||
model = RotationInvariantCNN().to(device) # 根据你的模型结构调整参数
|
|
||||||
criterion = nn.CrossEntropyLoss() # 分类任务示例,根据任务类型选择损失函数
|
|
||||||
optimizer = optim.Adam(model.parameters(), lr=args.lr)
|
|
||||||
|
|
||||||
# 训练循环
|
model = SuperPointCustom(num_channels=num_channels).to(device)
|
||||||
best_val_loss = float("inf")
|
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
|
||||||
for epoch in range(1, args.epochs + 1):
|
|
||||||
|
for epoch in range(num_epochs):
|
||||||
model.train()
|
model.train()
|
||||||
train_loss = 0.0
|
total_loss = 0
|
||||||
for batch_idx, (data, targets) in enumerate(train_loader):
|
for patch, warped_patch, H in dataloader:
|
||||||
data, targets = data.to(device), targets.to(device)
|
patch, warped_patch, H = patch.to(device), warped_patch.to(device), H.to(device)
|
||||||
|
semi, desc = model(patch)
|
||||||
# 前向传播
|
semi_w, desc_w = model(warped_patch)
|
||||||
outputs = model(data)
|
det_loss = simple_detector_loss(semi, semi_w, H, device)
|
||||||
loss = criterion(outputs, targets)
|
desc_loss = simple_descriptor_loss(desc, desc_w, H, device)
|
||||||
|
loss = det_loss + desc_loss
|
||||||
# 反向传播和优化
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
total_loss += loss.item()
|
||||||
train_loss += loss.item()
|
print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss / len(dataloader):.4f}")
|
||||||
|
|
||||||
if (batch_idx + 1) % 10 == 0:
|
|
||||||
print(f"Epoch [{epoch}/{args.epochs}] Batch {batch_idx+1}/{len(train_loader)} Loss: {loss.item():.4f}")
|
|
||||||
|
|
||||||
# 验证
|
|
||||||
model.eval()
|
|
||||||
val_loss = 0.0
|
|
||||||
with torch.no_grad():
|
|
||||||
for data, targets in val_loader:
|
|
||||||
data, targets = data.to(device), targets.to(device)
|
|
||||||
outputs = model(data)
|
|
||||||
loss = criterion(outputs, targets)
|
|
||||||
val_loss += loss.item()
|
|
||||||
|
|
||||||
avg_train_loss = train_loss / len(train_loader)
|
|
||||||
avg_val_loss = val_loss / len(val_loader)
|
|
||||||
|
|
||||||
print(f"Epoch {epoch} - Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")
|
|
||||||
|
|
||||||
# 保存最佳模型
|
|
||||||
if avg_val_loss < best_val_loss:
|
|
||||||
best_val_loss = avg_val_loss
|
|
||||||
torch.save(model.state_dict(), os.path.join(args.model_save_dir, f"best_model_{datetime.now().strftime('%Y%m%d%H%M')}.pth"))
|
|
||||||
|
|
||||||
print("训练完成!")
|
torch.save(model.state_dict(), 'superpoint_custom_model.pth')
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
train()
|
||||||
0
utils/__init__.py
Normal file
0
utils/__init__.py
Normal file
33
utils/data_augmentation.py
Normal file
33
utils/data_augmentation.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
def get_random_patch(image, patch_size):
|
||||||
|
h, w = image.shape[1:3]
|
||||||
|
x = np.random.randint(0, w - patch_size)
|
||||||
|
y = np.random.randint(0, h - patch_size)
|
||||||
|
return image[:, y:y+patch_size, x:x+patch_size]
|
||||||
|
|
||||||
|
def get_random_homography(max_rotation=30, max_translation=20):
|
||||||
|
theta = np.random.uniform(-max_rotation, max_rotation) * np.pi / 180.0
|
||||||
|
tx = np.random.uniform(-max_translation, max_translation)
|
||||||
|
ty = np.random.uniform(-max_translation, max_translation)
|
||||||
|
cos_theta = np.cos(theta)
|
||||||
|
sin_theta = np.sin(theta)
|
||||||
|
H = np.array([
|
||||||
|
[cos_theta, -sin_theta, tx],
|
||||||
|
[sin_theta, cos_theta, ty],
|
||||||
|
[0, 0, 1]
|
||||||
|
])
|
||||||
|
return H
|
||||||
|
|
||||||
|
def apply_homography_to_image(image, H, output_size):
|
||||||
|
warped = np.zeros_like(image)
|
||||||
|
for c in range(image.shape[0]):
|
||||||
|
warped[c] = cv2.warpPerspective(image[c], H, output_size, flags=cv2.INTER_NEAREST)
|
||||||
|
return warped
|
||||||
|
|
||||||
|
def generate_training_pair(image, patch_size):
|
||||||
|
patch = get_random_patch(image, patch_size)
|
||||||
|
H = get_random_homography()
|
||||||
|
warped_patch = apply_homography_to_image(patch, H, (patch_size, patch_size))
|
||||||
|
return patch, warped_patch, H
|
||||||
Reference in New Issue
Block a user