common update

This commit is contained in:
Jiao77
2026-02-11 21:41:40 +08:00
parent f4e04f9b3c
commit ed8270b0f3
33 changed files with 1227 additions and 124 deletions

16
TODO.md
View File

@@ -49,37 +49,37 @@
## 优先级清单(可执行项)
### P0立即优先
- [ ] 数据集切分与 DataLoader 管线
- [x] 数据集切分与 DataLoader 管线
-`main.py` 引入可配置的 train/val/test 切分比例与随机种子;支持从目录/清单载入各 split。
-`configs/default.yaml` 增加 `splits` 字段;更新 `README*` 用法说明。
- [ ] 监督训练工程化
- [x] 监督训练工程化
-`trainer.py` 补充验证阶段与最佳模型保存(`torch.save` 至指定路径)。
- 引入学习率调度器(如 StepLR/CosineAnnealingWarmRestarts与早停策略。
- 支持 class weights/focal loss`trainer.py` 增加 `focal_loss` 实现并在配置选择。
- [ ] 自监督预训练修复
- [x] 自监督预训练修复
- 明确 batch 内每图的 patch 序列映射:根据 `batch.ptr` 逐图生成 mask 索引,避免跨图混淆。
- 将掩码作用在输入特征/图结构层而非已池化的图级嵌入或增加“节点级→patch 聚合→重建头”。
-`transformer_core` 或单独模块增加重建头MLP以回归原 patch 表征;提供单元测试。
### P1高优
- [ ] 任务头与损失扩展
- [x] 任务头与损失扩展
-`task_heads.py` 增加多标签分类、回归头增添可插拔的池化CLS token/Mean/Max/Attention Pool
-`trainer.py` 支持多任务训练配置(不同 head/loss 的加权)。
- [ ] 训练与日志可观测性
- [x] 训练与日志可观测性
- 增加 TensorBoard/CSVLogger记录 epoch 指标、学习率、耗时;保存 `config``git` 提交信息。
- 固定随机种子PyTorch/NumPy/环境变量),在 `utils` 中提供 `set_seed()` 并在入口调用。
- [ ] 可复现实验与最小数据
- [x] 可复现实验与最小数据
- 提供最小 GDS 示例与对应的 processed `.pt` 小样,便于 CI 与用户快速体验。
-`scripts/` 增加一键跑通的小样流程脚本preprocess→train→eval
### P2中优
- [ ] 大图/性能优化
- [x] 大图/性能优化
- 引入混合精度(`torch.cuda.amp`)、梯度累积、可选更小 batch监控显存。
- 探索 GraphSAINT/Cluster-GCN 等大图训练策略,并与当前 patch 划分结合。
- [ ] I/O 与生态集成
- `klayout` Python API 的可选集成与安装脚本说明;解析 OASIS 的路径补全与测试。
-`graph_constructor.py` 为边策略加入可学习/基于几何关系的拓展(如跨层连接边)。
- [ ] 可解释性与可视化
- [x] 可解释性与可视化
- 完成 `scripts/visualize_attention.py`:注册 Hook 提取注意力/特征图,绘图并保存到 `docs/`
-`Data.node_meta` 基础上支持几何叠加可视化patch bbox 与局部多边形)。

View File

@@ -22,7 +22,7 @@ model:
hidden_dim: 128
output_dim: 256 # Dimension of the patch embedding
num_layers: 4
gnn_type: "rgat" # 'rgat', 'gcn', 'graphsage'
gnn_type: "gat" # 'gat', 'gcn', 'graphsage'
# Transformer Backbone
transformer:
@@ -42,9 +42,25 @@ training:
optimizer: "adamw"
loss_function: "bce" # 'bce', 'focal_loss'
weight_decay: 0.01
scheduler: "cosine" # 'step', 'cosine'
scheduler_T_0: 10
scheduler_T_mult: 2
early_stopping_patience: 10
save_dir: "checkpoints"
log_dir: "logs"
use_amp: false # 是否启用混合精度训练
gradient_accumulation_steps: 1 # 梯度累积步数
# 4. Data Splits
splits:
train_ratio: 0.8
val_ratio: 0.1
test_ratio: 0.1
random_seed: 42
# 4. Self-Supervised Pre-training
pretraining:
mask_ratio: 0.15
epochs: 200
learning_rate: 0.0005
early_stopping_patience: 10

View File

@@ -0,0 +1,102 @@
#!/usr/bin/env python3
"""
生成示例数据的脚本
- 创建一个简单的 GDS 文件
- 使用 preprocess_gds.py 处理它,生成示例数据集
"""
import os
import sys
import gdstk
import numpy as np
# 添加项目根目录到 Python 路径
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
def create_simple_gds(output_file):
"""创建一个简单的 GDS 文件,包含几个矩形"""
# 创建一个新的库
lib = gdstk.Library("simple_layout")
# 创建一个新的单元
top_cell = lib.new_cell("TOP")
# 在不同层上添加几个矩形
# 层 1: 金属层 1
rect1 = gdstk.rectangle((0, 0), (10, 10), layer=1, datatype=0)
top_cell.add(rect1)
# 层 2: 过孔层
via = gdstk.rectangle((4, 4), (6, 6), layer=2, datatype=0)
top_cell.add(via)
# 层 3: 金属层 2
rect2 = gdstk.rectangle((2, 2), (8, 8), layer=3, datatype=0)
top_cell.add(rect2)
# 保存 GDS 文件
lib.write_gds(output_file)
print(f"已创建 GDS 文件: {output_file}")
def preprocess_sample_data(gds_file, output_dir):
"""使用 preprocess_gds.py 处理 GDS 文件,生成示例数据集"""
import subprocess
# 确保输出目录存在
os.makedirs(output_dir, exist_ok=True)
# 运行 preprocess_gds.py 脚本
script_path = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "scripts", "preprocess_gds.py")
# 创建层映射配置
layer_mapping = {
"1/0": 0, # 金属层 1
"2/0": 1, # 过孔层
"3/0": 2 # 金属层 2
}
# 构建命令
cmd = [
sys.executable, script_path,
"--gds-file", gds_file,
"--output-dir", output_dir,
"--patch-size", "5.0",
"--patch-stride", "2.5"
]
# 添加层映射参数
for layer_str, idx in layer_mapping.items():
cmd.extend(["--layer-mapping", f"{layer_str}:{idx}"])
print(f"运行预处理命令: {' '.join(cmd)}")
# 执行命令
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode == 0:
print("预处理成功完成!")
print("输出:")
print(result.stdout)
else:
print("预处理失败!")
print("错误:")
print(result.stderr)
def main():
"""主函数"""
# 定义路径
examples_dir = os.path.dirname(os.path.abspath(__file__))
gds_file = os.path.join(examples_dir, "simple_layout.gds")
output_dir = os.path.join(examples_dir, "processed_data")
# 创建 GDS 文件
create_simple_gds(gds_file)
# 预处理数据
preprocess_sample_data(gds_file, output_dir)
print("\n示例数据生成完成!")
print(f"GDS 文件: {gds_file}")
print(f"处理后的数据: {output_dir}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,89 @@
#!/usr/bin/env python3
"""
一键运行的小样流程脚本
- 生成示例数据
- 训练模型
- 评估模型
"""
import os
import sys
import subprocess
import time
# 添加项目根目录到 Python 路径
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
def run_command(cmd, cwd=None):
"""运行命令并打印输出"""
print(f"\n运行命令: {' '.join(cmd)}")
result = subprocess.run(cmd, cwd=cwd, capture_output=True, text=True)
print("输出:")
print(result.stdout)
if result.stderr:
print("错误:")
print(result.stderr)
if result.returncode != 0:
print(f"命令执行失败,返回码: {result.returncode}")
sys.exit(1)
return result
def generate_sample_data():
"""生成示例数据"""
print("\n=== 步骤 1: 生成示例数据 ===")
script_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "generate_sample_data.py")
run_command([sys.executable, script_path])
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "processed_data")
def train_model(data_dir):
"""训练模型"""
print("\n=== 步骤 2: 训练模型 ===")
main_script = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "main.py")
config_file = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "configs", "hotspot_detection.yaml")
# 运行训练命令
cmd = [
sys.executable, main_script,
"--config-file", config_file,
"--mode", "train",
"--data-dir", data_dir
]
run_command(cmd)
def evaluate_model(data_dir):
"""评估模型"""
print("\n=== 步骤 3: 评估模型 ===")
main_script = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "main.py")
config_file = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "configs", "hotspot_detection.yaml")
# 运行评估命令
cmd = [
sys.executable, main_script,
"--config-file", config_file,
"--mode", "eval",
"--data-dir", data_dir
]
run_command(cmd)
def main():
"""主函数"""
start_time = time.time()
print("Geo-Layout Transformer 小样流程")
print("==============================")
# 步骤 1: 生成示例数据
data_dir = generate_sample_data()
# 步骤 2: 训练模型
train_model(data_dir)
# 步骤 3: 评估模型
evaluate_model(data_dir)
total_time = time.time() - start_time
print(f"\n=== 流程完成 ===")
print(f"总耗时: {total_time:.2f}")
print("示例流程已成功运行!")
if __name__ == "__main__":
main()

File diff suppressed because one or more lines are too long

49
main.py
View File

@@ -4,6 +4,7 @@ from torch.utils.data import random_split
from src.utils.config_loader import load_config, merge_configs
from src.utils.logging import get_logger
from src.utils.seed import set_seed
from src.data.dataset import LayoutDataset
from torch_geometric.data import DataLoader
from src.models.geo_layout_transformer import GeoLayoutTransformer
@@ -27,22 +28,46 @@ def main():
base_config = load_config('configs/default.yaml')
task_config = load_config(args.config_file)
config = merge_configs(base_config, task_config)
# 设置随机种子,确保实验的可重复性
random_seed = config['splits']['random_seed']
logger.info(f"正在设置随机种子: {random_seed}")
set_seed(random_seed)
# 加载数据
logger.info(f"{args.data_dir} 加载数据集")
dataset = LayoutDataset(root=args.data_dir)
# TODO: 实现更完善的数据集划分逻辑
# 这是一个简化的数据加载方式。在实际应用中,您需要将数据集划分为训练集、验证集和测试集。
# 例如:
# train_size = int(0.8 * len(dataset))
# val_size = len(dataset) - train_size
# train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
# train_loader = DataLoader(train_dataset, batch_size=config['training']['batch_size'], shuffle=True)
# val_loader = DataLoader(val_dataset, batch_size=config['training']['batch_size'], shuffle=False)
train_loader = DataLoader(dataset, batch_size=config['training']['batch_size'], shuffle=True)
val_loader = DataLoader(dataset, batch_size=config['training']['batch_size'], shuffle=False)
# 实现数据集划分逻辑
logger.info("正在划分数据集...")
train_ratio = config['splits']['train_ratio']
val_ratio = config['splits']['val_ratio']
test_ratio = config['splits']['test_ratio']
random_seed = config['splits']['random_seed']
# 计算各数据集大小
train_size = int(train_ratio * len(dataset))
val_size = int(val_ratio * len(dataset))
test_size = len(dataset) - train_size - val_size
# 确保各部分大小合理
if test_size < 0:
test_size = 0
val_size = len(dataset) - train_size
# 划分数据集
train_dataset, val_dataset, test_dataset = random_split(
dataset,
[train_size, val_size, test_size],
generator=torch.Generator().manual_seed(random_seed)
)
# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=config['training']['batch_size'], shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=config['training']['batch_size'], shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=config['training']['batch_size'], shuffle=False)
logger.info(f"数据集划分完成: 训练集 {len(train_dataset)}, 验证集 {len(val_dataset)}, 测试集 {len(test_dataset)}")
# 初始化模型
logger.info("正在初始化模型...")
@@ -63,7 +88,7 @@ def main():
elif args.mode == 'eval':
logger.info("进入评估模式...")
evaluator = Evaluator(model)
evaluator.evaluate(val_loader)
evaluator.evaluate(test_loader)
if __name__ == "__main__":
main()

View File

@@ -10,6 +10,7 @@ dependencies = [
"pandas>=2.3.2",
"pyyaml>=6.0.2",
"scikit-learn>=1.7.1",
"tensorboard>=2.20.0",
"torch>=2.8.0",
"torch-geometric>=2.6.1",
"torchvision>=0.23.0",

View File

@@ -3,6 +3,7 @@ import argparse
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import os
from src.utils.config_loader import load_config
from src.models.geo_layout_transformer import GeoLayoutTransformer
@@ -13,52 +14,93 @@ def main():
parser.add_argument("--config-file", required=True, help="模型配置文件的路径。")
parser.add_argument("--model-path", required=True, help="已训练模型检查点的路径。")
parser.add_argument("--patch-data", required=True, help="区块数据样本(.pt 文件)的路径。")
parser.add_argument("--output-dir", default="docs/attention_visualization", help="注意力图保存目录。")
parser.add_argument("--layer-index", type=int, default=0, help="要可视化的 Transformer 层索引。")
parser.add_argument("--head-index", type=int, default=-1, help="要可视化的注意力头索引,-1 表示所有头的平均值。")
args = parser.parse_args()
logger = get_logger("Attention_Visualizer")
logger.info("这是一个用于注意力可视化的占位符脚本。")
logger.info("完整的实现需要加载一个训练好的模型、一个数据样本,然后提取注意力权重。")
# 确保输出目录存在
os.makedirs(args.output_dir, exist_ok=True)
# 1. 加载配置和模型
# logger.info("正在加载模型...")
# config = load_config(args.config_file)
# model = GeoLayoutTransformer(config)
# model.load_state_dict(torch.load(args.model_path))
# model.eval()
logger.info("正在加载模型...")
config = load_config(args.config_file)
model = GeoLayoutTransformer(config)
model.load_state_dict(torch.load(args.model_path, map_location=torch.device('cpu')))
model.eval()
# 2. 加载一个数据样本
# logger.info(f"正在加载数据样本从 {args.patch_data}")
# sample_data = torch.load(args.patch_data)
logger.info(f"正在加载数据样本从 {args.patch_data}")
sample_data = torch.load(args.patch_data)
# 3. 注册钩子Hook到模型中以提取注意力权重
# 这是一个复杂的过程,需要访问 nn.MultiheadAttention 模块的前向传播过程。
# attention_weights = []
# def hook(module, input, output):
# # output[1] 是注意力权重
# attention_weights.append(output[1])
# model.transformer_core.transformer_encoder.layers[0].self_attn.register_forward_hook(hook)
attention_weights = []
def hook(module, input, output):
# 对于 PyTorch 的 nn.MultiheadAttentionoutput 是一个元组
# output[0] 是注意力输出,output[1] 是注意力权重
if len(output) > 1:
attention_weights.append(output[1])
# 获取指定层的自注意力模块
if hasattr(model.transformer_core.transformer_encoder, 'layers'):
layer = model.transformer_core.transformer_encoder.layers[args.layer_index]
if hasattr(layer, 'self_attn'):
layer.self_attn.register_forward_hook(hook)
logger.info(f"已注册钩子到 Transformer 层 {args.layer_index} 的自注意力模块")
else:
logger.error("找不到自注意力模块")
return
else:
logger.error("找不到 Transformer 层")
return
# 4. 运行一次前向传播以获取权重
# logger.info("正在运行前向传播...")
# with torch.no_grad():
# # 模型需要修改以支持返回注意力权重,或者通过钩子获取
# _ = model(sample_data)
logger.info("正在运行前向传播...")
with torch.no_grad():
_ = model(sample_data)
# 5. 绘制注意力图
# if attention_weights:
# logger.info("正在绘制注意力图...")
# # attention_weights[0] 的形状是 [batch_size, num_heads, seq_len, seq_len]
# # 我们取第一项,并在所有头上取平均值
# avg_attention = attention_weights[0][0].mean(dim=0).cpu().numpy()
# plt.figure(figsize=(10, 10))
# sns.heatmap(avg_attention, cmap='viridis')
# plt.title("区块之间的平均注意力图")
# plt.xlabel("区块索引")
# plt.ylabel("区块索引")
# plt.show()
# else:
# logger.warning("未能提取注意力权重。")
if attention_weights:
logger.info("正在绘制注意力图...")
# attention_weights[0] 的形状是 [batch_size, num_heads, seq_len, seq_len]
attn_weights = attention_weights[0]
batch_size, num_heads, seq_len, _ = attn_weights.shape
logger.info(f"注意力权重形状: batch_size={batch_size}, num_heads={num_heads}, seq_len={seq_len}")
# 选择第一个样本
sample_attn = attn_weights[0]
if args.head_index == -1:
# 计算所有头的平均值
avg_attention = sample_attn.mean(dim=0).cpu().numpy()
plt.figure(figsize=(12, 10))
sns.heatmap(avg_attention, cmap='viridis', square=True, vmin=0, vmax=1)
plt.title(f"所有注意力头的平均注意力图 (Layer {args.layer_index})")
plt.xlabel("区块索引")
plt.ylabel("区块索引")
output_file = os.path.join(args.output_dir, f"attention_layer_{args.layer_index}_avg.png")
plt.savefig(output_file, bbox_inches='tight', dpi=150)
logger.info(f"已保存平均注意力图到 {output_file}")
else:
# 可视化指定的注意力头
if 0 <= args.head_index < num_heads:
head_attention = sample_attn[args.head_index].cpu().numpy()
plt.figure(figsize=(12, 10))
sns.heatmap(head_attention, cmap='viridis', square=True, vmin=0, vmax=1)
plt.title(f"注意力头 {args.head_index} 的注意力图 (Layer {args.layer_index})")
plt.xlabel("区块索引")
plt.ylabel("区块索引")
output_file = os.path.join(args.output_dir, f"attention_layer_{args.layer_index}_head_{args.head_index}.png")
plt.savefig(output_file, bbox_inches='tight', dpi=150)
logger.info(f"已保存注意力头 {args.head_index} 的注意力图到 {output_file}")
else:
logger.error(f"注意力头索引 {args.head_index} 超出范围,有效范围是 0-{num_heads-1}")
else:
logger.warning("未能提取注意力权重。")
if __name__ == "__main__":
main()

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -3,7 +3,10 @@ import torch
import torch.nn as nn
from torch.optim import AdamW
from torch_geometric.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from ..utils.logging import get_logger
import os
import time
class SelfSupervisedTrainer:
"""处理自监督预训练循环(掩码版图建模)。"""
@@ -16,53 +19,179 @@ class SelfSupervisedTrainer:
# 使用均方误差损失来重建嵌入向量
self.criterion = nn.MSELoss()
# 初始化可学习的 [MASK] 嵌入
self.mask_embedding = nn.Parameter(torch.randn(config['model']['gnn']['output_dim']))
# 将其添加到模型参数中,使其可被优化
self.model.register_parameter('mask_embedding', self.mask_embedding)
# 初始化重建头
hidden_dim = config['model']['transformer']['hidden_dim']
output_dim = config['model']['gnn']['output_dim']
self.reconstruction_head = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, output_dim)
)
# 确保保存目录存在
self.save_dir = config.get('save_dir', 'checkpoints')
os.makedirs(self.save_dir, exist_ok=True)
# 初始化 TensorBoard 日志记录器
self.log_dir = config.get('log_dir', 'logs/pretrain')
os.makedirs(self.log_dir, exist_ok=True)
self.writer = SummaryWriter(log_dir=self.log_dir)
# 初始化早停相关变量
self.best_loss = float('inf')
self.patience = config['pretraining'].get('early_stopping_patience', 10)
self.counter = 0
self.early_stop = False
# 初始化混合精度训练
self.use_amp = config['training'].get('use_amp', False)
self.scaler = torch.cuda.amp.GradScaler() if self.use_amp else None
# 初始化梯度累积
self.gradient_accumulation_steps = config['training'].get('gradient_accumulation_steps', 1)
if self.gradient_accumulation_steps > 1:
self.logger.info(f"启用梯度累积,累积步数: {self.gradient_accumulation_steps}")
def train_epoch(self, dataloader: DataLoader):
"""运行单个预训练周期。"""
self.model.train()
self.reconstruction_head.train()
total_loss = 0
mask_ratio = self.config['pretraining']['mask_ratio']
for batch in dataloader:
self.optimizer.zero_grad()
for i, batch in enumerate(dataloader):
# 只有在梯度累积的第一步或不需要累积时才清空梯度
if i % self.gradient_accumulation_steps == 0:
self.optimizer.zero_grad()
# 1. 获取原始的区块嵌入(作为重建的目标)
with torch.no_grad():
# 使用混合精度训练
if self.use_amp:
with torch.cuda.amp.autocast():
# 1. 获取原始的区块嵌入(作为重建的目标)
original_embeddings = self.model.gnn_encoder(batch)
# 2. 根据 batch.ptr 逐图生成 mask 索引,避免跨图混淆
num_graphs = batch.num_graphs
nodes_per_graph = batch.ptr[1:] - batch.ptr[:-1]
# 确保所有图的节点数相同
if not torch.all(nodes_per_graph == nodes_per_graph[0]):
self.logger.warning("批次中图形的节点数不一致,使用第一个图形的节点数")
nodes_per_graph = nodes_per_graph[0]
# 为每个图单独生成掩码
all_masked_indices = []
for j in range(num_graphs):
# 计算当前图的节点在批次中的起始和结束索引
start_idx = batch.ptr[j]
end_idx = batch.ptr[j+1]
num_patches = end_idx - start_idx
num_masked = int(mask_ratio * num_patches)
# 生成当前图内的掩码索引
graph_masked_indices = torch.randperm(num_patches)[:num_masked] + start_idx
all_masked_indices.append(graph_masked_indices)
# 合并所有图的掩码索引
masked_indices = torch.cat(all_masked_indices)
# 3. 创建损坏的嵌入
corrupted_embeddings = original_embeddings.clone()
# 使用可学习的 [MASK] 嵌入
corrupted_embeddings[masked_indices] = self.mask_embedding.to(corrupted_embeddings.device)
# 4. 为 Transformer 重塑形状
corrupted_embeddings = corrupted_embeddings.view(num_graphs, nodes_per_graph, -1)
# 5. 将损坏的嵌入传入 Transformer 进行编码
encoded_embeddings = self.model.transformer_core(corrupted_embeddings)
# 6. 通过重建头生成重建的嵌入
reconstructed_embeddings = self.reconstruction_head(encoded_embeddings)
# 7. 只在被掩盖的区块上计算损失
# 将 Transformer 输出和原始嵌入都拉平成 (N, D) 的形状
reconstructed_flat = reconstructed_embeddings.view(-1, original_embeddings.size(1))
# 只选择被掩盖的那些进行比较
loss = self.criterion(
reconstructed_flat[masked_indices],
original_embeddings[masked_indices]
)
# 缩放损失以防止梯度下溢
self.scaler.scale(loss).backward()
# 只有在累积步数达到设定值时才更新权重
if (i + 1) % self.gradient_accumulation_steps == 0:
# 取消缩放并更新权重
self.scaler.step(self.optimizer)
self.scaler.update()
else:
# 标准训练流程
# 1. 获取原始的区块嵌入(作为重建的目标)
original_embeddings = self.model.gnn_encoder(batch)
# 2. 创建掩码并损坏输入
num_patches = original_embeddings.size(0)
num_masked = int(mask_ratio * num_patches)
# 随机选择要掩盖的区块索引
masked_indices = torch.randperm(num_patches)[:num_masked]
# 2. 根据 batch.ptr 逐图生成 mask 索引,避免跨图混淆
num_graphs = batch.num_graphs
nodes_per_graph = batch.ptr[1:] - batch.ptr[:-1]
# 确保所有图的节点数相同
if not torch.all(nodes_per_graph == nodes_per_graph[0]):
self.logger.warning("批次中图形的节点数不一致,使用第一个图形的节点数")
nodes_per_graph = nodes_per_graph[0]
# 为每个图单独生成掩码
all_masked_indices = []
for j in range(num_graphs):
# 计算当前图的节点在批次中的起始和结束索引
start_idx = batch.ptr[j]
end_idx = batch.ptr[j+1]
num_patches = end_idx - start_idx
num_masked = int(mask_ratio * num_patches)
# 生成当前图内的掩码索引
graph_masked_indices = torch.randperm(num_patches)[:num_masked] + start_idx
all_masked_indices.append(graph_masked_indices)
# 合并所有图的掩码索引
masked_indices = torch.cat(all_masked_indices)
# 3. 创建损坏的嵌入
corrupted_embeddings = original_embeddings.clone()
# 使用可学习的 [MASK] 嵌入
corrupted_embeddings[masked_indices] = self.mask_embedding.to(corrupted_embeddings.device)
# 4. 为 Transformer 重塑形状
corrupted_embeddings = corrupted_embeddings.view(num_graphs, nodes_per_graph, -1)
# 5. 将损坏的嵌入传入 Transformer 进行编码
encoded_embeddings = self.model.transformer_core(corrupted_embeddings)
# 6. 通过重建头生成重建的嵌入
reconstructed_embeddings = self.reconstruction_head(encoded_embeddings)
# 7. 只在被掩盖的区块上计算损失
# 将 Transformer 输出和原始嵌入都拉平成 (N, D) 的形状
reconstructed_flat = reconstructed_embeddings.view(-1, original_embeddings.size(1))
# 只选择被掩盖的那些进行比较
loss = self.criterion(
reconstructed_flat[masked_indices],
original_embeddings[masked_indices]
)
loss.backward()
# 只有在累积步数达到设定值时才更新权重
if (i + 1) % self.gradient_accumulation_steps == 0:
# 更新权重
self.optimizer.step()
# 创建一个损坏的嵌入副本
# 这是一个简化的方法。更稳健的方法是直接在批次数据中掩盖特征。
# 在这个占位符中,我们直接掩盖嵌入向量。
corrupted_embeddings = original_embeddings.clone()
# 创建一个可学习的 [MASK] 嵌入
mask_embedding = nn.Parameter(torch.randn(original_embeddings.size(1), device=original_embeddings.device))
corrupted_embeddings[masked_indices] = mask_embedding
# 3. 为 Transformer 重塑形状
num_graphs = batch.num_graphs
nodes_per_graph = batch.ptr[1:] - batch.ptr[:-1]
corrupted_embeddings = corrupted_embeddings.view(num_graphs, nodes_per_graph[0], -1)
# 4. 将损坏的嵌入传入 Transformer 进行重建
# 注意:这里只用了 transformer_core没有用 task_head
reconstructed_embeddings = self.model.transformer_core(corrupted_embeddings)
# 5. 只在被掩盖的区块上计算损失
# 将 Transformer 输出和原始嵌入都拉平成 (N, D) 的形状
reconstructed_flat = reconstructed_embeddings.view(-1, original_embeddings.size(1))
# 只选择被掩盖的那些进行比较
loss = self.criterion(
reconstructed_flat[masked_indices],
original_embeddings[masked_indices]
)
loss.backward()
self.optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(dataloader)
@@ -72,7 +201,63 @@ class SelfSupervisedTrainer:
def run(self, train_loader: DataLoader):
"""运行完整的预训练流程。"""
self.logger.info("开始自监督预训练...")
start_time = time.time()
for epoch in range(self.config['pretraining']['epochs']):
if self.early_stop:
self.logger.info("早停触发,停止预训练。")
break
epoch_start_time = time.time()
self.logger.info(f"周期 {epoch+1}/{self.config['pretraining']['epochs']}")
self.train_epoch(train_loader)
current_loss = self.train_epoch(train_loader)
# 记录学习率
current_lr = self.optimizer.param_groups[0]['lr']
# 记录到 TensorBoard
self.writer.add_scalar('Loss/pretrain', current_loss, epoch)
self.writer.add_scalar('Learning Rate', current_lr, epoch)
# 计算周期耗时
epoch_time = time.time() - epoch_start_time
self.writer.add_scalar('Time/epoch', epoch_time, epoch)
self.logger.info(f"周期耗时: {epoch_time:.2f}")
# 检查是否需要保存最佳模型
if current_loss < self.best_loss:
self.best_loss = current_loss
self.counter = 0
# 保存最佳模型
save_path = os.path.join(self.save_dir, 'best_pretrain_model.pth')
torch.save({
'model_state_dict': self.model.state_dict(),
'reconstruction_head_state_dict': self.reconstruction_head.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'best_loss': self.best_loss
}, save_path)
self.logger.info(f"保存最佳预训练模型到 {save_path}")
else:
self.counter += 1
if self.counter >= self.patience:
self.early_stop = True
self.logger.info(f"预训练损失连续 {self.patience} 个周期未改善,触发早停。")
# 计算总训练耗时
total_time = time.time() - start_time
self.logger.info(f"总预训练耗时: {total_time:.2f}")
# 保存最后一个模型
save_path = os.path.join(self.save_dir, 'last_pretrain_model.pth')
torch.save({
'model_state_dict': self.model.state_dict(),
'reconstruction_head_state_dict': self.reconstruction_head.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict()
}, save_path)
self.logger.info(f"保存最后一个预训练模型到 {save_path}")
# 关闭 TensorBoard SummaryWriter
self.writer.close()
self.logger.info("预训练完成。")
self.logger.info(f"最佳预训练损失: {self.best_loss:.4f}")

View File

@@ -2,8 +2,34 @@
import torch
import torch.nn as nn
from torch.optim import Adam, AdamW
from torch.optim.lr_scheduler import StepLR, CosineAnnealingWarmRestarts
from torch_geometric.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from ..utils.logging import get_logger
from .evaluator import Evaluator
import os
import time
class FocalLoss(nn.Module):
"""Focal Loss 实现,用于处理类别不平衡问题。"""
def __init__(self, alpha=1, gamma=2, reduction='mean'):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
self.bce_with_logits = nn.BCEWithLogitsLoss(reduction='none')
def forward(self, inputs, targets):
bce_loss = self.bce_with_logits(inputs, targets)
pt = torch.exp(-bce_loss)
focal_loss = self.alpha * (1 - pt) ** self.gamma * bce_loss
if self.reduction == 'mean':
return focal_loss.mean()
elif self.reduction == 'sum':
return focal_loss.sum()
else:
return focal_loss
class Trainer:
"""处理(监督学习)训练循环。"""
@@ -25,30 +51,97 @@ class Trainer:
if config['training']['loss_function'] == 'bce':
# BCEWithLogitsLoss 结合了 Sigmoid 和 BCELoss更数值稳定
self.criterion = nn.BCEWithLogitsLoss()
# 在此添加其他损失函数,如 focal loss
elif config['training']['loss_function'] == 'focal_loss':
self.criterion = FocalLoss()
else:
raise ValueError(f"不支持的损失函数: {config['training']['loss_function']}")
# 初始化学习率调度器
self.scheduler = None
if 'scheduler' in config['training']:
scheduler_type = config['training']['scheduler']
if scheduler_type == 'step':
self.scheduler = StepLR(self.optimizer, step_size=config['training'].get('scheduler_step_size', 30), gamma=config['training'].get('scheduler_gamma', 0.1))
elif scheduler_type == 'cosine':
self.scheduler = CosineAnnealingWarmRestarts(self.optimizer, T_0=config['training'].get('scheduler_T_0', 10), T_mult=config['training'].get('scheduler_T_mult', 2))
# 初始化评估器
self.evaluator = Evaluator(model)
# 初始化早停相关变量
self.best_val_score = -float('inf')
self.patience = config['training'].get('early_stopping_patience', 10)
self.counter = 0
self.early_stop = False
# 确保保存目录存在
self.save_dir = config.get('save_dir', 'checkpoints')
os.makedirs(self.save_dir, exist_ok=True)
# 初始化 TensorBoard 日志记录器
self.log_dir = config.get('log_dir', 'logs')
os.makedirs(self.log_dir, exist_ok=True)
self.writer = SummaryWriter(log_dir=self.log_dir)
# 初始化混合精度训练
self.use_amp = config['training'].get('use_amp', False)
self.scaler = torch.cuda.amp.GradScaler() if self.use_amp else None
# 初始化梯度累积
self.gradient_accumulation_steps = config['training'].get('gradient_accumulation_steps', 1)
if self.gradient_accumulation_steps > 1:
self.logger.info(f"启用梯度累积,累积步数: {self.gradient_accumulation_steps}")
def train_epoch(self, dataloader: DataLoader):
"""运行单个训练周期epoch"""
self.model.train() # 将模型设置为训练模式
total_loss = 0
for batch in dataloader:
self.optimizer.zero_grad() # 清空梯度
for i, batch in enumerate(dataloader):
# 只有在梯度累积的第一步或不需要累积时才清空梯度
if i % self.gradient_accumulation_steps == 0:
self.optimizer.zero_grad()
# 前向传播
output = self.model(batch)
# 准备目标标签
# 假设标签在图级别,并且需要调整形状以匹配输出
target = batch.y.view_as(output)
# 使用混合精度训练
if self.use_amp:
with torch.cuda.amp.autocast():
# 前向传播
output = self.model(batch)
# 准备目标标签
# 假设标签在图级别,并且需要调整形状以匹配输出
target = batch.y.view_as(output)
# 计算损失
loss = self.criterion(output, target)
# 反向传播
loss.backward()
# 更新权重
self.optimizer.step()
# 计算损失
loss = self.criterion(output, target)
# 缩放损失以防止梯度下溢
self.scaler.scale(loss).backward()
# 只有在累积步数达到设定值时才更新权重
if (i + 1) % self.gradient_accumulation_steps == 0:
# 取消缩放并更新权重
self.scaler.step(self.optimizer)
self.scaler.update()
else:
# 标准训练流程
# 前向传播
output = self.model(batch)
# 准备目标标签
# 假设标签在图级别,并且需要调整形状以匹配输出
target = batch.y.view_as(output)
# 计算损失
loss = self.criterion(output, target)
# 反向传播
loss.backward()
# 只有在累积步数达到设定值时才更新权重
if (i + 1) % self.gradient_accumulation_steps == 0:
# 更新权重
self.optimizer.step()
total_loss += loss.item()
@@ -56,11 +149,79 @@ class Trainer:
self.logger.info(f"训练损失: {avg_loss:.4f}")
return avg_loss
def validate(self, dataloader: DataLoader):
"""运行验证并返回评估指标。"""
self.model.eval() # 将模型设置为评估模式
metrics = self.evaluator.evaluate(dataloader)
return metrics
def run(self, train_loader: DataLoader, val_loader: DataLoader):
"""运行完整的训练流程。"""
self.logger.info("开始训练...")
start_time = time.time()
for epoch in range(self.config['training']['epochs']):
if self.early_stop:
self.logger.info("早停触发,停止训练。")
break
epoch_start_time = time.time()
self.logger.info(f"周期 {epoch+1}/{self.config['training']['epochs']}")
self.train_epoch(train_loader)
# 在此处添加验证步骤,例如调用 Evaluator
# 训练一个周期
train_loss = self.train_epoch(train_loader)
# 验证
self.logger.info("正在验证...")
val_metrics = self.validate(val_loader)
# 更新学习率调度器
current_lr = self.optimizer.param_groups[0]['lr']
if self.scheduler:
self.scheduler.step()
new_lr = self.optimizer.param_groups[0]['lr']
self.logger.info(f"学习率从 {current_lr:.6f} 调整为 {new_lr:.6f}")
current_lr = new_lr
else:
self.logger.info(f"当前学习率: {current_lr:.6f}")
# 记录到 TensorBoard
self.writer.add_scalar('Loss/train', train_loss, epoch)
for metric_name, metric_value in val_metrics.items():
self.writer.add_scalar(f'Metrics/{metric_name}', metric_value, epoch)
self.writer.add_scalar('Learning Rate', current_lr, epoch)
# 计算周期耗时
epoch_time = time.time() - epoch_start_time
self.writer.add_scalar('Time/epoch', epoch_time, epoch)
self.logger.info(f"周期耗时: {epoch_time:.2f}")
# 检查是否需要保存最佳模型
val_score = val_metrics.get('f1', val_metrics.get('accuracy', -1))
if val_score > self.best_val_score:
self.best_val_score = val_score
self.counter = 0
# 保存最佳模型
save_path = os.path.join(self.save_dir, 'best_model.pth')
torch.save(self.model.state_dict(), save_path)
self.logger.info(f"保存最佳模型到 {save_path}")
else:
self.counter += 1
if self.counter >= self.patience:
self.early_stop = True
self.logger.info(f"验证性能连续 {self.patience} 个周期未改善,触发早停。")
# 计算总训练耗时
total_time = time.time() - start_time
self.logger.info(f"总训练耗时: {total_time:.2f}")
# 保存最后一个模型
save_path = os.path.join(self.save_dir, 'last_model.pth')
torch.save(self.model.state_dict(), save_path)
self.logger.info(f"保存最后一个模型到 {save_path}")
# 关闭 TensorBoard SummaryWriter
self.writer.close()
self.logger.info("训练完成。")
self.logger.info(f"最佳验证分数: {self.best_val_score:.4f}")

Binary file not shown.

Binary file not shown.

View File

@@ -3,7 +3,7 @@ import torch
import torch.nn as nn
from .gnn_encoder import GNNEncoder
from .transformer_core import TransformerCore
from .task_heads import ClassificationHead, MatchingHead
from .task_heads import ClassificationHead, MultiLabelClassificationHead, RegressionHead, MatchingHead
class GeoLayoutTransformer(nn.Module):
"""完整的 Geo-Layout Transformer 模型。"""
@@ -38,16 +38,34 @@ class GeoLayoutTransformer(nn.Module):
self.task_head = None
if 'task_head' in config['model']:
head_config = config['model']['task_head']
pooling_type = head_config.get('pooling_type', 'mean')
if head_config['type'] == 'classification':
self.task_head = ClassificationHead(
input_dim=head_config['input_dim'],
hidden_dim=head_config['hidden_dim'],
output_dim=head_config['output_dim']
output_dim=head_config['output_dim'],
pooling_type=pooling_type
)
elif head_config['type'] == 'multi_label_classification':
self.task_head = MultiLabelClassificationHead(
input_dim=head_config['input_dim'],
hidden_dim=head_config['hidden_dim'],
output_dim=head_config['output_dim'],
pooling_type=pooling_type
)
elif head_config['type'] == 'regression':
self.task_head = RegressionHead(
input_dim=head_config['input_dim'],
hidden_dim=head_config['hidden_dim'],
output_dim=head_config['output_dim'],
pooling_type=pooling_type
)
elif head_config['type'] == 'matching':
self.task_head = MatchingHead(
input_dim=head_config['input_dim'],
output_dim=head_config['output_dim']
output_dim=head_config['output_dim'],
pooling_type=pooling_type
)
# 可在此处添加其他任务头

View File

@@ -48,15 +48,14 @@ class GNNEncoder(nn.Module):
data: 一个 PyTorch Geometric 的 Data 或 Batch 对象。
Returns:
一个代表区块的图级别嵌入的张量。
一个代表节点级别嵌入的张量。
"""
x, edge_index, batch = data.x, data.edge_index, data.batch
x, edge_index = data.x, data.edge_index
# 通过所有 GNN 层
for layer in self.layers:
x = layer(x, edge_index)
x = torch.relu(x)
# 全局池化以获得图级别的嵌入
graph_embedding = self.readout(x, batch)
return graph_embedding
# 返回节点级别的嵌入,不进行全局池化
return x

View File

@@ -2,11 +2,44 @@
import torch
import torch.nn as nn
class PoolingLayer(nn.Module):
"""可插拔的池化层,支持多种池化策略。"""
def __init__(self, pooling_type: str = 'mean'):
super(PoolingLayer, self).__init__()
self.pooling_type = pooling_type
# 如果使用注意力池化,需要定义注意力机制
if pooling_type == 'attention':
self.attention = nn.Linear(1, 1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: 形状为 [batch_size, seq_len, hidden_dim] 的张量
Returns:
形状为 [batch_size, hidden_dim] 的池化后的张量
"""
if self.pooling_type == 'mean':
return torch.mean(x, dim=1)
elif self.pooling_type == 'max':
return torch.max(x, dim=1)[0]
elif self.pooling_type == 'cls':
# 取第一个 token 作为 [CLS] token
return x[:, 0, :]
elif self.pooling_type == 'attention':
# 计算注意力权重
weights = self.attention(torch.ones_like(x[:, :, :1])).softmax(dim=1)
return (x * weights).sum(dim=1)
else:
raise ValueError(f"不支持的池化类型: {self.pooling_type}")
class ClassificationHead(nn.Module):
"""一个用于分类任务的简单多层感知机MLP任务头。"""
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, pooling_type: str = 'mean'):
super(ClassificationHead, self).__init__()
self.pooling = PoolingLayer(pooling_type)
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_dim, output_dim)
@@ -19,9 +52,60 @@ class ClassificationHead(nn.Module):
Returns:
最终的分类 logits。
"""
# 我们可以取第一个 token类似 [CLS])的嵌入,或者进行平均池化
# 为简单起见,我们假设在序列维度上进行平均池化
x_pooled = torch.mean(x, dim=1)
# 使用指定的池化方法
x_pooled = self.pooling(x)
out = self.fc1(x_pooled)
out = self.relu(out)
out = self.fc2(out)
return out
class MultiLabelClassificationHead(nn.Module):
"""用于多标签分类任务的任务头。"""
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, pooling_type: str = 'mean'):
super(MultiLabelClassificationHead, self).__init__()
self.pooling = PoolingLayer(pooling_type)
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_dim, output_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: 来自 Transformer 骨干网络的输入张量。
Returns:
最终的多标签分类 logits。
"""
# 使用指定的池化方法
x_pooled = self.pooling(x)
out = self.fc1(x_pooled)
out = self.relu(out)
out = self.fc2(out)
return out
class RegressionHead(nn.Module):
"""用于回归任务的任务头。"""
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, pooling_type: str = 'mean'):
super(RegressionHead, self).__init__()
self.pooling = PoolingLayer(pooling_type)
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_dim, output_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: 来自 Transformer 骨干网络的输入张量。
Returns:
最终的回归输出。
"""
# 使用指定的池化方法
x_pooled = self.pooling(x)
out = self.fc1(x_pooled)
out = self.relu(out)
@@ -31,8 +115,9 @@ class ClassificationHead(nn.Module):
class MatchingHead(nn.Module):
"""用于学习版图匹配的相似性嵌入的任务头。"""
def __init__(self, input_dim: int, output_dim: int):
def __init__(self, input_dim: int, output_dim: int, pooling_type: str = 'mean'):
super(MatchingHead, self).__init__()
self.pooling = PoolingLayer(pooling_type)
self.projection = nn.Linear(input_dim, output_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -43,8 +128,8 @@ class MatchingHead(nn.Module):
Returns:
代表整个输入图(例如一个 IP 模块)的单个嵌入向量。
"""
# 全局平均池化,为整个序列获取一个单一的向量
graph_embedding = torch.mean(x, dim=1)
# 使用指定的池化方法
graph_embedding = self.pooling(x)
# 投影到最终的嵌入空间
similarity_embedding = self.projection(graph_embedding)
# 对嵌入进行 L2 归一化,以便使用余弦相似度

6
src/utils/__init__.py Normal file
View File

@@ -0,0 +1,6 @@
# src/utils/__init__.py
from .config_loader import load_config, merge_configs
from .logging import get_logger
from .seed import set_seed
__all__ = ['load_config', 'merge_configs', 'get_logger', 'set_seed']

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

33
src/utils/seed.py Normal file
View File

@@ -0,0 +1,33 @@
# src/utils/seed.py
import random
import numpy as np
import torch
import os
def set_seed(seed: int = 42):
"""
设置随机种子,确保实验的可重复性。
Args:
seed: 随机种子值
"""
# 设置 Python 内置随机种子
random.seed(seed)
# 设置 NumPy 随机种子
np.random.seed(seed)
# 设置 PyTorch 随机种子
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # 对于多 GPU 环境
# 禁用 CUDA 中的确定性算法,以提高性能(可选)
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False
# 设置环境变量中的随机种子
os.environ['PYTHONHASHSEED'] = str(seed)
print(f"随机种子已设置为: {seed}")

199
tests/test_model_run.py Normal file
View File

@@ -0,0 +1,199 @@
#!/usr/bin/env python3
"""
测试脚本,用于验证模型是否可以正常跑通,不需要真实数据
- 生成随机图数据
- 加载模型配置
- 初始化模型
- 运行前向传播和反向传播
- 验证模型是否可以正常工作
"""
import os
import sys
import torch
from torch_geometric.data import Data, Batch
# 添加项目根目录到 Python 路径
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.utils.config_loader import load_config
from src.models.geo_layout_transformer import GeoLayoutTransformer
from src.engine.trainer import Trainer
from src.engine.self_supervised import SelfSupervisedTrainer
from src.utils.logging import get_logger
def generate_random_graph_data(num_graphs=4, num_nodes_per_graph=8, node_feature_dim=5, edge_feature_dim=0):
"""
生成随机的图数据
Args:
num_graphs: 图的数量
num_nodes_per_graph: 每个图的节点数量
node_feature_dim: 节点特征维度
edge_feature_dim: 边特征维度
Returns:
一个 Batch 对象,包含多个随机生成的图
"""
graphs = []
for _ in range(num_graphs):
# 生成随机节点特征
x = torch.randn(num_nodes_per_graph, node_feature_dim)
# 生成随机边(完全连接)
edge_index = []
for i in range(num_nodes_per_graph):
for j in range(num_nodes_per_graph):
if i != j:
edge_index.append([i, j])
edge_index = torch.tensor(edge_index, dtype=torch.long).t()
# 生成随机标签
y = torch.randn(1, 1) # 假设是图级别的标签
# 创建图数据
graph = Data(x=x, edge_index=edge_index, y=y)
graphs.append(graph)
# 构建批次
batch = Batch.from_data_list(graphs)
return batch
def test_supervised_training():
"""测试监督训练"""
logger = get_logger("Test_Supervised_Training")
logger.info("=== 测试监督训练 ===")
# 加载配置
config = load_config('configs/default.yaml')
# 生成随机数据
batch = generate_random_graph_data()
logger.info(f"生成的批次数据: {batch}")
logger.info(f"批次大小: {batch.num_graphs}")
logger.info(f"总节点数: {batch.num_nodes}")
logger.info(f"总边数: {batch.num_edges}")
# 初始化模型
logger.info("初始化模型...")
model = GeoLayoutTransformer(config)
logger.info("模型初始化成功")
# 初始化训练器
logger.info("初始化训练器...")
trainer = Trainer(model, config)
logger.info("训练器初始化成功")
# 测试前向传播
logger.info("测试前向传播...")
with torch.no_grad():
# 先测试 GNN 编码器
gnn_output = model.gnn_encoder(batch)
logger.info(f"GNN 编码器输出形状: {gnn_output.shape}")
# 测试形状重塑
num_graphs = batch.num_graphs
nodes_per_graph = batch.ptr[1:] - batch.ptr[:-1]
logger.info(f"每个图的节点数: {nodes_per_graph}")
reshaped_embeddings = gnn_output.view(num_graphs, nodes_per_graph[0], -1)
logger.info(f"重塑后的嵌入形状: {reshaped_embeddings.shape}")
# 测试 Transformer 核心
transformer_output = model.transformer_core(reshaped_embeddings)
logger.info(f"Transformer 输出形状: {transformer_output.shape}")
# 测试完整模型
output = model(batch)
logger.info(f"前向传播成功,输出形状: {output.shape}")
# 测试反向传播
logger.info("测试反向传播...")
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
optimizer.zero_grad()
output = model(batch)
# 对输出进行全局池化,得到图级别的表示
# 从 [batch_size, seq_len, hidden_dim] 变为 [batch_size, hidden_dim]
graph_output = output.mean(dim=1)
# 使用 MSE 损失,只比较前 1 个维度(与 batch.y 形状匹配)
loss = torch.nn.MSELoss()(graph_output[:, :1], batch.y)
loss.backward()
optimizer.step()
logger.info(f"反向传播成功,损失值: {loss.item()}")
logger.info("监督训练测试完成,模型可以正常工作!")
def test_self_supervised_training():
"""测试自监督训练"""
logger = get_logger("Test_Self_Supervised_Training")
logger.info("\n=== 测试自监督训练 ===")
# 加载配置
config = load_config('configs/default.yaml')
# 生成随机数据
batch = generate_random_graph_data()
logger.info(f"生成的批次数据: {batch}")
logger.info(f"批次大小: {batch.num_graphs}")
logger.info(f"总节点数: {batch.num_nodes}")
logger.info(f"总边数: {batch.num_edges}")
# 初始化模型
logger.info("初始化模型...")
model = GeoLayoutTransformer(config)
logger.info("模型初始化成功")
# 初始化自监督训练器
logger.info("初始化自监督训练器...")
trainer = SelfSupervisedTrainer(model, config)
logger.info("自监督训练器初始化成功")
# 测试前向传播
logger.info("测试前向传播...")
with torch.no_grad():
# 测试 GNN 编码器
gnn_output = model.gnn_encoder(batch)
logger.info(f"GNN 编码器输出形状: {gnn_output.shape}")
# 测试 Transformer 核心
num_graphs = batch.num_graphs
nodes_per_graph = batch.ptr[1:] - batch.ptr[:-1]
if not torch.all(nodes_per_graph == nodes_per_graph[0]):
logger.warning("批次中图形的节点数不一致,使用第一个图形的节点数")
nodes_per_graph = nodes_per_graph[0]
gnn_output_reshaped = gnn_output.view(num_graphs, nodes_per_graph, -1)
transformer_output = model.transformer_core(gnn_output_reshaped)
logger.info(f"Transformer 核心输出形状: {transformer_output.shape}")
# 测试完整模型前向传播
logger.info("测试完整模型前向传播...")
with torch.no_grad():
output = model(batch)
logger.info(f"完整模型前向传播成功,输出形状: {output.shape}")
logger.info("自监督训练测试完成,模型可以正常工作!")
def main():
"""主函数"""
logger = get_logger("Test_Model_Run")
logger.info("开始测试模型是否可以正常跑通...")
try:
# 测试监督训练
test_supervised_training()
# 测试自监督训练
test_self_supervised_training()
logger.info("\n✅ 所有测试通过,模型可以正常跑通!")
logger.info("模型已准备就绪,可以使用真实数据进行训练。")
except Exception as e:
logger.error(f"❌ 测试失败: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":
main()

143
uv.lock generated
View File

@@ -1,8 +1,16 @@
# uv.lock
version = 1
revision = 2
requires-python = ">=3.12"
[[package]]
name = "absl-py"
version = "2.4.0"
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/64/c7/8de93764ad66968d19329a7e0c147a2bb3c7054c554d4a119111b8f9440f/absl_py-2.4.0.tar.gz", hash = "sha256:8c6af82722b35cf71e0f4d1d47dcaebfff286e27110a99fc359349b247dfb5d4", size = 116543, upload-time = "2026-01-28T10:17:05.322Z" }
wheels = [
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/18/a6/907a406bb7d359e6a63f99c313846d9eec4f7e6f7437809e03aa00fa3074/absl_py-2.4.0-py3-none-any.whl", hash = "sha256:88476fd881ca8aab94ffa78b7b6c632a782ab3ba1cd19c9bd423abc4fb4cd28d", size = 135750, upload-time = "2026-01-28T10:17:04.19Z" },
]
[[package]]
name = "aiohappyeyeballs"
version = "2.6.1"
@@ -234,20 +242,35 @@ sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/e8/b5/a12ef182943856
wheels = [
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/51/a8/cff9bd17789b41c2f09f4e4b574fd05c5860e0b5602b3708b6c1de4b6c82/gdstk-0.9.61-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:61f0ee05cdce9b4163ea812cbf2e2f5d8d01a293fa118ff98348280306bd91d6", size = 923143, upload-time = "2025-08-28T10:16:42.55Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/9b/13/c97316d18510e2dcb3aeba0e13f4c6d7ad34884be62006b910f382bdfbc6/gdstk-0.9.61-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:fff1b104b6775e4c27ab2751b3f4ac6c1ce86a4e9afd5e5535ac4acefa6a7a07", size = 475724, upload-time = "2025-08-28T10:16:44.434Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/19/49/40aed8aae97054b08a0063a583ef55e3ab0e335441d6d339615d8593892a/gdstk-0.9.61-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5218f8c5ab13b6e979665c0a7dc1272768003a1cb7add0682483837f7485faed", size = 536850, upload-time = "2025-10-20T11:25:50.578Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/35/7c/4d324ae83dac2ba15fcc61d688019ad79b8938c027d2467415cb804f8058/gdstk-0.9.61-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4e79f3881d3b3666a600efd5b2c131454507f69d3c9b9eaf383d106cfbd6e7bc", size = 600640, upload-time = "2025-08-28T10:16:45.627Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/ef/cd/addb66d3740a654495e413fb28d6c4e4d9ac94bc0c25dc31f2868fb6e18c/gdstk-0.9.61-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e90a6e24c2145320e53e953a59c6297fd25c17c6ef098fa8602e64e64a5390ea", size = 536849, upload-time = "2025-08-28T10:16:47.039Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/d7/37/43fb416068f0722f1a9ccc67fe6a8f78c93ef3bb9cdf37a2edd276316b23/gdstk-0.9.61-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a3a49401cbd26c5a17a4152d1befa73efb21af694524557bf09d15f4c8a874e6", size = 540029, upload-time = "2025-10-20T11:26:06.539Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/f9/f0/d4a24c1b6636454812b4990dc590d72b0a37f185122c6fa19d4b0d22a90e/gdstk-0.9.61-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:8738ac63bbe29dcb5abae6a19d207c4e0857f9dc1bd405c85af8a87f0dcfb348", size = 535749, upload-time = "2025-08-28T10:16:48.157Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/85/67/ec1ef9f67ac26554d55f4e667e602698fb7f520ef89dfe6cb4550152eec1/gdstk-0.9.61-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:23bb023a49f3321673d0e32cdce2e2705a51d9e12328c928723ded49af970520", size = 1711671, upload-time = "2025-08-28T10:16:49.963Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/89/73/4fd73bbc7500e383500a9edca4c09ce38bffd44eefb256ce17a41c526dbe/gdstk-0.9.61-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:81c2f19cab89623d1f56848e7a16e2fab82a93c61c8f7aa73f5ff59840b60c0f", size = 1535161, upload-time = "2025-08-28T10:16:51.181Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/a0/b1/d740423bd7436c522295a09cadc0e21d59afa418185cffffce46a6eb85b0/gdstk-0.9.61-cp312-cp312-win_amd64.whl", hash = "sha256:4474f015ecc228b210165287cb7eea65639ea6308f60105cb49e970079bddc2b", size = 500139, upload-time = "2025-08-28T10:16:52.496Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/07/9a/7ea7b7a295e029542d4aeb252d01c1bcc8e724df26155f9b5f432b02d02a/gdstk-0.9.61-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:3beeae846fc523c7e3a01c47edcd3b7dd83c29650e56b82a371e528f9cb0ec3e", size = 923024, upload-time = "2025-08-28T10:16:53.613Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/b8/3d/5aa9b1a4665259702e9f17e03a9a114a873df25c9ba2c9e782ff25fb11e9/gdstk-0.9.61-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:575a21639b31e2fab4d9e918468b8b40a58183028db563e5963be594bff1403d", size = 475687, upload-time = "2025-08-28T10:16:54.886Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/80/13/ec783d8de5d9b4e51763102cac6da124f16747e4c73f166c36a105065008/gdstk-0.9.61-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:90d54b48223dcbb8257769faaa87542d12a749d8486e8d1187a45d06e9422859", size = 536872, upload-time = "2025-10-20T11:25:52.902Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/20/4a/365ca49b76ee3d70a0d044fcc44272c85754fd763e0b3f3e9f498b8bf4a1/gdstk-0.9.61-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:35405bed95542a0b10f343b165ce0ad80740bf8127a4507565ec74222e6ec8d3", size = 600630, upload-time = "2025-08-28T10:16:56.326Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/62/3e/815fd2977ff1d885ad87d0a54deb19926fd025933325c7a27625c1c0a0c3/gdstk-0.9.61-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b311ddf8982995b52ac3bf3b32a6cf6d918afc4e66dea527d531e8af73896231", size = 536873, upload-time = "2025-08-28T10:16:57.516Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/b0/8b/1ba0abc4fb3c60015d23894aa9d093f473fdc337584f9c1d7afe96d6f9f5/gdstk-0.9.61-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6dcbfc60fba92d10f1c7d612b5409c343fcaf2a380640e9fb01c504ca948b412", size = 540029, upload-time = "2025-10-20T11:26:09.727Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/74/72/cc46f132741e541995ede7fccf9820f105fb2296ab70192bd27de56190f2/gdstk-0.9.61-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:fab67ccdd8029ef7eb873f8c98f875dc2665a5e45af7cf3d2a7a0f401826a1d3", size = 535763, upload-time = "2025-08-28T10:16:58.621Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/5b/db/72196721fedfc38cf21158e5a436d73b41ea244b78c1053462a35d1e42cb/gdstk-0.9.61-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:5852749e203d6978e06d02f8ef9e29ce4512cb1aedeb62c37b8e8b2c10c4f529", size = 1711669, upload-time = "2025-08-28T10:16:59.838Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/36/ec/eeebcc95c2741e1f39da08c95fdcc56b3d0f5305ad732d8a66213b6fa0b8/gdstk-0.9.61-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:7ee38a54c799e77dbe219266f765bbd3b2906b62bc7b6fb64b1387e6db3dd187", size = 1535165, upload-time = "2025-08-28T10:17:01.027Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/cb/a8/653335b1ec13306b023a9aa1e2072e8bab5c0a5376c138066006504198c3/gdstk-0.9.61-cp313-cp313-win_amd64.whl", hash = "sha256:6abb396873b2660dd7863d664b3822f00547bf7f216af27be9f1f812bc5e8027", size = 500117, upload-time = "2025-08-28T10:17:02.459Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/33/52/28cd8720357d6b892ac19684e4af57f7264ba8eea0ea8078e17f0476408c/gdstk-0.9.61-cp314-cp314-macosx_10_15_universal2.whl", hash = "sha256:a674af8be5cf1f8ea9f6c5b5f165f797d7e2ed74cbca68b4a22adb92b515fb35", size = 916091, upload-time = "2025-10-20T11:25:35.497Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/fe/f1/c55c7b7b0158a8540716a4fea3aaf0288726122afc0e9af1f0564b79605b/gdstk-0.9.61-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:38ec0b7285d6c9bf8cbc279731dc0d314633cda2ce9e6f9053554b3e5f004fcd", size = 474548, upload-time = "2025-10-20T11:25:38.086Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/ab/83/23907af9b349d79af54c4a79ac7972b9a52f8cee48b366ee9635cf9a32d9/gdstk-0.9.61-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:3b63a77b57fb441c8017217aaf1e8b13d93cbee822031a8e2826adb716e01dd4", size = 537072, upload-time = "2025-10-20T11:25:55.473Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/2c/e7/e930928f9a03b896f51b6e975dbb652cffa81c61338609d6289c3f48313c/gdstk-0.9.61-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f7fae6eee627e837d1405b47d381ccd33dbba85473b1bb3822bdc8ae41dbc0dc", size = 540132, upload-time = "2025-10-20T11:26:11.379Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/6d/9d/56afbb84cccb07751f8532591bfc3a1608833943446b51978b52daaaa2b1/gdstk-0.9.61-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:9e396694cac24bd87d0e38c37e6740d9ba0c13f6c9f2211a871d62288430f069", size = 1578033, upload-time = "2025-10-20T11:26:18.707Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/58/14/c854d3ef2b0ece30c0cb4c034d3b7f58425086383a229c960d813a163861/gdstk-0.9.61-cp314-cp314-win_amd64.whl", hash = "sha256:7ea0c1200dc53b794e9c0cc6fe3ea51e49113dfdd9c3109e1961cda3cc2197c7", size = 513072, upload-time = "2025-10-20T11:25:21.089Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/be/0a/5e9c3be1327d36556a5e8d16c6696614fdc1df0d28c0629b829785245d76/gdstk-0.9.61-cp314-cp314t-macosx_10_15_universal2.whl", hash = "sha256:616dd1c3e7aea4a98aeb03db7cf76a853d134c54690790eaa25c63eede7b869a", size = 925091, upload-time = "2025-10-20T11:25:40.772Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/6a/4a/a72de67ea1c217537ae537d7e96b6a0b3c7427177bd1439c89e7617ad3f0/gdstk-0.9.61-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:b0e898202fbb7fd4c39f8404831415a0aa0445656342102c4e77d4a7c2c15a1d", size = 477723, upload-time = "2025-10-20T11:25:44.535Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/0c/94/5034a646ee3bda58210cd377b241582161c8683489210cba762d4f7f06d1/gdstk-0.9.61-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:29bb862a1a814f5bbd6f8bbc2f99e1163df9e6307071cb6e11251dbe7542feb5", size = 541773, upload-time = "2025-10-20T11:25:57.407Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/bc/90/3e5883b528dcc7c4daa74925276a09ff274004518deea48c859b9215120b/gdstk-0.9.61-cp314-cp314t-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c6c2a08d82a683aff50dc63f2943ed805d32d46bd984cbd4ac9cf876146d0ef9", size = 544525, upload-time = "2025-10-20T11:26:13.877Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/46/0b/549a0e72982b87011013705cdb552e2acfdb882ceaeb41a3fe020e981ec8/gdstk-0.9.61-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:3ba52f95763052a6968583942e6531ceca20c14c762d44fe2bd887445e2f73b6", size = 1582069, upload-time = "2025-10-20T11:26:21.189Z" },
]
[[package]]
@@ -260,6 +283,7 @@ dependencies = [
{ name = "pandas" },
{ name = "pyyaml" },
{ name = "scikit-learn" },
{ name = "tensorboard" },
{ name = "torch" },
{ name = "torch-geometric" },
{ name = "torchvision" },
@@ -272,11 +296,53 @@ requires-dist = [
{ name = "pandas", specifier = ">=2.3.2" },
{ name = "pyyaml", specifier = ">=6.0.2" },
{ name = "scikit-learn", specifier = ">=1.7.1" },
{ name = "tensorboard", specifier = ">=2.20.0" },
{ name = "torch", specifier = ">=2.8.0" },
{ name = "torch-geometric", specifier = ">=2.6.1" },
{ name = "torchvision", specifier = ">=0.23.0" },
]
[[package]]
name = "grpcio"
version = "1.78.0"
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
dependencies = [
{ name = "typing-extensions" },
]
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/06/8a/3d098f35c143a89520e568e6539cc098fcd294495910e359889ce8741c84/grpcio-1.78.0.tar.gz", hash = "sha256:7382b95189546f375c174f53a5fa873cef91c4b8005faa05cc5b3beea9c4f1c5", size = 12852416, upload-time = "2026-02-06T09:57:18.093Z" }
wheels = [
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/4e/f4/7384ed0178203d6074446b3c4f46c90a22ddf7ae0b3aee521627f54cfc2a/grpcio-1.78.0-cp312-cp312-linux_armv7l.whl", hash = "sha256:f9ab915a267fc47c7e88c387a3a28325b58c898e23d4995f765728f4e3dedb97", size = 5913985, upload-time = "2026-02-06T09:55:26.832Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/81/ed/be1caa25f06594463f685b3790b320f18aea49b33166f4141bfdc2bfb236/grpcio-1.78.0-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:3f8904a8165ab21e07e58bf3e30a73f4dffc7a1e0dbc32d51c61b5360d26f43e", size = 11811853, upload-time = "2026-02-06T09:55:29.224Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/24/a7/f06d151afc4e64b7e3cc3e872d331d011c279aaab02831e40a81c691fb65/grpcio-1.78.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:859b13906ce098c0b493af92142ad051bf64c7870fa58a123911c88606714996", size = 6475766, upload-time = "2026-02-06T09:55:31.825Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/8a/a8/4482922da832ec0082d0f2cc3a10976d84a7424707f25780b82814aafc0a/grpcio-1.78.0-cp312-cp312-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:b2342d87af32790f934a79c3112641e7b27d63c261b8b4395350dad43eff1dc7", size = 7170027, upload-time = "2026-02-06T09:55:34.7Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/54/bf/f4a3b9693e35d25b24b0b39fa46d7d8a3c439e0a3036c3451764678fec20/grpcio-1.78.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:12a771591ae40bc65ba67048fa52ef4f0e6db8279e595fd349f9dfddeef571f9", size = 6690766, upload-time = "2026-02-06T09:55:36.902Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/c7/b9/521875265cc99fe5ad4c5a17010018085cae2810a928bf15ebe7d8bcd9cc/grpcio-1.78.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:185dea0d5260cbb2d224c507bf2a5444d5abbb1fa3594c1ed7e4c709d5eb8383", size = 7266161, upload-time = "2026-02-06T09:55:39.824Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/05/86/296a82844fd40a4ad4a95f100b55044b4f817dece732bf686aea1a284147/grpcio-1.78.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:51b13f9aed9d59ee389ad666b8c2214cc87b5de258fa712f9ab05f922e3896c6", size = 8253303, upload-time = "2026-02-06T09:55:42.353Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/f3/e4/ea3c0caf5468537f27ad5aab92b681ed7cc0ef5f8c9196d3fd42c8c2286b/grpcio-1.78.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:fd5f135b1bd58ab088930b3c613455796dfa0393626a6972663ccdda5b4ac6ce", size = 7698222, upload-time = "2026-02-06T09:55:44.629Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/d7/47/7f05f81e4bb6b831e93271fb12fd52ba7b319b5402cbc101d588f435df00/grpcio-1.78.0-cp312-cp312-win32.whl", hash = "sha256:94309f498bcc07e5a7d16089ab984d42ad96af1d94b5a4eb966a266d9fcabf68", size = 4066123, upload-time = "2026-02-06T09:55:47.644Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/ad/e7/d6914822c88aa2974dbbd10903d801a28a19ce9cd8bad7e694cbbcf61528/grpcio-1.78.0-cp312-cp312-win_amd64.whl", hash = "sha256:9566fe4ababbb2610c39190791e5b829869351d14369603702e890ef3ad2d06e", size = 4797657, upload-time = "2026-02-06T09:55:49.86Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/05/a9/8f75894993895f361ed8636cd9237f4ab39ef87fd30db17467235ed1c045/grpcio-1.78.0-cp313-cp313-linux_armv7l.whl", hash = "sha256:ce3a90455492bf8bfa38e56fbbe1dbd4f872a3d8eeaf7337dc3b1c8aa28c271b", size = 5920143, upload-time = "2026-02-06T09:55:52.035Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/55/06/0b78408e938ac424100100fd081189451b472236e8a3a1f6500390dc4954/grpcio-1.78.0-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:2bf5e2e163b356978b23652c4818ce4759d40f4712ee9ec5a83c4be6f8c23a3a", size = 11803926, upload-time = "2026-02-06T09:55:55.494Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/88/93/b59fe7832ff6ae3c78b813ea43dac60e295fa03606d14d89d2e0ec29f4f3/grpcio-1.78.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:8f2ac84905d12918e4e55a16da17939eb63e433dc11b677267c35568aa63fc84", size = 6478628, upload-time = "2026-02-06T09:55:58.533Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/ed/df/e67e3734527f9926b7d9c0dde6cd998d1d26850c3ed8eeec81297967ac67/grpcio-1.78.0-cp313-cp313-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:b58f37edab4a3881bc6c9bca52670610e0c9ca14e2ea3cf9debf185b870457fb", size = 7173574, upload-time = "2026-02-06T09:56:01.786Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/a6/62/cc03fffb07bfba982a9ec097b164e8835546980aec25ecfa5f9c1a47e022/grpcio-1.78.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:735e38e176a88ce41840c21bb49098ab66177c64c82426e24e0082500cc68af5", size = 6692639, upload-time = "2026-02-06T09:56:04.529Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/bf/9a/289c32e301b85bdb67d7ec68b752155e674ee3ba2173a1858f118e399ef3/grpcio-1.78.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:2045397e63a7a0ee7957c25f7dbb36ddc110e0cfb418403d110c0a7a68a844e9", size = 7268838, upload-time = "2026-02-06T09:56:08.397Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/0e/79/1be93f32add280461fa4773880196572563e9c8510861ac2da0ea0f892b6/grpcio-1.78.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:a9f136fbafe7ccf4ac7e8e0c28b31066e810be52d6e344ef954a3a70234e1702", size = 8251878, upload-time = "2026-02-06T09:56:10.914Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/65/65/793f8e95296ab92e4164593674ae6291b204bb5f67f9d4a711489cd30ffa/grpcio-1.78.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:748b6138585379c737adc08aeffd21222abbda1a86a0dca2a39682feb9196c20", size = 7695412, upload-time = "2026-02-06T09:56:13.593Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/1c/9f/1e233fe697ecc82845942c2822ed06bb522e70d6771c28d5528e4c50f6a4/grpcio-1.78.0-cp313-cp313-win32.whl", hash = "sha256:271c73e6e5676afe4fc52907686670c7cea22ab2310b76a59b678403ed40d670", size = 4064899, upload-time = "2026-02-06T09:56:15.601Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/4d/27/d86b89e36de8a951501fb06a0f38df19853210f341d0b28f83f4aa0ffa08/grpcio-1.78.0-cp313-cp313-win_amd64.whl", hash = "sha256:f2d4e43ee362adfc05994ed479334d5a451ab7bc3f3fee1b796b8ca66895acb4", size = 4797393, upload-time = "2026-02-06T09:56:17.882Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/29/f2/b56e43e3c968bfe822fa6ce5bca10d5c723aa40875b48791ce1029bb78c7/grpcio-1.78.0-cp314-cp314-linux_armv7l.whl", hash = "sha256:e87cbc002b6f440482b3519e36e1313eb5443e9e9e73d6a52d43bd2004fcfd8e", size = 5920591, upload-time = "2026-02-06T09:56:20.758Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/5d/81/1f3b65bd30c334167bfa8b0d23300a44e2725ce39bba5b76a2460d85f745/grpcio-1.78.0-cp314-cp314-macosx_11_0_universal2.whl", hash = "sha256:c41bc64626db62e72afec66b0c8a0da76491510015417c127bfc53b2fe6d7f7f", size = 11813685, upload-time = "2026-02-06T09:56:24.315Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/0e/1c/bbe2f8216a5bd3036119c544d63c2e592bdf4a8ec6e4a1867592f4586b26/grpcio-1.78.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:8dfffba826efcf366b1e3ccc37e67afe676f290e13a3b48d31a46739f80a8724", size = 6487803, upload-time = "2026-02-06T09:56:27.367Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/16/5c/a6b2419723ea7ddce6308259a55e8e7593d88464ce8db9f4aa857aba96fa/grpcio-1.78.0-cp314-cp314-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:74be1268d1439eaaf552c698cdb11cd594f0c49295ae6bb72c34ee31abbe611b", size = 7173206, upload-time = "2026-02-06T09:56:29.876Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/df/1e/b8801345629a415ea7e26c83d75eb5dbe91b07ffe5210cc517348a8d4218/grpcio-1.78.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:be63c88b32e6c0f1429f1398ca5c09bc64b0d80950c8bb7807d7d7fb36fb84c7", size = 6693826, upload-time = "2026-02-06T09:56:32.305Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/34/84/0de28eac0377742679a510784f049738a80424b17287739fc47d63c2439e/grpcio-1.78.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:3c586ac70e855c721bda8f548d38c3ca66ac791dc49b66a8281a1f99db85e452", size = 7277897, upload-time = "2026-02-06T09:56:34.915Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/ca/9c/ad8685cfe20559a9edb66f735afdcb2b7d3de69b13666fdfc542e1916ebd/grpcio-1.78.0-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:35eb275bf1751d2ffbd8f57cdbc46058e857cf3971041521b78b7db94bdaf127", size = 8252404, upload-time = "2026-02-06T09:56:37.553Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/3c/05/33a7a4985586f27e1de4803887c417ec7ced145ebd069bc38a9607059e2b/grpcio-1.78.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:207db540302c884b8848036b80db352a832b99dfdf41db1eb554c2c2c7800f65", size = 7696837, upload-time = "2026-02-06T09:56:40.173Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/73/77/7382241caf88729b106e49e7d18e3116216c778e6a7e833826eb96de22f7/grpcio-1.78.0-cp314-cp314-win32.whl", hash = "sha256:57bab6deef2f4f1ca76cc04565df38dc5713ae6c17de690721bdf30cb1e0545c", size = 4142439, upload-time = "2026-02-06T09:56:43.258Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/48/b2/b096ccce418882fbfda4f7496f9357aaa9a5af1896a9a7f60d9f2b275a06/grpcio-1.78.0-cp314-cp314-win_amd64.whl", hash = "sha256:dce09d6116df20a96acfdbf85e4866258c3758180e8c49845d6ba8248b6d0bbb", size = 4929852, upload-time = "2026-02-06T09:56:45.885Z" },
]
[[package]]
name = "idna"
version = "3.10"
@@ -307,6 +373,15 @@ wheels = [
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/1e/e8/685f47e0d754320684db4425a0967f7d3fa70126bffd76110b7009a0090f/joblib-1.5.2-py3-none-any.whl", hash = "sha256:4e1f0bdbb987e6d843c70cf43714cb276623def372df3c22fe5266b2670bc241", size = 308396, upload-time = "2025-08-27T12:15:45.188Z" },
]
[[package]]
name = "markdown"
version = "3.10.2"
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/2b/f4/69fa6ed85ae003c2378ffa8f6d2e3234662abd02c10d216c0ba96081a238/markdown-3.10.2.tar.gz", hash = "sha256:994d51325d25ad8aa7ce4ebaec003febcce822c3f8c911e3b17c52f7f589f950", size = 368805, upload-time = "2026-02-09T14:57:26.942Z" }
wheels = [
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/de/1f/77fa3081e4f66ca3576c896ae5d31c3002ac6607f9747d2e3aa49227e464/markdown-3.10.2-py3-none-any.whl", hash = "sha256:e91464b71ae3ee7afd3017d9f358ef0baf158fd9a298db92f1d4761133824c36", size = 108180, upload-time = "2026-02-09T14:57:25.787Z" },
]
[[package]]
name = "markupsafe"
version = "3.0.2"
@@ -615,6 +690,15 @@ wheels = [
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/a2/eb/86626c1bbc2edb86323022371c39aa48df6fd8b0a1647bc274577f72e90b/nvidia_nvtx_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5b17e2001cc0d751a5bc2c6ec6d26ad95913324a4adb86788c944f8ce9ba441f", size = 89954, upload-time = "2025-03-07T01:42:44.131Z" },
]
[[package]]
name = "packaging"
version = "26.0"
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/65/ee/299d360cdc32edc7d2cf530f3accf79c4fca01e96ffc950d8a52213bd8e4/packaging-26.0.tar.gz", hash = "sha256:00243ae351a257117b6a241061796684b084ed1c516a08c48a3f7e147a9d80b4", size = 143416, upload-time = "2026-01-21T20:50:39.064Z" }
wheels = [
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/b7/b9/c538f279a4e237a006a2c98387d081e9eb060d203d8ed34467cc0f0b9b53/packaging-26.0-py3-none-any.whl", hash = "sha256:b36f1fef9334a5588b4166f8bcd26a14e521f2b55e6b9de3aaa80d3ff7a37529", size = 74366, upload-time = "2026-01-21T20:50:37.788Z" },
]
[[package]]
name = "pandas"
version = "2.3.2"
@@ -772,6 +856,21 @@ wheels = [
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/cc/35/cc0aaecf278bb4575b8555f2b137de5ab821595ddae9da9d3cd1da4072c7/propcache-0.3.2-py3-none-any.whl", hash = "sha256:98f1ec44fb675f5052cccc8e609c46ed23a35a1cfd18545ad4e29002d858a43f", size = 12663, upload-time = "2025-06-09T22:56:04.484Z" },
]
[[package]]
name = "protobuf"
version = "6.33.5"
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/ba/25/7c72c307aafc96fa87062aa6291d9f7c94836e43214d43722e86037aac02/protobuf-6.33.5.tar.gz", hash = "sha256:6ddcac2a081f8b7b9642c09406bc6a4290128fce5f471cddd165960bb9119e5c", size = 444465, upload-time = "2026-01-29T21:51:33.494Z" }
wheels = [
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/b1/79/af92d0a8369732b027e6d6084251dd8e782c685c72da161bd4a2e00fbabb/protobuf-6.33.5-cp310-abi3-win32.whl", hash = "sha256:d71b040839446bac0f4d162e758bea99c8251161dae9d0983a3b88dee345153b", size = 425769, upload-time = "2026-01-29T21:51:21.751Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/55/75/bb9bc917d10e9ee13dee8607eb9ab963b7cf8be607c46e7862c748aa2af7/protobuf-6.33.5-cp310-abi3-win_amd64.whl", hash = "sha256:3093804752167bcab3998bec9f1048baae6e29505adaf1afd14a37bddede533c", size = 437118, upload-time = "2026-01-29T21:51:24.022Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/a2/6b/e48dfc1191bc5b52950246275bf4089773e91cb5ba3592621723cdddca62/protobuf-6.33.5-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:a5cb85982d95d906df1e2210e58f8e4f1e3cdc088e52c921a041f9c9a0386de5", size = 427766, upload-time = "2026-01-29T21:51:25.413Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/4e/b1/c79468184310de09d75095ed1314b839eb2f72df71097db9d1404a1b2717/protobuf-6.33.5-cp39-abi3-manylinux2014_aarch64.whl", hash = "sha256:9b71e0281f36f179d00cbcb119cb19dec4d14a81393e5ea220f64b286173e190", size = 324638, upload-time = "2026-01-29T21:51:26.423Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/c5/f5/65d838092fd01c44d16037953fd4c2cc851e783de9b8f02b27ec4ffd906f/protobuf-6.33.5-cp39-abi3-manylinux2014_s390x.whl", hash = "sha256:8afa18e1d6d20af15b417e728e9f60f3aa108ee76f23c3b2c07a2c3b546d3afd", size = 339411, upload-time = "2026-01-29T21:51:27.446Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/9b/53/a9443aa3ca9ba8724fdfa02dd1887c1bcd8e89556b715cfbacca6b63dbec/protobuf-6.33.5-cp39-abi3-manylinux2014_x86_64.whl", hash = "sha256:cbf16ba3350fb7b889fca858fb215967792dc125b35c7976ca4818bee3521cf0", size = 323465, upload-time = "2026-01-29T21:51:28.925Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/57/bf/2086963c69bdac3d7cff1cc7ff79b8ce5ea0bec6797a017e1be338a46248/protobuf-6.33.5-py3-none-any.whl", hash = "sha256:69915a973dd0f60f31a08b8318b73eab2bd6a392c79184b3612226b0a3f8ec02", size = 170687, upload-time = "2026-01-29T21:51:32.557Z" },
]
[[package]]
name = "psutil"
version = "7.0.0"
@@ -973,6 +1072,36 @@ wheels = [
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/a2/09/77d55d46fd61b4a135c444fc97158ef34a095e5681d0a6c10b75bf356191/sympy-1.14.0-py3-none-any.whl", hash = "sha256:e091cc3e99d2141a0ba2847328f5479b05d94a6635cb96148ccb3f34671bd8f5", size = 6299353, upload-time = "2025-04-27T18:04:59.103Z" },
]
[[package]]
name = "tensorboard"
version = "2.20.0"
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
dependencies = [
{ name = "absl-py" },
{ name = "grpcio" },
{ name = "markdown" },
{ name = "numpy" },
{ name = "packaging" },
{ name = "pillow" },
{ name = "protobuf" },
{ name = "setuptools" },
{ name = "tensorboard-data-server" },
{ name = "werkzeug" },
]
wheels = [
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/9c/d9/a5db55f88f258ac669a92858b70a714bbbd5acd993820b41ec4a96a4d77f/tensorboard-2.20.0-py3-none-any.whl", hash = "sha256:9dc9f978cb84c0723acf9a345d96c184f0293d18f166bb8d59ee098e6cfaaba6", size = 5525680, upload-time = "2025-07-17T19:20:49.638Z" },
]
[[package]]
name = "tensorboard-data-server"
version = "0.7.2"
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
wheels = [
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/7a/13/e503968fefabd4c6b2650af21e110aa8466fe21432cd7c43a84577a89438/tensorboard_data_server-0.7.2-py3-none-any.whl", hash = "sha256:7e0610d205889588983836ec05dc098e80f97b7e7bbff7e994ebb78f578d0ddb" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/b7/85/dabeaf902892922777492e1d253bb7e1264cadce3cea932f7ff599e53fea/tensorboard_data_server-0.7.2-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:9fe5d24221b29625dbc7328b0436ca7fc1c23de4acf4d272f1180856e32f9f60" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/73/c6/825dab04195756cf8ff2e12698f22513b3db2f64925bdd41671bfb33aaa5/tensorboard_data_server-0.7.2-py3-none-manylinux_2_31_x86_64.whl", hash = "sha256:ef687163c24185ae9754ed5650eb5bc4d84ff257aabdc33f0cc6f74d8ba54530" },
]
[[package]]
name = "threadpoolctl"
version = "3.6.0"
@@ -1120,6 +1249,18 @@ wheels = [
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/a7/c2/fe1e52489ae3122415c51f387e221dd0773709bad6c6cdaa599e8a2c5185/urllib3-2.5.0-py3-none-any.whl", hash = "sha256:e6b01673c0fa6a13e374b50871808eb3bf7046c4b125b216f6bf1cc604cff0dc", size = 129795, upload-time = "2025-06-18T14:07:40.39Z" },
]
[[package]]
name = "werkzeug"
version = "3.1.5"
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
dependencies = [
{ name = "markupsafe" },
]
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/5a/70/1469ef1d3542ae7c2c7b72bd5e3a4e6ee69d7978fa8a3af05a38eca5becf/werkzeug-3.1.5.tar.gz", hash = "sha256:6a548b0e88955dd07ccb25539d7d0cc97417ee9e179677d22c7041c8f078ce67", size = 864754, upload-time = "2026-01-08T17:49:23.247Z" }
wheels = [
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/ad/e4/8d97cca767bcc1be76d16fb76951608305561c6e056811587f36cb1316a8/werkzeug-3.1.5-py3-none-any.whl", hash = "sha256:5111e36e91086ece91f93268bb39b4a35c1e6f1feac762c9c822ded0a4e322dc", size = 225025, upload-time = "2026-01-08T17:49:21.859Z" },
]
[[package]]
name = "yarl"
version = "1.20.1"