common update
This commit is contained in:
49
main.py
49
main.py
@@ -4,6 +4,7 @@ from torch.utils.data import random_split
|
||||
|
||||
from src.utils.config_loader import load_config, merge_configs
|
||||
from src.utils.logging import get_logger
|
||||
from src.utils.seed import set_seed
|
||||
from src.data.dataset import LayoutDataset
|
||||
from torch_geometric.data import DataLoader
|
||||
from src.models.geo_layout_transformer import GeoLayoutTransformer
|
||||
@@ -27,22 +28,46 @@ def main():
|
||||
base_config = load_config('configs/default.yaml')
|
||||
task_config = load_config(args.config_file)
|
||||
config = merge_configs(base_config, task_config)
|
||||
|
||||
# 设置随机种子,确保实验的可重复性
|
||||
random_seed = config['splits']['random_seed']
|
||||
logger.info(f"正在设置随机种子: {random_seed}")
|
||||
set_seed(random_seed)
|
||||
|
||||
# 加载数据
|
||||
logger.info(f"从 {args.data_dir} 加载数据集")
|
||||
dataset = LayoutDataset(root=args.data_dir)
|
||||
|
||||
# TODO: 实现更完善的数据集划分逻辑
|
||||
# 这是一个简化的数据加载方式。在实际应用中,您需要将数据集划分为训练集、验证集和测试集。
|
||||
# 例如:
|
||||
# train_size = int(0.8 * len(dataset))
|
||||
# val_size = len(dataset) - train_size
|
||||
# train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
|
||||
# train_loader = DataLoader(train_dataset, batch_size=config['training']['batch_size'], shuffle=True)
|
||||
# val_loader = DataLoader(val_dataset, batch_size=config['training']['batch_size'], shuffle=False)
|
||||
|
||||
train_loader = DataLoader(dataset, batch_size=config['training']['batch_size'], shuffle=True)
|
||||
val_loader = DataLoader(dataset, batch_size=config['training']['batch_size'], shuffle=False)
|
||||
# 实现数据集划分逻辑
|
||||
logger.info("正在划分数据集...")
|
||||
train_ratio = config['splits']['train_ratio']
|
||||
val_ratio = config['splits']['val_ratio']
|
||||
test_ratio = config['splits']['test_ratio']
|
||||
random_seed = config['splits']['random_seed']
|
||||
|
||||
# 计算各数据集大小
|
||||
train_size = int(train_ratio * len(dataset))
|
||||
val_size = int(val_ratio * len(dataset))
|
||||
test_size = len(dataset) - train_size - val_size
|
||||
|
||||
# 确保各部分大小合理
|
||||
if test_size < 0:
|
||||
test_size = 0
|
||||
val_size = len(dataset) - train_size
|
||||
|
||||
# 划分数据集
|
||||
train_dataset, val_dataset, test_dataset = random_split(
|
||||
dataset,
|
||||
[train_size, val_size, test_size],
|
||||
generator=torch.Generator().manual_seed(random_seed)
|
||||
)
|
||||
|
||||
# 创建数据加载器
|
||||
train_loader = DataLoader(train_dataset, batch_size=config['training']['batch_size'], shuffle=True)
|
||||
val_loader = DataLoader(val_dataset, batch_size=config['training']['batch_size'], shuffle=False)
|
||||
test_loader = DataLoader(test_dataset, batch_size=config['training']['batch_size'], shuffle=False)
|
||||
|
||||
logger.info(f"数据集划分完成: 训练集 {len(train_dataset)}, 验证集 {len(val_dataset)}, 测试集 {len(test_dataset)}")
|
||||
|
||||
# 初始化模型
|
||||
logger.info("正在初始化模型...")
|
||||
@@ -63,7 +88,7 @@ def main():
|
||||
elif args.mode == 'eval':
|
||||
logger.info("进入评估模式...")
|
||||
evaluator = Evaluator(model)
|
||||
evaluator.evaluate(val_loader)
|
||||
evaluator.evaluate(test_loader)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user