# train.py import argparse import logging import os from datetime import datetime from pathlib import Path import torch from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter from data.ic_dataset import ICLayoutTrainingDataset from losses import compute_detection_loss, compute_description_loss from models.rord import RoRD from utils.config_loader import load_config, to_absolute_path from utils.data_utils import get_transform # 设置日志记录 def setup_logging(save_dir): """设置训练日志记录""" if not os.path.exists(save_dir): os.makedirs(save_dir) log_file = os.path.join(save_dir, f'training_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log') logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler(log_file), logging.StreamHandler() ] ) return logging.getLogger(__name__) # --- (已修改) 主函数与命令行接口 --- def main(args): cfg = load_config(args.config) config_dir = Path(args.config).resolve().parent data_dir = args.data_dir or str(to_absolute_path(cfg.paths.layout_dir, config_dir)) save_dir = args.save_dir or str(to_absolute_path(cfg.paths.save_dir, config_dir)) epochs = args.epochs if args.epochs is not None else int(cfg.training.num_epochs) batch_size = args.batch_size if args.batch_size is not None else int(cfg.training.batch_size) lr = args.lr if args.lr is not None else float(cfg.training.learning_rate) patch_size = int(cfg.training.patch_size) scale_range = tuple(float(x) for x in cfg.training.scale_jitter_range) logging_cfg = cfg.get("logging", None) use_tensorboard = False log_dir = None experiment_name = None if logging_cfg is not None: use_tensorboard = bool(logging_cfg.get("use_tensorboard", False)) log_dir = logging_cfg.get("log_dir", "runs") experiment_name = logging_cfg.get("experiment_name", "default") if args.disable_tensorboard: use_tensorboard = False if args.log_dir is not None: log_dir = args.log_dir if args.experiment_name is not None: experiment_name = args.experiment_name writer = None if use_tensorboard and log_dir: log_root = Path(log_dir).expanduser() experiment_folder = experiment_name or "default" tb_path = log_root / "train" / experiment_folder tb_path.parent.mkdir(parents=True, exist_ok=True) writer = SummaryWriter(tb_path.as_posix()) logger = setup_logging(save_dir) logger.info("--- 开始训练 RoRD 模型 ---") logger.info(f"训练参数: Epochs={epochs}, Batch Size={batch_size}, LR={lr}") logger.info(f"数据目录: {data_dir}") logger.info(f"保存目录: {save_dir}") if writer: logger.info(f"TensorBoard 日志目录: {tb_path}") transform = get_transform() dataset = ICLayoutTrainingDataset( data_dir, patch_size=patch_size, transform=transform, scale_range=scale_range, ) logger.info(f"数据集大小: {len(dataset)}") # 分割训练集和验证集 train_size = int(0.8 * len(dataset)) val_size = len(dataset) - train_size train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size]) logger.info(f"训练集大小: {len(train_dataset)}, 验证集大小: {len(val_dataset)}") if writer: writer.add_text("dataset/info", f"train={len(train_dataset)}, val={len(val_dataset)}") train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4) val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4) model = RoRD().cuda() logger.info(f"模型参数数量: {sum(p.numel() for p in model.parameters()):,}") optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4) # 添加学习率调度器 scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode='min', factor=0.5, patience=5 ) # 早停机制 best_val_loss = float('inf') patience_counter = 0 patience = 10 for epoch in range(epochs): # 训练阶段 model.train() total_train_loss = 0 total_det_loss = 0 total_desc_loss = 0 for i, (original, rotated, H) in enumerate(train_dataloader): original, rotated, H = original.cuda(), rotated.cuda(), H.cuda() det_original, desc_original = model(original) det_rotated, desc_rotated = model(rotated) det_loss = compute_detection_loss(det_original, det_rotated, H) desc_loss = compute_description_loss(desc_original, desc_rotated, H) loss = det_loss + desc_loss optimizer.zero_grad() loss.backward() # 梯度裁剪,防止梯度爆炸 torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() total_train_loss += loss.item() total_det_loss += det_loss.item() total_desc_loss += desc_loss.item() if writer: num_batches = len(train_dataloader) if len(train_dataloader) > 0 else 1 global_step = epoch * num_batches + i writer.add_scalar("train/loss_total", loss.item(), global_step) writer.add_scalar("train/loss_det", det_loss.item(), global_step) writer.add_scalar("train/loss_desc", desc_loss.item(), global_step) writer.add_scalar("train/lr", optimizer.param_groups[0]['lr'], global_step) if i % 10 == 0: logger.info(f"Epoch {epoch+1}, Batch {i}, Total Loss: {loss.item():.4f}, " f"Det Loss: {det_loss.item():.4f}, Desc Loss: {desc_loss.item():.4f}") avg_train_loss = total_train_loss / len(train_dataloader) avg_det_loss = total_det_loss / len(train_dataloader) avg_desc_loss = total_desc_loss / len(train_dataloader) if writer: writer.add_scalar("epoch/train_loss_total", avg_train_loss, epoch) writer.add_scalar("epoch/train_loss_det", avg_det_loss, epoch) writer.add_scalar("epoch/train_loss_desc", avg_desc_loss, epoch) # 验证阶段 model.eval() total_val_loss = 0 total_val_det_loss = 0 total_val_desc_loss = 0 with torch.no_grad(): for original, rotated, H in val_dataloader: original, rotated, H = original.cuda(), rotated.cuda(), H.cuda() det_original, desc_original = model(original) det_rotated, desc_rotated = model(rotated) val_det_loss = compute_detection_loss(det_original, det_rotated, H) val_desc_loss = compute_description_loss(desc_original, desc_rotated, H) val_loss = val_det_loss + val_desc_loss total_val_loss += val_loss.item() total_val_det_loss += val_det_loss.item() total_val_desc_loss += val_desc_loss.item() avg_val_loss = total_val_loss / len(val_dataloader) avg_val_det_loss = total_val_det_loss / len(val_dataloader) avg_val_desc_loss = total_val_desc_loss / len(val_dataloader) # 学习率调度 scheduler.step(avg_val_loss) logger.info(f"--- Epoch {epoch+1} 完成 ---") logger.info(f"训练 - Total: {avg_train_loss:.4f}, Det: {avg_det_loss:.4f}, Desc: {avg_desc_loss:.4f}") logger.info(f"验证 - Total: {avg_val_loss:.4f}, Det: {avg_val_det_loss:.4f}, Desc: {avg_val_desc_loss:.4f}") logger.info(f"学习率: {optimizer.param_groups[0]['lr']:.2e}") if writer: writer.add_scalar("epoch/val_loss_total", avg_val_loss, epoch) writer.add_scalar("epoch/val_loss_det", avg_val_det_loss, epoch) writer.add_scalar("epoch/val_loss_desc", avg_val_desc_loss, epoch) writer.add_scalar("epoch/lr", optimizer.param_groups[0]['lr'], epoch) # 早停检查 if avg_val_loss < best_val_loss: best_val_loss = avg_val_loss patience_counter = 0 # 保存最佳模型 if not os.path.exists(save_dir): os.makedirs(save_dir) save_path = os.path.join(save_dir, 'rord_model_best.pth') torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'best_val_loss': best_val_loss, 'config': { 'learning_rate': lr, 'batch_size': batch_size, 'epochs': epochs, 'config_path': str(Path(args.config).resolve()), } }, save_path) logger.info(f"最佳模型已保存至: {save_path}") if writer: writer.add_scalar("checkpoint/best_val_loss", best_val_loss, epoch) else: patience_counter += 1 if patience_counter >= patience: logger.info(f"早停触发!{patience} 个epoch没有改善") break # 保存最终模型 save_path = os.path.join(save_dir, 'rord_model_final.pth') torch.save({ 'epoch': epochs, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'final_val_loss': avg_val_loss, 'config': { 'learning_rate': lr, 'batch_size': batch_size, 'epochs': epochs, 'config_path': str(Path(args.config).resolve()), } }, save_path) logger.info(f"最终模型已保存至: {save_path}") logger.info("训练完成!") if writer: writer.add_scalar("final/val_loss", avg_val_loss, epochs - 1) writer.close() if __name__ == "__main__": parser = argparse.ArgumentParser(description="训练 RoRD 模型") parser.add_argument('--config', type=str, default="configs/base_config.yaml", help="YAML 配置文件路径") parser.add_argument('--data_dir', type=str, default=None, help="训练数据目录,若未提供则使用配置文件中的路径") parser.add_argument('--save_dir', type=str, default=None, help="模型保存目录,若未提供则使用配置文件中的路径") parser.add_argument('--epochs', type=int, default=None, help="训练轮数,若未提供则使用配置文件中的值") parser.add_argument('--batch_size', type=int, default=None, help="批次大小,若未提供则使用配置文件中的值") parser.add_argument('--lr', type=float, default=None, help="学习率,若未提供则使用配置文件中的值") parser.add_argument('--log_dir', type=str, default=None, help="TensorBoard 日志根目录,覆盖配置文件中的设置") parser.add_argument('--experiment_name', type=str, default=None, help="TensorBoard 实验名称,覆盖配置文件中的设置") parser.add_argument('--disable_tensorboard', action='store_true', help="禁用 TensorBoard 日志记录") main(parser.parse_args())