151 lines
6.6 KiB
Python
151 lines
6.6 KiB
Python
# train.py
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch.utils.data import Dataset, DataLoader
|
|
from PIL import Image
|
|
import numpy as np
|
|
import cv2
|
|
import os
|
|
import argparse
|
|
|
|
# 导入项目模块
|
|
import config
|
|
from models.rord import RoRD
|
|
from utils.data_utils import get_transform
|
|
|
|
# --- (已修改) 训练专用数据集类 ---
|
|
class ICLayoutTrainingDataset(Dataset):
|
|
def __init__(self, image_dir, patch_size=256, transform=None, scale_range=(1.0, 1.0)):
|
|
self.image_dir = image_dir
|
|
self.image_paths = [os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith('.png')]
|
|
self.patch_size = patch_size
|
|
self.transform = transform
|
|
self.scale_range = scale_range # 新增尺度范围参数
|
|
|
|
def __len__(self):
|
|
return len(self.image_paths)
|
|
|
|
def __getitem__(self, index):
|
|
img_path = self.image_paths[index]
|
|
image = Image.open(img_path).convert('L')
|
|
W, H = image.size
|
|
|
|
# --- 新增:尺度抖动数据增强 ---
|
|
# 1. 随机选择一个缩放比例
|
|
scale = np.random.uniform(self.scale_range[0], self.scale_range[1])
|
|
# 2. 根据缩放比例计算需要从原图裁剪的尺寸
|
|
crop_size = int(self.patch_size / scale)
|
|
|
|
# 确保裁剪尺寸不超过图像边界
|
|
if crop_size > min(W, H):
|
|
crop_size = min(W, H)
|
|
|
|
# 3. 随机裁剪
|
|
x = np.random.randint(0, W - crop_size + 1)
|
|
y = np.random.randint(0, H - crop_size + 1)
|
|
patch = image.crop((x, y, x + crop_size, y + crop_size))
|
|
|
|
# 4. 将裁剪出的图像块缩放回标准的 patch_size
|
|
patch = patch.resize((self.patch_size, self.patch_size), Image.LANCZOS)
|
|
# --- 尺度抖动结束 ---
|
|
|
|
patch_np = np.array(patch)
|
|
|
|
# 实现8个方向的离散几何变换 (这部分逻辑不变)
|
|
theta_deg = np.random.choice([0, 90, 180, 270])
|
|
is_mirrored = np.random.choice([True, False])
|
|
cx, cy = self.patch_size / 2.0, self.patch_size / 2.0
|
|
M = cv2.getRotationMatrix2D((cx, cy), theta_deg, 1)
|
|
|
|
if is_mirrored:
|
|
T1 = np.array([[1, 0, -cx], [0, 1, -cy], [0, 0, 1]])
|
|
Flip = np.array([[-1, 0, 0], [0, 1, 0], [0, 0, 1]])
|
|
T2 = np.array([[1, 0, cx], [0, 1, cy], [0, 0, 1]])
|
|
M_mirror_3x3 = T2 @ Flip @ T1
|
|
M_3x3 = np.vstack([M, [0, 0, 1]])
|
|
H = (M_3x3 @ M_mirror_3x3).astype(np.float32)
|
|
else:
|
|
H = np.vstack([M, [0, 0, 1]]).astype(np.float32)
|
|
|
|
transformed_patch_np = cv2.warpPerspective(patch_np, H, (self.patch_size, self.patch_size))
|
|
transformed_patch = Image.fromarray(transformed_patch_np)
|
|
|
|
if self.transform:
|
|
patch = self.transform(patch)
|
|
transformed_patch = self.transform(transformed_patch)
|
|
|
|
H_tensor = torch.from_numpy(H[:2, :]).float()
|
|
return patch, transformed_patch, H_tensor
|
|
|
|
# --- 特征图变换与损失函数 (无变动) ---
|
|
def warp_feature_map(feature_map, H_inv):
|
|
B, C, H, W = feature_map.size()
|
|
grid = F.affine_grid(H_inv, feature_map.size(), align_corners=False).to(feature_map.device)
|
|
return F.grid_sample(feature_map, grid, align_corners=False)
|
|
|
|
def compute_detection_loss(det_original, det_rotated, H):
|
|
with torch.no_grad():
|
|
H_inv = torch.inverse(torch.cat([H, torch.tensor([0.0, 0.0, 1.0]).view(1, 1, 3).repeat(H.shape[0], 1, 1)], dim=1))[:, :2, :]
|
|
warped_det_rotated = warp_feature_map(det_rotated, H_inv)
|
|
return F.mse_loss(det_original, warped_det_rotated)
|
|
|
|
def compute_description_loss(desc_original, desc_rotated, H, margin=1.0):
|
|
B, C, H_feat, W_feat = desc_original.size()
|
|
num_samples = 100
|
|
coords = torch.rand(B, num_samples, 2, device=desc_original.device) * 2 - 1
|
|
anchor = F.grid_sample(desc_original, coords.unsqueeze(1), align_corners=False).squeeze(2).transpose(1, 2)
|
|
coords_hom = torch.cat([coords, torch.ones(B, num_samples, 1, device=coords.device)], dim=2)
|
|
M_inv = torch.inverse(torch.cat([H, torch.tensor([0.0, 0.0, 1.0]).view(1, 1, 3).repeat(H.shape[0], 1, 1)], dim=1))
|
|
coords_transformed = (coords_hom @ M_inv.transpose(1, 2))[:, :, :2]
|
|
positive = F.grid_sample(desc_rotated, coords_transformed.unsqueeze(1), align_corners=False).squeeze(2).transpose(1, 2)
|
|
neg_coords = torch.rand(B, num_samples, 2, device=desc_original.device) * 2 - 1
|
|
negative = F.grid_sample(desc_rotated, neg_coords.unsqueeze(1), align_corners=False).squeeze(2).transpose(1, 2)
|
|
triplet_loss = nn.TripletMarginLoss(margin=margin, p=2)
|
|
return triplet_loss(anchor, positive, negative)
|
|
|
|
# --- (已修改) 主函数与命令行接口 ---
|
|
def main(args):
|
|
print("--- 开始训练 RoRD 模型 ---")
|
|
print(f"训练参数: Epochs={args.epochs}, Batch Size={args.batch_size}, LR={args.lr}")
|
|
transform = get_transform()
|
|
# 在数据集初始化时传入尺度抖动范围
|
|
dataset = ICLayoutTrainingDataset(
|
|
args.data_dir,
|
|
patch_size=config.PATCH_SIZE,
|
|
transform=transform,
|
|
scale_range=config.SCALE_JITTER_RANGE
|
|
)
|
|
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=4)
|
|
model = RoRD().cuda()
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
|
|
|
|
for epoch in range(args.epochs):
|
|
model.train()
|
|
total_loss_val = 0
|
|
for i, (original, rotated, H) in enumerate(dataloader):
|
|
original, rotated, H = original.cuda(), rotated.cuda(), H.cuda()
|
|
det_original, desc_original = model(original)
|
|
det_rotated, desc_rotated = model(rotated)
|
|
loss = compute_detection_loss(det_original, det_rotated, H) + compute_description_loss(desc_original, desc_rotated, H)
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|
|
total_loss_val += loss.item()
|
|
print(f"--- Epoch {epoch+1} 完成, 平均 Loss: {total_loss_val / len(dataloader):.4f} ---")
|
|
|
|
if not os.path.exists(args.save_dir):
|
|
os.makedirs(args.save_dir)
|
|
save_path = os.path.join(args.save_dir, 'rord_model_final.pth')
|
|
torch.save(model.state_dict(), save_path)
|
|
print(f"模型已保存至: {save_path}")
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="训练 RoRD 模型")
|
|
parser.add_argument('--data_dir', type=str, default=config.LAYOUT_DIR)
|
|
parser.add_argument('--save_dir', type=str, default=config.SAVE_DIR)
|
|
parser.add_argument('--epochs', type=int, default=config.NUM_EPOCHS)
|
|
parser.add_argument('--batch_size', type=int, default=config.BATCH_SIZE)
|
|
parser.add_argument('--lr', type=float, default=config.LEARNING_RATE)
|
|
main(parser.parse_args()) |