common update

This commit is contained in:
Jiao77
2026-02-11 21:41:40 +08:00
parent f4e04f9b3c
commit ed8270b0f3
33 changed files with 1227 additions and 124 deletions

199
tests/test_model_run.py Normal file
View File

@@ -0,0 +1,199 @@
#!/usr/bin/env python3
"""
测试脚本,用于验证模型是否可以正常跑通,不需要真实数据
- 生成随机图数据
- 加载模型配置
- 初始化模型
- 运行前向传播和反向传播
- 验证模型是否可以正常工作
"""
import os
import sys
import torch
from torch_geometric.data import Data, Batch
# 添加项目根目录到 Python 路径
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.utils.config_loader import load_config
from src.models.geo_layout_transformer import GeoLayoutTransformer
from src.engine.trainer import Trainer
from src.engine.self_supervised import SelfSupervisedTrainer
from src.utils.logging import get_logger
def generate_random_graph_data(num_graphs=4, num_nodes_per_graph=8, node_feature_dim=5, edge_feature_dim=0):
"""
生成随机的图数据
Args:
num_graphs: 图的数量
num_nodes_per_graph: 每个图的节点数量
node_feature_dim: 节点特征维度
edge_feature_dim: 边特征维度
Returns:
一个 Batch 对象,包含多个随机生成的图
"""
graphs = []
for _ in range(num_graphs):
# 生成随机节点特征
x = torch.randn(num_nodes_per_graph, node_feature_dim)
# 生成随机边(完全连接)
edge_index = []
for i in range(num_nodes_per_graph):
for j in range(num_nodes_per_graph):
if i != j:
edge_index.append([i, j])
edge_index = torch.tensor(edge_index, dtype=torch.long).t()
# 生成随机标签
y = torch.randn(1, 1) # 假设是图级别的标签
# 创建图数据
graph = Data(x=x, edge_index=edge_index, y=y)
graphs.append(graph)
# 构建批次
batch = Batch.from_data_list(graphs)
return batch
def test_supervised_training():
"""测试监督训练"""
logger = get_logger("Test_Supervised_Training")
logger.info("=== 测试监督训练 ===")
# 加载配置
config = load_config('configs/default.yaml')
# 生成随机数据
batch = generate_random_graph_data()
logger.info(f"生成的批次数据: {batch}")
logger.info(f"批次大小: {batch.num_graphs}")
logger.info(f"总节点数: {batch.num_nodes}")
logger.info(f"总边数: {batch.num_edges}")
# 初始化模型
logger.info("初始化模型...")
model = GeoLayoutTransformer(config)
logger.info("模型初始化成功")
# 初始化训练器
logger.info("初始化训练器...")
trainer = Trainer(model, config)
logger.info("训练器初始化成功")
# 测试前向传播
logger.info("测试前向传播...")
with torch.no_grad():
# 先测试 GNN 编码器
gnn_output = model.gnn_encoder(batch)
logger.info(f"GNN 编码器输出形状: {gnn_output.shape}")
# 测试形状重塑
num_graphs = batch.num_graphs
nodes_per_graph = batch.ptr[1:] - batch.ptr[:-1]
logger.info(f"每个图的节点数: {nodes_per_graph}")
reshaped_embeddings = gnn_output.view(num_graphs, nodes_per_graph[0], -1)
logger.info(f"重塑后的嵌入形状: {reshaped_embeddings.shape}")
# 测试 Transformer 核心
transformer_output = model.transformer_core(reshaped_embeddings)
logger.info(f"Transformer 输出形状: {transformer_output.shape}")
# 测试完整模型
output = model(batch)
logger.info(f"前向传播成功,输出形状: {output.shape}")
# 测试反向传播
logger.info("测试反向传播...")
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
optimizer.zero_grad()
output = model(batch)
# 对输出进行全局池化,得到图级别的表示
# 从 [batch_size, seq_len, hidden_dim] 变为 [batch_size, hidden_dim]
graph_output = output.mean(dim=1)
# 使用 MSE 损失,只比较前 1 个维度(与 batch.y 形状匹配)
loss = torch.nn.MSELoss()(graph_output[:, :1], batch.y)
loss.backward()
optimizer.step()
logger.info(f"反向传播成功,损失值: {loss.item()}")
logger.info("监督训练测试完成,模型可以正常工作!")
def test_self_supervised_training():
"""测试自监督训练"""
logger = get_logger("Test_Self_Supervised_Training")
logger.info("\n=== 测试自监督训练 ===")
# 加载配置
config = load_config('configs/default.yaml')
# 生成随机数据
batch = generate_random_graph_data()
logger.info(f"生成的批次数据: {batch}")
logger.info(f"批次大小: {batch.num_graphs}")
logger.info(f"总节点数: {batch.num_nodes}")
logger.info(f"总边数: {batch.num_edges}")
# 初始化模型
logger.info("初始化模型...")
model = GeoLayoutTransformer(config)
logger.info("模型初始化成功")
# 初始化自监督训练器
logger.info("初始化自监督训练器...")
trainer = SelfSupervisedTrainer(model, config)
logger.info("自监督训练器初始化成功")
# 测试前向传播
logger.info("测试前向传播...")
with torch.no_grad():
# 测试 GNN 编码器
gnn_output = model.gnn_encoder(batch)
logger.info(f"GNN 编码器输出形状: {gnn_output.shape}")
# 测试 Transformer 核心
num_graphs = batch.num_graphs
nodes_per_graph = batch.ptr[1:] - batch.ptr[:-1]
if not torch.all(nodes_per_graph == nodes_per_graph[0]):
logger.warning("批次中图形的节点数不一致,使用第一个图形的节点数")
nodes_per_graph = nodes_per_graph[0]
gnn_output_reshaped = gnn_output.view(num_graphs, nodes_per_graph, -1)
transformer_output = model.transformer_core(gnn_output_reshaped)
logger.info(f"Transformer 核心输出形状: {transformer_output.shape}")
# 测试完整模型前向传播
logger.info("测试完整模型前向传播...")
with torch.no_grad():
output = model(batch)
logger.info(f"完整模型前向传播成功,输出形状: {output.shape}")
logger.info("自监督训练测试完成,模型可以正常工作!")
def main():
"""主函数"""
logger = get_logger("Test_Model_Run")
logger.info("开始测试模型是否可以正常跑通...")
try:
# 测试监督训练
test_supervised_training()
# 测试自监督训练
test_self_supervised_training()
logger.info("\n✅ 所有测试通过,模型可以正常跑通!")
logger.info("模型已准备就绪,可以使用真实数据进行训练。")
except Exception as e:
logger.error(f"❌ 测试失败: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":
main()