fix function

This commit is contained in:
Jiao77
2025-06-30 03:27:18 +08:00
parent 98f6709768
commit 35c3cb1420
4 changed files with 336 additions and 41 deletions

241
train.py
View File

@@ -9,12 +9,31 @@ import numpy as np
import cv2
import os
import argparse
import logging
from datetime import datetime
# 导入项目模块
import config
from models.rord import RoRD
from utils.data_utils import get_transform
# 设置日志记录
def setup_logging(save_dir):
"""设置训练日志记录"""
if not os.path.exists(save_dir):
os.makedirs(save_dir)
log_file = os.path.join(save_dir, f'training_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log')
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(log_file),
logging.StreamHandler()
]
)
return logging.getLogger(__name__)
# --- (已修改) 训练专用数据集类 ---
class ICLayoutTrainingDataset(Dataset):
def __init__(self, image_dir, patch_size=256, transform=None, scale_range=(1.0, 1.0)):
@@ -48,9 +67,28 @@ class ICLayoutTrainingDataset(Dataset):
patch = image.crop((x, y, x + crop_size, y + crop_size))
# 4. 将裁剪出的图像块缩放回标准的 patch_size
patch = patch.resize((self.patch_size, self.patch_size), Image.LANCZOS)
patch = patch.resize((self.patch_size, self.patch_size), Image.Resampling.LANCZOS)
# --- 尺度抖动结束 ---
# --- 新增:额外的数据增强 ---
# 亮度调整
if np.random.random() < 0.5:
brightness_factor = np.random.uniform(0.8, 1.2)
patch = patch.point(lambda x: int(x * brightness_factor))
# 对比度调整
if np.random.random() < 0.5:
contrast_factor = np.random.uniform(0.8, 1.2)
patch = patch.point(lambda x: int(((x - 128) * contrast_factor) + 128))
# 添加噪声
if np.random.random() < 0.3:
patch_np = np.array(patch, dtype=np.float32)
noise = np.random.normal(0, 5, patch_np.shape)
patch_np = np.clip(patch_np + noise, 0, 255)
patch = Image.fromarray(patch_np.astype(np.uint8))
# --- 额外数据增强结束 ---
patch_np = np.array(patch)
# 实现8个方向的离散几何变换 (这部分逻辑不变)
@@ -79,37 +117,78 @@ class ICLayoutTrainingDataset(Dataset):
H_tensor = torch.from_numpy(H[:2, :]).float()
return patch, transformed_patch, H_tensor
# --- 特征图变换与损失函数 (无变动) ---
# --- 特征图变换与损失函数 (改进版) ---
def warp_feature_map(feature_map, H_inv):
B, C, H, W = feature_map.size()
grid = F.affine_grid(H_inv, feature_map.size(), align_corners=False).to(feature_map.device)
return F.grid_sample(feature_map, grid, align_corners=False)
def compute_detection_loss(det_original, det_rotated, H):
"""改进的检测损失使用BCE损失替代MSE"""
with torch.no_grad():
H_inv = torch.inverse(torch.cat([H, torch.tensor([0.0, 0.0, 1.0]).view(1, 1, 3).repeat(H.shape[0], 1, 1)], dim=1))[:, :2, :]
warped_det_rotated = warp_feature_map(det_rotated, H_inv)
return F.mse_loss(det_original, warped_det_rotated)
# 使用BCE损失更适合二分类问题
bce_loss = F.binary_cross_entropy(det_original, warped_det_rotated)
# 添加平滑L1损失作为辅助
smooth_l1_loss = F.smooth_l1_loss(det_original, warped_det_rotated)
return bce_loss + 0.1 * smooth_l1_loss
def compute_description_loss(desc_original, desc_rotated, H, margin=1.0):
"""改进的描述子损失:使用更有效的采样策略"""
B, C, H_feat, W_feat = desc_original.size()
num_samples = 100
coords = torch.rand(B, num_samples, 2, device=desc_original.device) * 2 - 1
# 增加采样点数量,提高训练稳定性
num_samples = 200
# 使用网格采样而不是随机采样,确保空间分布更均匀
h_coords = torch.linspace(-1, 1, int(np.sqrt(num_samples)), device=desc_original.device)
w_coords = torch.linspace(-1, 1, int(np.sqrt(num_samples)), device=desc_original.device)
h_grid, w_grid = torch.meshgrid(h_coords, w_coords, indexing='ij')
coords = torch.stack([h_grid.flatten(), w_grid.flatten()], dim=1).unsqueeze(0).repeat(B, 1, 1)
# 采样anchor点
anchor = F.grid_sample(desc_original, coords.unsqueeze(1), align_corners=False).squeeze(2).transpose(1, 2)
coords_hom = torch.cat([coords, torch.ones(B, num_samples, 1, device=coords.device)], dim=2)
# 计算对应的正样本点
coords_hom = torch.cat([coords, torch.ones(B, coords.size(1), 1, device=coords.device)], dim=2)
M_inv = torch.inverse(torch.cat([H, torch.tensor([0.0, 0.0, 1.0]).view(1, 1, 3).repeat(H.shape[0], 1, 1)], dim=1))
coords_transformed = (coords_hom @ M_inv.transpose(1, 2))[:, :, :2]
positive = F.grid_sample(desc_rotated, coords_transformed.unsqueeze(1), align_corners=False).squeeze(2).transpose(1, 2)
neg_coords = torch.rand(B, num_samples, 2, device=desc_original.device) * 2 - 1
negative = F.grid_sample(desc_rotated, neg_coords.unsqueeze(1), align_corners=False).squeeze(2).transpose(1, 2)
triplet_loss = nn.TripletMarginLoss(margin=margin, p=2)
# 使用困难负样本挖掘
with torch.no_grad():
# 计算所有可能的负样本对
neg_coords = torch.rand(B, num_samples * 2, 2, device=desc_original.device) * 2 - 1
negative_candidates = F.grid_sample(desc_rotated, neg_coords.unsqueeze(1), align_corners=False).squeeze(2).transpose(1, 2)
# 选择最困难的负样本
anchor_expanded = anchor.unsqueeze(2).expand(-1, -1, negative_candidates.size(1), -1)
negative_candidates_expanded = negative_candidates.unsqueeze(1).expand(-1, anchor.size(1), -1, -1)
distances = torch.norm(anchor_expanded - negative_candidates_expanded, dim=3)
hard_negative_indices = torch.argmin(distances, dim=2)
negative = torch.gather(negative_candidates, 1, hard_negative_indices.unsqueeze(2).expand(-1, -1, C))
# 使用改进的Triplet Loss
triplet_loss = nn.TripletMarginLoss(margin=margin, p=2, reduction='mean')
return triplet_loss(anchor, positive, negative)
# --- (已修改) 主函数与命令行接口 ---
def main(args):
print("--- 开始训练 RoRD 模型 ---")
print(f"训练参数: Epochs={args.epochs}, Batch Size={args.batch_size}, LR={args.lr}")
# 设置日志记录
logger = setup_logging(args.save_dir)
logger.info("--- 开始训练 RoRD 模型 ---")
logger.info(f"训练参数: Epochs={args.epochs}, Batch Size={args.batch_size}, LR={args.lr}")
logger.info(f"数据目录: {args.data_dir}")
logger.info(f"保存目录: {args.save_dir}")
transform = get_transform()
# 在数据集初始化时传入尺度抖动范围
dataset = ICLayoutTrainingDataset(
args.data_dir,
@@ -117,29 +196,145 @@ def main(args):
transform=transform,
scale_range=config.SCALE_JITTER_RANGE
)
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=4)
logger.info(f"数据集大小: {len(dataset)}")
# 分割训练集和验证集
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
logger.info(f"训练集大小: {len(train_dataset)}, 验证集大小: {len(val_dataset)}")
train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4)
val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)
model = RoRD().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
logger.info(f"模型参数数量: {sum(p.numel() for p in model.parameters()):,}")
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-4)
# 添加学习率调度器
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode='min', factor=0.5, patience=5
)
# 早停机制
best_val_loss = float('inf')
patience_counter = 0
patience = 10
for epoch in range(args.epochs):
# 训练阶段
model.train()
total_loss_val = 0
for i, (original, rotated, H) in enumerate(dataloader):
total_train_loss = 0
total_det_loss = 0
total_desc_loss = 0
for i, (original, rotated, H) in enumerate(train_dataloader):
original, rotated, H = original.cuda(), rotated.cuda(), H.cuda()
det_original, desc_original = model(original)
det_rotated, desc_rotated = model(rotated)
loss = compute_detection_loss(det_original, det_rotated, H) + compute_description_loss(desc_original, desc_rotated, H)
det_loss = compute_detection_loss(det_original, det_rotated, H)
desc_loss = compute_description_loss(desc_original, desc_rotated, H)
loss = det_loss + desc_loss
optimizer.zero_grad()
loss.backward()
# 梯度裁剪,防止梯度爆炸
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
total_loss_val += loss.item()
print(f"--- Epoch {epoch+1} 完成, 平均 Loss: {total_loss_val / len(dataloader):.4f} ---")
if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir)
total_train_loss += loss.item()
total_det_loss += det_loss.item()
total_desc_loss += desc_loss.item()
if i % 10 == 0:
logger.info(f"Epoch {epoch+1}, Batch {i}, Total Loss: {loss.item():.4f}, "
f"Det Loss: {det_loss.item():.4f}, Desc Loss: {desc_loss.item():.4f}")
avg_train_loss = total_train_loss / len(train_dataloader)
avg_det_loss = total_det_loss / len(train_dataloader)
avg_desc_loss = total_desc_loss / len(train_dataloader)
# 验证阶段
model.eval()
total_val_loss = 0
total_val_det_loss = 0
total_val_desc_loss = 0
with torch.no_grad():
for original, rotated, H in val_dataloader:
original, rotated, H = original.cuda(), rotated.cuda(), H.cuda()
det_original, desc_original = model(original)
det_rotated, desc_rotated = model(rotated)
val_det_loss = compute_detection_loss(det_original, det_rotated, H)
val_desc_loss = compute_description_loss(desc_original, desc_rotated, H)
val_loss = val_det_loss + val_desc_loss
total_val_loss += val_loss.item()
total_val_det_loss += val_det_loss.item()
total_val_desc_loss += val_desc_loss.item()
avg_val_loss = total_val_loss / len(val_dataloader)
avg_val_det_loss = total_val_det_loss / len(val_dataloader)
avg_val_desc_loss = total_val_desc_loss / len(val_dataloader)
# 学习率调度
scheduler.step(avg_val_loss)
logger.info(f"--- Epoch {epoch+1} 完成 ---")
logger.info(f"训练 - Total: {avg_train_loss:.4f}, Det: {avg_det_loss:.4f}, Desc: {avg_desc_loss:.4f}")
logger.info(f"验证 - Total: {avg_val_loss:.4f}, Det: {avg_val_det_loss:.4f}, Desc: {avg_val_desc_loss:.4f}")
logger.info(f"学习率: {optimizer.param_groups[0]['lr']:.2e}")
# 早停检查
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
patience_counter = 0
# 保存最佳模型
if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir)
save_path = os.path.join(args.save_dir, 'rord_model_best.pth')
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'best_val_loss': best_val_loss,
'config': {
'learning_rate': args.lr,
'batch_size': args.batch_size,
'epochs': args.epochs
}
}, save_path)
logger.info(f"最佳模型已保存至: {save_path}")
else:
patience_counter += 1
if patience_counter >= patience:
logger.info(f"早停触发!{patience} 个epoch没有改善")
break
# 保存最终模型
save_path = os.path.join(args.save_dir, 'rord_model_final.pth')
torch.save(model.state_dict(), save_path)
print(f"模型已保存至: {save_path}")
torch.save({
'epoch': args.epochs,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'final_val_loss': avg_val_loss,
'config': {
'learning_rate': args.lr,
'batch_size': args.batch_size,
'epochs': args.epochs
}
}, save_path)
logger.info(f"最终模型已保存至: {save_path}")
logger.info("训练完成!")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="训练 RoRD 模型")