From a5c63ad0de5d8894c36fab645d02f83b7c367b8e Mon Sep 17 00:00:00 2001 From: jiao77 Date: Mon, 31 Mar 2025 14:49:04 +0800 Subject: [PATCH] =?UTF-8?q?=E4=B8=80=E4=B8=AA=E7=9B=AE=E6=A0=87=E5=AE=9E?= =?UTF-8?q?=E6=97=B6=E6=A3=80=E6=B5=8B=E7=9A=84=E6=A8=A1=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- inference.py | 97 +++++++++++++++----------- models/rotation_cnn.py | 31 --------- models/superpoint_custom.py | 47 +++++++++++++ train.py | 131 ++++++++++++++---------------------- utils/__init__.py | 0 utils/data_augmentation.py | 33 +++++++++ 6 files changed, 188 insertions(+), 151 deletions(-) delete mode 100644 models/rotation_cnn.py create mode 100644 models/superpoint_custom.py create mode 100644 utils/__init__.py create mode 100644 utils/data_augmentation.py diff --git a/inference.py b/inference.py index 325d1a3..8f0b352 100644 --- a/inference.py +++ b/inference.py @@ -1,51 +1,70 @@ -import faiss -import numpy as np import torch -from models.rotation_cnn import RotationInvariantNet, get_rotational_features -from data_units import layout_to_tensor, tile_layout +import cv2 +import numpy as np +from models.superpoint_custom import SuperPointCustom -def main(): - # 配置参数(需根据实际调整) - block_size = 64 # 分块尺寸 - target_module_path = "target.png" - large_layout_path = "layout_large.png" +def get_keypoints_from_heatmap(semi, threshold=0.015): + semi = semi.squeeze().cpu().numpy() # [65, H/8, W/8] + prob = cv2.softmax(semi, axis=0)[:-1] # [64, H/8, W/8] + prob = prob.reshape(8, 8, semi.shape[1], semi.shape[2]) + prob = prob.transpose(0, 2, 1, 3).reshape(8*semi.shape[1], 8*semi.shape[2]) # [H, W] + keypoints = [] + for y in range(prob.shape[0]): + for x in range(prob.shape[1]): + if prob[y, x] > threshold: + keypoints.append(cv2.KeyPoint(x, y, 1)) + return keypoints - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - model = RotationInvariantNet().to(device) - model.load_state_dict(torch.load("rotation_cnn.pth")) +def get_descriptors_from_map(desc, keypoints): + desc = desc.squeeze().cpu().numpy() # [256, H/8, W/8] + descriptors = [] + scale = 8 + for kp in keypoints: + x, y = int(kp.pt[0] / scale), int(kp.pt[1] / scale) + if 0 <= x < desc.shape[2] and 0 <= y < desc.shape[1]: + descriptors.append(desc[:, y, x]) + return np.array(descriptors) + +def match_and_estimate(layout_path, module_path, model_path, num_channels, device='cuda'): + model = SuperPointCustom(num_channels=num_channels).to(device) + model.load_state_dict(torch.load(model_path, map_location=device)) model.eval() - # 预处理目标模块与大版图 - target_tensor = layout_to_tensor(target_module_path, (block_size, block_size)) - target_feat = get_rotational_features(model, torch.tensor(target_tensor).to(device)) + layout = np.load(layout_path) # [C, H, W] + module = np.load(module_path) # [C, H, W] + layout_tensor = torch.from_numpy(layout).float().unsqueeze(0).to(device) + module_tensor = torch.from_numpy(module).float().unsqueeze(0).to(device) - large_layout = layout_to_tensor(large_layout_path) - tiles = tile_layout(large_layout) + with torch.no_grad(): + semi_layout, desc_layout = model(layout_tensor) + semi_module, desc_module = model(module_tensor) + + kp_layout = get_keypoints_from_heatmap(semi_layout) + desc_layout = get_descriptors_from_map(desc_layout, kp_layout) + kp_module = get_keypoints_from_heatmap(semi_module) + desc_module = get_descriptors_from_map(desc_module, kp_module) - # 构建特征索引(使用Faiss加速) - index = faiss.IndexFlatL2(64) # 特征维度由模型决定 - features_db = [] - for (x, y, tile) in tiles: - feat = get_rotational_features(model, torch.tensor(tile).to(device)) - features_db.append(feat) - index.add(np.stack(features_db)) + bf = cv2.BFMatcher(cv2.NORM_L2, crossCheck=True) + matches = bf.match(desc_module, desc_layout) + matches = sorted(matches, key=lambda x: x.distance) - # 检索相似区域 - D, I = index.search(target_feat[np.newaxis, :], k=10) + src_pts = np.float32([kp_module[m.queryIdx].pt for m in matches]).reshape(-1, 1, 2) + dst_pts = np.float32([kp_layout[m.trainIdx].pt for m in matches]).reshape(-1, 1, 2) + H, _ = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 5.0) - for idx in I[0]: - x, y, _ = tiles[idx] + h, w = module.shape[1], module.shape[2] + corners = np.float32([[0, 0], [w, 0], [w, h], [0, h]]).reshape(-1, 1, 2) + transformed_corners = cv2.perspectiveTransform(corners, H) + x_min, y_min = np.min(transformed_corners, axis=0).ravel().astype(int) + x_max, y_max = np.max(transformed_corners, axis=0).ravel().astype(int) + theta = np.arctan2(H[1, 0], H[0, 0]) * 180 / np.pi - # 计算最佳匹配角度的显式计算 - min_angle, min_dist = 90, float('inf') - target_vec = target_feat - feat = features_db[idx] - for a in [0, 1, 2, 3]: # 代表0°、90°、180°、270° - rotated_feat = np.rot90(feat.reshape(block_size, block_size), k=a) - dist = np.linalg.norm(target_vec - rotated_feat.flatten()) - if dist < min_dist: - min_dist, min_angle = dist, a * 90 + print(f"Matched region: [{x_min}, {y_min}, {x_max}, {y_max}], Rotation: {theta:.2f} degrees") + return x_min, y_min, x_max, y_max, theta - print(f"坐标({x},{y}), 最佳旋转方向{min_angle}度,距离: {min_dist}") if __name__ == "__main__": - main() \ No newline at end of file + layout_path = "data/large_layout.npy" + module_path = "data/small_module.npy" + model_path = "superpoint_custom_model.pth" + num_channels = 3 # 替换为实际通道数 + match_and_estimate(layout_path, module_path, model_path, num_channels) \ No newline at end of file diff --git a/models/rotation_cnn.py b/models/rotation_cnn.py deleted file mode 100644 index 226cbde..0000000 --- a/models/rotation_cnn.py +++ /dev/null @@ -1,31 +0,0 @@ -import torch -import torch.nn as nn - -class RotationInvariantNet(nn.Module): - """轻量级旋转不变特征提取网络""" - def __init__(self, input_channels=1): - super().__init__() - self.cnn = nn.Sequential( - # 基础卷积层 - nn.Conv2d(input_channels, 32, kernel_size=3, padding=1), - nn.ReLU(), - nn.MaxPool2d(2), # 下采样 - nn.Conv2d(32, 64, kernel_size=3, padding=1), - nn.ReLU(), - nn.Conv2d(64, 64, kernel_size=3, stride=2), # 更大感受野 - nn.AdaptiveAvgPool2d((4,4)), # 全局池化获取全局特征,调整输出尺寸为4x4 - nn.Flatten(), # 展平为一维向量 - nn.Linear(64*16, 128) # 增加全连接层以降低维度到128 - ) - - def forward(self, x): - return self.cnn(x) -def get_rotational_features(model, input_image): - """计算输入图像所有旋转角度的特征平均值""" - rotations = [0, 90, 180, 270] - features_list = [] - for angle in rotations: - rotated_img = torch.rot90(input_image, k=angle//90, dims=[2,3]) - feat = model(rotated_img.unsqueeze(0)) - features_list.append(feat) - return torch.mean(torch.stack(features_list), dim=0).detach().numpy() \ No newline at end of file diff --git a/models/superpoint_custom.py b/models/superpoint_custom.py new file mode 100644 index 0000000..6dbb9b8 --- /dev/null +++ b/models/superpoint_custom.py @@ -0,0 +1,47 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class SuperPointCustom(nn.Module): + def __init__(self, num_channels=3): # num_channels 为版图通道数 + super(SuperPointCustom, self).__init__() + self.relu = nn.ReLU(inplace=True) + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + c1, c2, c3, c4, c5, d1 = 64, 64, 128, 128, 256, 256 + # 编码器 + self.conv1a = nn.Conv2d(num_channels, c1, kernel_size=3, stride=1, padding=1) + self.conv1b = nn.Conv2d(c1, c1, kernel_size=3, stride=1, padding=1) + self.conv2a = nn.Conv2d(c1, c2, kernel_size=3, stride=1, padding=1) + self.conv2b = nn.Conv2d(c2, c2, kernel_size=3, stride=1, padding=1) + self.conv3a = nn.Conv2d(c2, c3, kernel_size=3, stride=1, padding=1) + self.conv3b = nn.Conv2d(c3, c3, kernel_size=3, stride=1, padding=1) + self.conv4a = nn.Conv2d(c3, c4, kernel_size=3, stride=1, padding=1) + self.conv4b = nn.Conv2d(c4, c4, kernel_size=3, stride=1, padding=1) + # 检测头 + self.convPa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1) + self.convPb = nn.Conv2d(c5, 65, kernel_size=1, stride=1, padding=0) # 65 = 8x8 + dustbin + # 描述符头 + self.convDa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1) + self.convDb = nn.Conv2d(c5, d1, kernel_size=1, stride=1, padding=0) + + def forward(self, x): + # 编码器 + x = self.relu(self.conv1a(x)) + x = self.relu(self.conv1b(x)) + x = self.pool(x) + x = self.relu(self.conv2a(x)) + x = self.relu(self.conv2b(x)) + x = self.pool(x) + x = self.relu(self.conv3a(x)) + x = self.relu(self.conv3b(x)) + x = self.pool(x) + x = self.relu(self.conv4a(x)) + x = self.relu(self.conv4b(x)) + # 检测头 + cPa = self.relu(self.convPa(x)) + semi = self.convPb(cPa) # [B, 65, H/8, W/8] + # 描述符头 + cDa = self.relu(self.convDa(x)) + desc = self.convDb(cDa) # [B, 256, H/8, W/8] + desc = F.normalize(desc, p=2, dim=1) # L2归一化 + return semi, desc \ No newline at end of file diff --git a/train.py b/train.py index 097dd69..2a57519 100644 --- a/train.py +++ b/train.py @@ -1,99 +1,68 @@ import os import torch -from torch import nn, optim -from torch.utils.data import DataLoader, random_split +import torch.nn.functional as F +from torch.utils.data import Dataset, DataLoader import numpy as np -from datetime import datetime -import argparse +from models.superpoint_custom import SuperPointCustom +from utils.data_augmentation import generate_training_pair -# 导入项目模块(根据你的路径调整) -from models.rotation_cnn import RotationInvariantCNN # 模型实现 -from data_units import LayoutDataset, layout_transforms # 数据集和预处理函数 +class BinaryDataset(Dataset): + def __init__(self, image_dir, patch_size, num_channels): + self.image_dir = image_dir + self.patch_size = patch_size + self.num_channels = num_channels + self.image_paths = [os.path.join(image_dir, f) for f in os.listdir(image_dir) + if f.endswith('.npy')] -# 设置随机种子(可选) -torch.manual_seed(42) -np.random.seed(42) + def __len__(self): + return len(self.image_paths) -def main(): - """训练流程""" - # 解析命令行参数 - parser = argparse.ArgumentParser(description="Train Rotation-Invariant Layout Matcher") - parser.add_argument("--data_dir", type=str, default="./data/train/", help="训练数据目录") - parser.add_argument("--val_split", type=float, default=0.2, help="验证集比例") - parser.add_argument("--batch_size", type=int, default=16, help="批量大小") - parser.add_argument("--epochs", type=int, default=50, help="训练轮次") - parser.add_argument("--lr", type=float, default=1e-3, help="学习率") - parser.add_argument("--model_save_dir", type=str, default="./models/", help="模型保存路径") - args = parser.parse_args() + def __getitem__(self, idx): + img_path = self.image_paths[idx] + image = np.load(img_path) # [C, H, W] + patch, warped_patch, H = generate_training_pair(image, self.patch_size) + patch = torch.from_numpy(patch).float() + warped_patch = torch.from_numpy(warped_patch).float() + return patch, warped_patch, torch.from_numpy(H).float() - # 创建输出目录 - os.makedirs(args.model_save_dir, exist_ok=True) +def simple_detector_loss(semi, semi_w, H, device): + return F.mse_loss(semi, semi_w) # 简化版,实际需更复杂实现 - # 数据加载 - dataset = LayoutDataset(root_dir=args.data_dir, transform=layout_transforms()) - total_samples = len(dataset) - val_size = int(total_samples * args.val_split) - train_size = total_samples - val_size +def simple_descriptor_loss(desc, desc_w, H, device): + return F.mse_loss(desc, desc_w) # 简化版,实际需更复杂实现 - # 划分训练集和验证集 - train_dataset, val_dataset = random_split( - dataset, - [train_size, val_size], - generator=torch.Generator().manual_seed(42) - ) - - train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4) - val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4) +def train(): + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + batch_size = 8 + learning_rate = 0.001 + num_epochs = 10 + image_dir = 'data/train_images' # 替换为实际路径 + patch_size = 256 + num_channels = 3 # 替换为实际通道数 - # 初始化模型、损失函数和优化器 - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - model = RotationInvariantCNN().to(device) # 根据你的模型结构调整参数 - criterion = nn.CrossEntropyLoss() # 分类任务示例,根据任务类型选择损失函数 - optimizer = optim.Adam(model.parameters(), lr=args.lr) + dataset = BinaryDataset(image_dir, patch_size, num_channels) + dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) - # 训练循环 - best_val_loss = float("inf") - for epoch in range(1, args.epochs + 1): + model = SuperPointCustom(num_channels=num_channels).to(device) + optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) + + for epoch in range(num_epochs): model.train() - train_loss = 0.0 - for batch_idx, (data, targets) in enumerate(train_loader): - data, targets = data.to(device), targets.to(device) - - # 前向传播 - outputs = model(data) - loss = criterion(outputs, targets) - - # 反向传播和优化 + total_loss = 0 + for patch, warped_patch, H in dataloader: + patch, warped_patch, H = patch.to(device), warped_patch.to(device), H.to(device) + semi, desc = model(patch) + semi_w, desc_w = model(warped_patch) + det_loss = simple_detector_loss(semi, semi_w, H, device) + desc_loss = simple_descriptor_loss(desc, desc_w, H, device) + loss = det_loss + desc_loss optimizer.zero_grad() loss.backward() optimizer.step() - - train_loss += loss.item() - - if (batch_idx + 1) % 10 == 0: - print(f"Epoch [{epoch}/{args.epochs}] Batch {batch_idx+1}/{len(train_loader)} Loss: {loss.item():.4f}") - - # 验证 - model.eval() - val_loss = 0.0 - with torch.no_grad(): - for data, targets in val_loader: - data, targets = data.to(device), targets.to(device) - outputs = model(data) - loss = criterion(outputs, targets) - val_loss += loss.item() - - avg_train_loss = train_loss / len(train_loader) - avg_val_loss = val_loss / len(val_loader) - - print(f"Epoch {epoch} - Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}") - - # 保存最佳模型 - if avg_val_loss < best_val_loss: - best_val_loss = avg_val_loss - torch.save(model.state_dict(), os.path.join(args.model_save_dir, f"best_model_{datetime.now().strftime('%Y%m%d%H%M')}.pth")) + total_loss += loss.item() + print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss / len(dataloader):.4f}") - print("训练完成!") + torch.save(model.state_dict(), 'superpoint_custom_model.pth') if __name__ == "__main__": - main() \ No newline at end of file + train() \ No newline at end of file diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/data_augmentation.py b/utils/data_augmentation.py new file mode 100644 index 0000000..cc05310 --- /dev/null +++ b/utils/data_augmentation.py @@ -0,0 +1,33 @@ +import numpy as np +import cv2 + +def get_random_patch(image, patch_size): + h, w = image.shape[1:3] + x = np.random.randint(0, w - patch_size) + y = np.random.randint(0, h - patch_size) + return image[:, y:y+patch_size, x:x+patch_size] + +def get_random_homography(max_rotation=30, max_translation=20): + theta = np.random.uniform(-max_rotation, max_rotation) * np.pi / 180.0 + tx = np.random.uniform(-max_translation, max_translation) + ty = np.random.uniform(-max_translation, max_translation) + cos_theta = np.cos(theta) + sin_theta = np.sin(theta) + H = np.array([ + [cos_theta, -sin_theta, tx], + [sin_theta, cos_theta, ty], + [0, 0, 1] + ]) + return H + +def apply_homography_to_image(image, H, output_size): + warped = np.zeros_like(image) + for c in range(image.shape[0]): + warped[c] = cv2.warpPerspective(image[c], H, output_size, flags=cv2.INTER_NEAREST) + return warped + +def generate_training_pair(image, patch_size): + patch = get_random_patch(image, patch_size) + H = get_random_homography() + warped_patch = apply_homography_to_image(patch, H, (patch_size, patch_size)) + return patch, warped_patch, H \ No newline at end of file