68 lines
2.6 KiB
Python
68 lines
2.6 KiB
Python
import os
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch.utils.data import Dataset, DataLoader
|
|
import numpy as np
|
|
from models.superpoint_custom import SuperPointCustom
|
|
from utils.data_augmentation import generate_training_pair
|
|
|
|
class BinaryDataset(Dataset):
|
|
def __init__(self, image_dir, patch_size, num_channels):
|
|
self.image_dir = image_dir
|
|
self.patch_size = patch_size
|
|
self.num_channels = num_channels
|
|
self.image_paths = [os.path.join(image_dir, f) for f in os.listdir(image_dir)
|
|
if f.endswith('.npy')]
|
|
|
|
def __len__(self):
|
|
return len(self.image_paths)
|
|
|
|
def __getitem__(self, idx):
|
|
img_path = self.image_paths[idx]
|
|
image = np.load(img_path) # [C, H, W]
|
|
patch, warped_patch, H = generate_training_pair(image, self.patch_size)
|
|
patch = torch.from_numpy(patch).float()
|
|
warped_patch = torch.from_numpy(warped_patch).float()
|
|
return patch, warped_patch, torch.from_numpy(H).float()
|
|
|
|
def simple_detector_loss(semi, semi_w, H, device):
|
|
return F.mse_loss(semi, semi_w) # 简化版,实际需更复杂实现
|
|
|
|
def simple_descriptor_loss(desc, desc_w, H, device):
|
|
return F.mse_loss(desc, desc_w) # 简化版,实际需更复杂实现
|
|
|
|
def train():
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
batch_size = 8
|
|
learning_rate = 0.001
|
|
num_epochs = 10
|
|
image_dir = 'data/train_images' # 替换为实际路径
|
|
patch_size = 256
|
|
num_channels = 3 # 替换为实际通道数
|
|
|
|
dataset = BinaryDataset(image_dir, patch_size, num_channels)
|
|
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
|
|
|
model = SuperPointCustom(num_channels=num_channels).to(device)
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
|
|
|
|
for epoch in range(num_epochs):
|
|
model.train()
|
|
total_loss = 0
|
|
for patch, warped_patch, H in dataloader:
|
|
patch, warped_patch, H = patch.to(device), warped_patch.to(device), H.to(device)
|
|
semi, desc = model(patch)
|
|
semi_w, desc_w = model(warped_patch)
|
|
det_loss = simple_detector_loss(semi, semi_w, H, device)
|
|
desc_loss = simple_descriptor_loss(desc, desc_w, H, device)
|
|
loss = det_loss + desc_loss
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|
|
total_loss += loss.item()
|
|
print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss / len(dataloader):.4f}")
|
|
|
|
torch.save(model.state_dict(), 'superpoint_custom_model.pth')
|
|
|
|
if __name__ == "__main__":
|
|
train() |