Files
LayoutMatch/train.py
2025-03-31 14:49:04 +08:00

68 lines
2.6 KiB
Python

import os
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
from models.superpoint_custom import SuperPointCustom
from utils.data_augmentation import generate_training_pair
class BinaryDataset(Dataset):
def __init__(self, image_dir, patch_size, num_channels):
self.image_dir = image_dir
self.patch_size = patch_size
self.num_channels = num_channels
self.image_paths = [os.path.join(image_dir, f) for f in os.listdir(image_dir)
if f.endswith('.npy')]
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
img_path = self.image_paths[idx]
image = np.load(img_path) # [C, H, W]
patch, warped_patch, H = generate_training_pair(image, self.patch_size)
patch = torch.from_numpy(patch).float()
warped_patch = torch.from_numpy(warped_patch).float()
return patch, warped_patch, torch.from_numpy(H).float()
def simple_detector_loss(semi, semi_w, H, device):
return F.mse_loss(semi, semi_w) # 简化版,实际需更复杂实现
def simple_descriptor_loss(desc, desc_w, H, device):
return F.mse_loss(desc, desc_w) # 简化版,实际需更复杂实现
def train():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size = 8
learning_rate = 0.001
num_epochs = 10
image_dir = 'data/train_images' # 替换为实际路径
patch_size = 256
num_channels = 3 # 替换为实际通道数
dataset = BinaryDataset(image_dir, patch_size, num_channels)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
model = SuperPointCustom(num_channels=num_channels).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
for epoch in range(num_epochs):
model.train()
total_loss = 0
for patch, warped_patch, H in dataloader:
patch, warped_patch, H = patch.to(device), warped_patch.to(device), H.to(device)
semi, desc = model(patch)
semi_w, desc_w = model(warped_patch)
det_loss = simple_detector_loss(semi, semi_w, H, device)
desc_loss = simple_descriptor_loss(desc, desc_w, H, device)
loss = det_loss + desc_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss / len(dataloader):.4f}")
torch.save(model.state_dict(), 'superpoint_custom_model.pth')
if __name__ == "__main__":
train()