200 lines
6.9 KiB
Python
200 lines
6.9 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
测试脚本,用于验证模型是否可以正常跑通,不需要真实数据
|
|
- 生成随机图数据
|
|
- 加载模型配置
|
|
- 初始化模型
|
|
- 运行前向传播和反向传播
|
|
- 验证模型是否可以正常工作
|
|
"""
|
|
import os
|
|
import sys
|
|
import torch
|
|
from torch_geometric.data import Data, Batch
|
|
|
|
# 添加项目根目录到 Python 路径
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
from src.utils.config_loader import load_config
|
|
from src.models.geo_layout_transformer import GeoLayoutTransformer
|
|
from src.engine.trainer import Trainer
|
|
from src.engine.self_supervised import SelfSupervisedTrainer
|
|
from src.utils.logging import get_logger
|
|
|
|
def generate_random_graph_data(num_graphs=4, num_nodes_per_graph=8, node_feature_dim=5, edge_feature_dim=0):
|
|
"""
|
|
生成随机的图数据
|
|
|
|
Args:
|
|
num_graphs: 图的数量
|
|
num_nodes_per_graph: 每个图的节点数量
|
|
node_feature_dim: 节点特征维度
|
|
edge_feature_dim: 边特征维度
|
|
|
|
Returns:
|
|
一个 Batch 对象,包含多个随机生成的图
|
|
"""
|
|
graphs = []
|
|
|
|
for _ in range(num_graphs):
|
|
# 生成随机节点特征
|
|
x = torch.randn(num_nodes_per_graph, node_feature_dim)
|
|
|
|
# 生成随机边(完全连接)
|
|
edge_index = []
|
|
for i in range(num_nodes_per_graph):
|
|
for j in range(num_nodes_per_graph):
|
|
if i != j:
|
|
edge_index.append([i, j])
|
|
edge_index = torch.tensor(edge_index, dtype=torch.long).t()
|
|
|
|
# 生成随机标签
|
|
y = torch.randn(1, 1) # 假设是图级别的标签
|
|
|
|
# 创建图数据
|
|
graph = Data(x=x, edge_index=edge_index, y=y)
|
|
graphs.append(graph)
|
|
|
|
# 构建批次
|
|
batch = Batch.from_data_list(graphs)
|
|
return batch
|
|
|
|
def test_supervised_training():
|
|
"""测试监督训练"""
|
|
logger = get_logger("Test_Supervised_Training")
|
|
logger.info("=== 测试监督训练 ===")
|
|
|
|
# 加载配置
|
|
config = load_config('configs/default.yaml')
|
|
|
|
# 生成随机数据
|
|
batch = generate_random_graph_data()
|
|
logger.info(f"生成的批次数据: {batch}")
|
|
logger.info(f"批次大小: {batch.num_graphs}")
|
|
logger.info(f"总节点数: {batch.num_nodes}")
|
|
logger.info(f"总边数: {batch.num_edges}")
|
|
|
|
# 初始化模型
|
|
logger.info("初始化模型...")
|
|
model = GeoLayoutTransformer(config)
|
|
logger.info("模型初始化成功")
|
|
|
|
# 初始化训练器
|
|
logger.info("初始化训练器...")
|
|
trainer = Trainer(model, config)
|
|
logger.info("训练器初始化成功")
|
|
|
|
# 测试前向传播
|
|
logger.info("测试前向传播...")
|
|
with torch.no_grad():
|
|
# 先测试 GNN 编码器
|
|
gnn_output = model.gnn_encoder(batch)
|
|
logger.info(f"GNN 编码器输出形状: {gnn_output.shape}")
|
|
|
|
# 测试形状重塑
|
|
num_graphs = batch.num_graphs
|
|
nodes_per_graph = batch.ptr[1:] - batch.ptr[:-1]
|
|
logger.info(f"每个图的节点数: {nodes_per_graph}")
|
|
reshaped_embeddings = gnn_output.view(num_graphs, nodes_per_graph[0], -1)
|
|
logger.info(f"重塑后的嵌入形状: {reshaped_embeddings.shape}")
|
|
|
|
# 测试 Transformer 核心
|
|
transformer_output = model.transformer_core(reshaped_embeddings)
|
|
logger.info(f"Transformer 输出形状: {transformer_output.shape}")
|
|
|
|
# 测试完整模型
|
|
output = model(batch)
|
|
logger.info(f"前向传播成功,输出形状: {output.shape}")
|
|
|
|
# 测试反向传播
|
|
logger.info("测试反向传播...")
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
|
|
optimizer.zero_grad()
|
|
output = model(batch)
|
|
|
|
# 对输出进行全局池化,得到图级别的表示
|
|
# 从 [batch_size, seq_len, hidden_dim] 变为 [batch_size, hidden_dim]
|
|
graph_output = output.mean(dim=1)
|
|
|
|
# 使用 MSE 损失,只比较前 1 个维度(与 batch.y 形状匹配)
|
|
loss = torch.nn.MSELoss()(graph_output[:, :1], batch.y)
|
|
loss.backward()
|
|
optimizer.step()
|
|
logger.info(f"反向传播成功,损失值: {loss.item()}")
|
|
|
|
logger.info("监督训练测试完成,模型可以正常工作!")
|
|
|
|
def test_self_supervised_training():
|
|
"""测试自监督训练"""
|
|
logger = get_logger("Test_Self_Supervised_Training")
|
|
logger.info("\n=== 测试自监督训练 ===")
|
|
|
|
# 加载配置
|
|
config = load_config('configs/default.yaml')
|
|
|
|
# 生成随机数据
|
|
batch = generate_random_graph_data()
|
|
logger.info(f"生成的批次数据: {batch}")
|
|
logger.info(f"批次大小: {batch.num_graphs}")
|
|
logger.info(f"总节点数: {batch.num_nodes}")
|
|
logger.info(f"总边数: {batch.num_edges}")
|
|
|
|
# 初始化模型
|
|
logger.info("初始化模型...")
|
|
model = GeoLayoutTransformer(config)
|
|
logger.info("模型初始化成功")
|
|
|
|
# 初始化自监督训练器
|
|
logger.info("初始化自监督训练器...")
|
|
trainer = SelfSupervisedTrainer(model, config)
|
|
logger.info("自监督训练器初始化成功")
|
|
|
|
# 测试前向传播
|
|
logger.info("测试前向传播...")
|
|
with torch.no_grad():
|
|
# 测试 GNN 编码器
|
|
gnn_output = model.gnn_encoder(batch)
|
|
logger.info(f"GNN 编码器输出形状: {gnn_output.shape}")
|
|
|
|
# 测试 Transformer 核心
|
|
num_graphs = batch.num_graphs
|
|
nodes_per_graph = batch.ptr[1:] - batch.ptr[:-1]
|
|
if not torch.all(nodes_per_graph == nodes_per_graph[0]):
|
|
logger.warning("批次中图形的节点数不一致,使用第一个图形的节点数")
|
|
nodes_per_graph = nodes_per_graph[0]
|
|
|
|
gnn_output_reshaped = gnn_output.view(num_graphs, nodes_per_graph, -1)
|
|
transformer_output = model.transformer_core(gnn_output_reshaped)
|
|
logger.info(f"Transformer 核心输出形状: {transformer_output.shape}")
|
|
|
|
# 测试完整模型前向传播
|
|
logger.info("测试完整模型前向传播...")
|
|
with torch.no_grad():
|
|
output = model(batch)
|
|
logger.info(f"完整模型前向传播成功,输出形状: {output.shape}")
|
|
|
|
logger.info("自监督训练测试完成,模型可以正常工作!")
|
|
|
|
def main():
|
|
"""主函数"""
|
|
logger = get_logger("Test_Model_Run")
|
|
logger.info("开始测试模型是否可以正常跑通...")
|
|
|
|
try:
|
|
# 测试监督训练
|
|
test_supervised_training()
|
|
|
|
# 测试自监督训练
|
|
test_self_supervised_training()
|
|
|
|
logger.info("\n✅ 所有测试通过,模型可以正常跑通!")
|
|
logger.info("模型已准备就绪,可以使用真实数据进行训练。")
|
|
except Exception as e:
|
|
logger.error(f"❌ 测试失败: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
sys.exit(1)
|
|
|
|
if __name__ == "__main__":
|
|
main()
|