finish Experiment Tracking and Evaluation
This commit is contained in:
56
train.py
56
train.py
@@ -8,6 +8,7 @@ from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from data.ic_dataset import ICLayoutTrainingDataset
|
||||
from losses import compute_detection_loss, compute_description_loss
|
||||
@@ -45,12 +46,39 @@ def main(args):
|
||||
patch_size = int(cfg.training.patch_size)
|
||||
scale_range = tuple(float(x) for x in cfg.training.scale_jitter_range)
|
||||
|
||||
logging_cfg = cfg.get("logging", None)
|
||||
use_tensorboard = False
|
||||
log_dir = None
|
||||
experiment_name = None
|
||||
|
||||
if logging_cfg is not None:
|
||||
use_tensorboard = bool(logging_cfg.get("use_tensorboard", False))
|
||||
log_dir = logging_cfg.get("log_dir", "runs")
|
||||
experiment_name = logging_cfg.get("experiment_name", "default")
|
||||
|
||||
if args.disable_tensorboard:
|
||||
use_tensorboard = False
|
||||
if args.log_dir is not None:
|
||||
log_dir = args.log_dir
|
||||
if args.experiment_name is not None:
|
||||
experiment_name = args.experiment_name
|
||||
|
||||
writer = None
|
||||
if use_tensorboard and log_dir:
|
||||
log_root = Path(log_dir).expanduser()
|
||||
experiment_folder = experiment_name or "default"
|
||||
tb_path = log_root / "train" / experiment_folder
|
||||
tb_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
writer = SummaryWriter(tb_path.as_posix())
|
||||
|
||||
logger = setup_logging(save_dir)
|
||||
|
||||
logger.info("--- 开始训练 RoRD 模型 ---")
|
||||
logger.info(f"训练参数: Epochs={epochs}, Batch Size={batch_size}, LR={lr}")
|
||||
logger.info(f"数据目录: {data_dir}")
|
||||
logger.info(f"保存目录: {save_dir}")
|
||||
if writer:
|
||||
logger.info(f"TensorBoard 日志目录: {tb_path}")
|
||||
|
||||
transform = get_transform()
|
||||
|
||||
@@ -69,6 +97,8 @@ def main(args):
|
||||
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
|
||||
|
||||
logger.info(f"训练集大小: {len(train_dataset)}, 验证集大小: {len(val_dataset)}")
|
||||
if writer:
|
||||
writer.add_text("dataset/info", f"train={len(train_dataset)}, val={len(val_dataset)}")
|
||||
|
||||
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
|
||||
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
|
||||
@@ -115,6 +145,14 @@ def main(args):
|
||||
total_train_loss += loss.item()
|
||||
total_det_loss += det_loss.item()
|
||||
total_desc_loss += desc_loss.item()
|
||||
|
||||
if writer:
|
||||
num_batches = len(train_dataloader) if len(train_dataloader) > 0 else 1
|
||||
global_step = epoch * num_batches + i
|
||||
writer.add_scalar("train/loss_total", loss.item(), global_step)
|
||||
writer.add_scalar("train/loss_det", det_loss.item(), global_step)
|
||||
writer.add_scalar("train/loss_desc", desc_loss.item(), global_step)
|
||||
writer.add_scalar("train/lr", optimizer.param_groups[0]['lr'], global_step)
|
||||
|
||||
if i % 10 == 0:
|
||||
logger.info(f"Epoch {epoch+1}, Batch {i}, Total Loss: {loss.item():.4f}, "
|
||||
@@ -123,6 +161,10 @@ def main(args):
|
||||
avg_train_loss = total_train_loss / len(train_dataloader)
|
||||
avg_det_loss = total_det_loss / len(train_dataloader)
|
||||
avg_desc_loss = total_desc_loss / len(train_dataloader)
|
||||
if writer:
|
||||
writer.add_scalar("epoch/train_loss_total", avg_train_loss, epoch)
|
||||
writer.add_scalar("epoch/train_loss_det", avg_det_loss, epoch)
|
||||
writer.add_scalar("epoch/train_loss_desc", avg_desc_loss, epoch)
|
||||
|
||||
# 验证阶段
|
||||
model.eval()
|
||||
@@ -156,6 +198,11 @@ def main(args):
|
||||
logger.info(f"训练 - Total: {avg_train_loss:.4f}, Det: {avg_det_loss:.4f}, Desc: {avg_desc_loss:.4f}")
|
||||
logger.info(f"验证 - Total: {avg_val_loss:.4f}, Det: {avg_val_det_loss:.4f}, Desc: {avg_val_desc_loss:.4f}")
|
||||
logger.info(f"学习率: {optimizer.param_groups[0]['lr']:.2e}")
|
||||
if writer:
|
||||
writer.add_scalar("epoch/val_loss_total", avg_val_loss, epoch)
|
||||
writer.add_scalar("epoch/val_loss_det", avg_val_det_loss, epoch)
|
||||
writer.add_scalar("epoch/val_loss_desc", avg_val_desc_loss, epoch)
|
||||
writer.add_scalar("epoch/lr", optimizer.param_groups[0]['lr'], epoch)
|
||||
|
||||
# 早停检查
|
||||
if avg_val_loss < best_val_loss:
|
||||
@@ -179,6 +226,8 @@ def main(args):
|
||||
}
|
||||
}, save_path)
|
||||
logger.info(f"最佳模型已保存至: {save_path}")
|
||||
if writer:
|
||||
writer.add_scalar("checkpoint/best_val_loss", best_val_loss, epoch)
|
||||
else:
|
||||
patience_counter += 1
|
||||
if patience_counter >= patience:
|
||||
@@ -202,6 +251,10 @@ def main(args):
|
||||
logger.info(f"最终模型已保存至: {save_path}")
|
||||
logger.info("训练完成!")
|
||||
|
||||
if writer:
|
||||
writer.add_scalar("final/val_loss", avg_val_loss, epochs - 1)
|
||||
writer.close()
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="训练 RoRD 模型")
|
||||
parser.add_argument('--config', type=str, default="configs/base_config.yaml", help="YAML 配置文件路径")
|
||||
@@ -210,4 +263,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument('--epochs', type=int, default=None, help="训练轮数,若未提供则使用配置文件中的值")
|
||||
parser.add_argument('--batch_size', type=int, default=None, help="批次大小,若未提供则使用配置文件中的值")
|
||||
parser.add_argument('--lr', type=float, default=None, help="学习率,若未提供则使用配置文件中的值")
|
||||
parser.add_argument('--log_dir', type=str, default=None, help="TensorBoard 日志根目录,覆盖配置文件中的设置")
|
||||
parser.add_argument('--experiment_name', type=str, default=None, help="TensorBoard 实验名称,覆盖配置文件中的设置")
|
||||
parser.add_argument('--disable_tensorboard', action='store_true', help="禁用 TensorBoard 日志记录")
|
||||
main(parser.parse_args())
|
||||
Reference in New Issue
Block a user