import os import torch import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader import numpy as np from models.superpoint_custom import SuperPointCustom from utils.data_augmentation import generate_training_pair class BinaryDataset(Dataset): def __init__(self, image_dir, patch_size, num_channels): self.image_dir = image_dir self.patch_size = patch_size self.num_channels = num_channels self.image_paths = [os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith('.npy')] def __len__(self): return len(self.image_paths) def __getitem__(self, idx): img_path = self.image_paths[idx] image = np.load(img_path) # [C, H, W] patch, warped_patch, H = generate_training_pair(image, self.patch_size) patch = torch.from_numpy(patch).float() warped_patch = torch.from_numpy(warped_patch).float() return patch, warped_patch, torch.from_numpy(H).float() def simple_detector_loss(semi, semi_w, H, device): return F.mse_loss(semi, semi_w) # 简化版,实际需更复杂实现 def simple_descriptor_loss(desc, desc_w, H, device): return F.mse_loss(desc, desc_w) # 简化版,实际需更复杂实现 def train(): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') batch_size = 8 learning_rate = 0.001 num_epochs = 10 image_dir = 'data/train_images' # 替换为实际路径 patch_size = 256 num_channels = 3 # 替换为实际通道数 dataset = BinaryDataset(image_dir, patch_size, num_channels) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) model = SuperPointCustom(num_channels=num_channels).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) for epoch in range(num_epochs): model.train() total_loss = 0 for patch, warped_patch, H in dataloader: patch, warped_patch, H = patch.to(device), warped_patch.to(device), H.to(device) semi, desc = model(patch) semi_w, desc_w = model(warped_patch) det_loss = simple_detector_loss(semi, semi_w, H, device) desc_loss = simple_descriptor_loss(desc, desc_w, H, device) loss = det_loss + desc_loss optimizer.zero_grad() loss.backward() optimizer.step() total_loss += loss.item() print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss / len(dataloader):.4f}") torch.save(model.state_dict(), 'superpoint_custom_model.pth') if __name__ == "__main__": train()