common update

This commit is contained in:
Jiao77
2026-02-11 21:41:40 +08:00
parent f4e04f9b3c
commit ed8270b0f3
33 changed files with 1227 additions and 124 deletions

49
main.py
View File

@@ -4,6 +4,7 @@ from torch.utils.data import random_split
from src.utils.config_loader import load_config, merge_configs
from src.utils.logging import get_logger
from src.utils.seed import set_seed
from src.data.dataset import LayoutDataset
from torch_geometric.data import DataLoader
from src.models.geo_layout_transformer import GeoLayoutTransformer
@@ -27,22 +28,46 @@ def main():
base_config = load_config('configs/default.yaml')
task_config = load_config(args.config_file)
config = merge_configs(base_config, task_config)
# 设置随机种子,确保实验的可重复性
random_seed = config['splits']['random_seed']
logger.info(f"正在设置随机种子: {random_seed}")
set_seed(random_seed)
# 加载数据
logger.info(f"{args.data_dir} 加载数据集")
dataset = LayoutDataset(root=args.data_dir)
# TODO: 实现更完善的数据集划分逻辑
# 这是一个简化的数据加载方式。在实际应用中,您需要将数据集划分为训练集、验证集和测试集。
# 例如:
# train_size = int(0.8 * len(dataset))
# val_size = len(dataset) - train_size
# train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
# train_loader = DataLoader(train_dataset, batch_size=config['training']['batch_size'], shuffle=True)
# val_loader = DataLoader(val_dataset, batch_size=config['training']['batch_size'], shuffle=False)
train_loader = DataLoader(dataset, batch_size=config['training']['batch_size'], shuffle=True)
val_loader = DataLoader(dataset, batch_size=config['training']['batch_size'], shuffle=False)
# 实现数据集划分逻辑
logger.info("正在划分数据集...")
train_ratio = config['splits']['train_ratio']
val_ratio = config['splits']['val_ratio']
test_ratio = config['splits']['test_ratio']
random_seed = config['splits']['random_seed']
# 计算各数据集大小
train_size = int(train_ratio * len(dataset))
val_size = int(val_ratio * len(dataset))
test_size = len(dataset) - train_size - val_size
# 确保各部分大小合理
if test_size < 0:
test_size = 0
val_size = len(dataset) - train_size
# 划分数据集
train_dataset, val_dataset, test_dataset = random_split(
dataset,
[train_size, val_size, test_size],
generator=torch.Generator().manual_seed(random_seed)
)
# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=config['training']['batch_size'], shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=config['training']['batch_size'], shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=config['training']['batch_size'], shuffle=False)
logger.info(f"数据集划分完成: 训练集 {len(train_dataset)}, 验证集 {len(val_dataset)}, 测试集 {len(test_dataset)}")
# 初始化模型
logger.info("正在初始化模型...")
@@ -63,7 +88,7 @@ def main():
elif args.mode == 'eval':
logger.info("进入评估模式...")
evaluator = Evaluator(model)
evaluator.evaluate(val_loader)
evaluator.evaluate(test_loader)
if __name__ == "__main__":
main()