95 lines
3.7 KiB
Python
95 lines
3.7 KiB
Python
# main.py
|
|
import argparse
|
|
from torch.utils.data import random_split
|
|
|
|
from src.utils.config_loader import load_config, merge_configs
|
|
from src.utils.logging import get_logger
|
|
from src.utils.seed import set_seed
|
|
from src.data.dataset import LayoutDataset
|
|
from torch_geometric.data import DataLoader
|
|
from src.models.geo_layout_transformer import GeoLayoutTransformer
|
|
from src.engine.trainer import Trainer
|
|
from src.engine.evaluator import Evaluator
|
|
from src.engine.self_supervised import SelfSupervisedTrainer
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Geo-Layout Transformer 的主脚本。")
|
|
parser.add_argument("--config-file", required=True, help="特定于任务的配置文件的路径。")
|
|
parser.add_argument("--mode", choices=["train", "eval", "pretrain"], required=True, help="脚本运行模式。")
|
|
parser.add_argument("--data-dir", required=True, help="已处理图数据的目录。")
|
|
parser.add_argument("--checkpoint-path", help="要加载的模型检查点的路径。")
|
|
args = parser.parse_args()
|
|
|
|
logger = get_logger("Main")
|
|
|
|
# 加载配置
|
|
logger.info("正在加载配置...")
|
|
# 首先加载基础配置,然后用任务特定配置覆盖
|
|
base_config = load_config('configs/default.yaml')
|
|
task_config = load_config(args.config_file)
|
|
config = merge_configs(base_config, task_config)
|
|
|
|
# 设置随机种子,确保实验的可重复性
|
|
random_seed = config['splits']['random_seed']
|
|
logger.info(f"正在设置随机种子: {random_seed}")
|
|
set_seed(random_seed)
|
|
|
|
# 加载数据
|
|
logger.info(f"从 {args.data_dir} 加载数据集")
|
|
dataset = LayoutDataset(root=args.data_dir)
|
|
|
|
# 实现数据集划分逻辑
|
|
logger.info("正在划分数据集...")
|
|
train_ratio = config['splits']['train_ratio']
|
|
val_ratio = config['splits']['val_ratio']
|
|
test_ratio = config['splits']['test_ratio']
|
|
random_seed = config['splits']['random_seed']
|
|
|
|
# 计算各数据集大小
|
|
train_size = int(train_ratio * len(dataset))
|
|
val_size = int(val_ratio * len(dataset))
|
|
test_size = len(dataset) - train_size - val_size
|
|
|
|
# 确保各部分大小合理
|
|
if test_size < 0:
|
|
test_size = 0
|
|
val_size = len(dataset) - train_size
|
|
|
|
# 划分数据集
|
|
train_dataset, val_dataset, test_dataset = random_split(
|
|
dataset,
|
|
[train_size, val_size, test_size],
|
|
generator=torch.Generator().manual_seed(random_seed)
|
|
)
|
|
|
|
# 创建数据加载器
|
|
train_loader = DataLoader(train_dataset, batch_size=config['training']['batch_size'], shuffle=True)
|
|
val_loader = DataLoader(val_dataset, batch_size=config['training']['batch_size'], shuffle=False)
|
|
test_loader = DataLoader(test_dataset, batch_size=config['training']['batch_size'], shuffle=False)
|
|
|
|
logger.info(f"数据集划分完成: 训练集 {len(train_dataset)}, 验证集 {len(val_dataset)}, 测试集 {len(test_dataset)}")
|
|
|
|
# 初始化模型
|
|
logger.info("正在初始化模型...")
|
|
model = GeoLayoutTransformer(config)
|
|
if args.checkpoint_path:
|
|
logger.info(f"从 {args.checkpoint_path} 加载模型检查点")
|
|
# model.load_state_dict(torch.load(args.checkpoint_path))
|
|
|
|
# 根据模式运行
|
|
if args.mode == 'pretrain':
|
|
logger.info("进入自监督预训练模式...")
|
|
trainer = SelfSupervisedTrainer(model, config)
|
|
trainer.run(train_loader)
|
|
elif args.mode == 'train':
|
|
logger.info("进入监督训练模式...")
|
|
trainer = Trainer(model, config)
|
|
trainer.run(train_loader, val_loader)
|
|
elif args.mode == 'eval':
|
|
logger.info("进入评估模式...")
|
|
evaluator = Evaluator(model)
|
|
evaluator.evaluate(test_loader)
|
|
|
|
if __name__ == "__main__":
|
|
main()
|