From ed8270b0f3c726fd7651008a5010749bc0a1cdc0 Mon Sep 17 00:00:00 2001 From: Jiao77 Date: Wed, 11 Feb 2026 21:41:40 +0800 Subject: [PATCH] common update --- TODO.md | 16 +- configs/default.yaml | 18 +- examples/generate_sample_data.py | 102 +++++++ examples/run_sample_flow.py | 89 ++++++ examples/simple_layout.gds | 1 + ...6771.jiao77-macdeMacBook-Air.local.70402.0 | Bin 0 -> 88 bytes ...7085.jiao77-macdeMacBook-Air.local.72789.0 | Bin 0 -> 88 bytes ...7175.jiao77-macdeMacBook-Air.local.73741.0 | Bin 0 -> 88 bytes ...7223.jiao77-macdeMacBook-Air.local.74546.0 | Bin 0 -> 88 bytes ...7223.jiao77-macdeMacBook-Air.local.74546.1 | Bin 0 -> 88 bytes main.py | 49 +++- pyproject.toml | 1 + scripts/visualize_attention.py | 106 ++++--- .../__pycache__/evaluator.cpython-312.pyc | Bin 0 -> 2977 bytes .../self_supervised.cpython-312.pyc | Bin 0 -> 12197 bytes .../__pycache__/trainer.cpython-312.pyc | Bin 0 -> 12323 bytes src/engine/self_supervised.py | 261 +++++++++++++++--- src/engine/trainer.py | 195 +++++++++++-- .../geo_layout_transformer.cpython-312.pyc | Bin 0 -> 3557 bytes .../__pycache__/gnn_encoder.cpython-312.pyc | Bin 0 -> 3051 bytes .../__pycache__/task_heads.cpython-312.pyc | Bin 0 -> 7595 bytes .../transformer_core.cpython-312.pyc | Bin 0 -> 3958 bytes src/models/geo_layout_transformer.py | 24 +- src/models/gnn_encoder.py | 9 +- src/models/task_heads.py | 99 ++++++- src/utils/__init__.py | 6 + .../__pycache__/__init__.cpython-312.pyc | Bin 0 -> 376 bytes .../__pycache__/config_loader.cpython-312.pyc | Bin 0 -> 1558 bytes src/utils/__pycache__/logging.cpython-312.pyc | Bin 0 -> 1266 bytes src/utils/__pycache__/seed.cpython-312.pyc | Bin 0 -> 1095 bytes src/utils/seed.py | 33 +++ tests/test_model_run.py | 199 +++++++++++++ uv.lock | 143 +++++++++- 33 files changed, 1227 insertions(+), 124 deletions(-) create mode 100644 examples/generate_sample_data.py create mode 100644 examples/run_sample_flow.py create mode 100644 examples/simple_layout.gds create mode 100644 logs/events.out.tfevents.1770816771.jiao77-macdeMacBook-Air.local.70402.0 create mode 100644 logs/events.out.tfevents.1770817085.jiao77-macdeMacBook-Air.local.72789.0 create mode 100644 logs/events.out.tfevents.1770817175.jiao77-macdeMacBook-Air.local.73741.0 create mode 100644 logs/events.out.tfevents.1770817223.jiao77-macdeMacBook-Air.local.74546.0 create mode 100644 logs/pretrain/events.out.tfevents.1770817223.jiao77-macdeMacBook-Air.local.74546.1 create mode 100644 src/engine/__pycache__/evaluator.cpython-312.pyc create mode 100644 src/engine/__pycache__/self_supervised.cpython-312.pyc create mode 100644 src/engine/__pycache__/trainer.cpython-312.pyc create mode 100644 src/models/__pycache__/geo_layout_transformer.cpython-312.pyc create mode 100644 src/models/__pycache__/gnn_encoder.cpython-312.pyc create mode 100644 src/models/__pycache__/task_heads.cpython-312.pyc create mode 100644 src/models/__pycache__/transformer_core.cpython-312.pyc create mode 100644 src/utils/__init__.py create mode 100644 src/utils/__pycache__/__init__.cpython-312.pyc create mode 100644 src/utils/__pycache__/config_loader.cpython-312.pyc create mode 100644 src/utils/__pycache__/logging.cpython-312.pyc create mode 100644 src/utils/__pycache__/seed.cpython-312.pyc create mode 100644 src/utils/seed.py create mode 100644 tests/test_model_run.py diff --git a/TODO.md b/TODO.md index 7aa0b12..e752ca9 100644 --- a/TODO.md +++ b/TODO.md @@ -49,37 +49,37 @@ ## 优先级清单(可执行项) ### P0(立即优先) -- [ ] 数据集切分与 DataLoader 管线 +- [x] 数据集切分与 DataLoader 管线 - 在 `main.py` 引入可配置的 train/val/test 切分比例与随机种子;支持从目录/清单载入各 split。 - 为 `configs/default.yaml` 增加 `splits` 字段;更新 `README*` 用法说明。 -- [ ] 监督训练工程化 +- [x] 监督训练工程化 - 在 `trainer.py` 补充验证阶段与最佳模型保存(`torch.save` 至指定路径)。 - 引入学习率调度器(如 StepLR/CosineAnnealingWarmRestarts)与早停策略。 - 支持 class weights/focal loss:在 `trainer.py` 增加 `focal_loss` 实现并在配置选择。 -- [ ] 自监督预训练修复 +- [x] 自监督预训练修复 - 明确 batch 内每图的 patch 序列映射:根据 `batch.ptr` 逐图生成 mask 索引,避免跨图混淆。 - 将掩码作用在输入特征/图结构层而非已池化的图级嵌入;或增加“节点级→patch 聚合→重建头”。 - 为 `transformer_core` 或单独模块增加重建头(MLP)以回归原 patch 表征;提供单元测试。 ### P1(高优) -- [ ] 任务头与损失扩展 +- [x] 任务头与损失扩展 - 在 `task_heads.py` 增加多标签分类、回归头;增添可插拔的池化(CLS token/Mean/Max/Attention Pool)。 - 在 `trainer.py` 支持多任务训练配置(不同 head/loss 的加权)。 -- [ ] 训练与日志可观测性 +- [x] 训练与日志可观测性 - 增加 TensorBoard/CSVLogger;记录 epoch 指标、学习率、耗时;保存 `config` 与 `git` 提交信息。 - 固定随机种子(PyTorch/NumPy/环境变量),在 `utils` 中提供 `set_seed()` 并在入口调用。 -- [ ] 可复现实验与最小数据 +- [x] 可复现实验与最小数据 - 提供最小 GDS 示例与对应的 processed `.pt` 小样,便于 CI 与用户快速体验。 - 在 `scripts/` 增加一键跑通的小样流程脚本(preprocess→train→eval)。 ### P2(中优) -- [ ] 大图/性能优化 +- [x] 大图/性能优化 - 引入混合精度(`torch.cuda.amp`)、梯度累积、可选更小 batch,监控显存。 - 探索 GraphSAINT/Cluster-GCN 等大图训练策略,并与当前 patch 划分结合。 - [ ] I/O 与生态集成 - `klayout` Python API 的可选集成与安装脚本说明;解析 OASIS 的路径补全与测试。 - 在 `graph_constructor.py` 为边策略加入可学习/基于几何关系的拓展(如跨层连接边)。 -- [ ] 可解释性与可视化 +- [x] 可解释性与可视化 - 完成 `scripts/visualize_attention.py`:注册 Hook 提取注意力/特征图,绘图并保存到 `docs/`。 - 在 `Data.node_meta` 基础上支持几何叠加可视化(patch bbox 与局部多边形)。 diff --git a/configs/default.yaml b/configs/default.yaml index ea6e42a..95054ce 100644 --- a/configs/default.yaml +++ b/configs/default.yaml @@ -22,7 +22,7 @@ model: hidden_dim: 128 output_dim: 256 # Dimension of the patch embedding num_layers: 4 - gnn_type: "rgat" # 'rgat', 'gcn', 'graphsage' + gnn_type: "gat" # 'gat', 'gcn', 'graphsage' # Transformer Backbone transformer: @@ -42,9 +42,25 @@ training: optimizer: "adamw" loss_function: "bce" # 'bce', 'focal_loss' weight_decay: 0.01 + scheduler: "cosine" # 'step', 'cosine' + scheduler_T_0: 10 + scheduler_T_mult: 2 + early_stopping_patience: 10 + save_dir: "checkpoints" + log_dir: "logs" + use_amp: false # 是否启用混合精度训练 + gradient_accumulation_steps: 1 # 梯度累积步数 + +# 4. Data Splits +splits: + train_ratio: 0.8 + val_ratio: 0.1 + test_ratio: 0.1 + random_seed: 42 # 4. Self-Supervised Pre-training pretraining: mask_ratio: 0.15 epochs: 200 learning_rate: 0.0005 + early_stopping_patience: 10 diff --git a/examples/generate_sample_data.py b/examples/generate_sample_data.py new file mode 100644 index 0000000..f263ed6 --- /dev/null +++ b/examples/generate_sample_data.py @@ -0,0 +1,102 @@ +#!/usr/bin/env python3 +""" +生成示例数据的脚本 +- 创建一个简单的 GDS 文件 +- 使用 preprocess_gds.py 处理它,生成示例数据集 +""" +import os +import sys +import gdstk +import numpy as np + +# 添加项目根目录到 Python 路径 +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +def create_simple_gds(output_file): + """创建一个简单的 GDS 文件,包含几个矩形""" + # 创建一个新的库 + lib = gdstk.Library("simple_layout") + + # 创建一个新的单元 + top_cell = lib.new_cell("TOP") + + # 在不同层上添加几个矩形 + # 层 1: 金属层 1 + rect1 = gdstk.rectangle((0, 0), (10, 10), layer=1, datatype=0) + top_cell.add(rect1) + + # 层 2: 过孔层 + via = gdstk.rectangle((4, 4), (6, 6), layer=2, datatype=0) + top_cell.add(via) + + # 层 3: 金属层 2 + rect2 = gdstk.rectangle((2, 2), (8, 8), layer=3, datatype=0) + top_cell.add(rect2) + + # 保存 GDS 文件 + lib.write_gds(output_file) + print(f"已创建 GDS 文件: {output_file}") + +def preprocess_sample_data(gds_file, output_dir): + """使用 preprocess_gds.py 处理 GDS 文件,生成示例数据集""" + import subprocess + + # 确保输出目录存在 + os.makedirs(output_dir, exist_ok=True) + + # 运行 preprocess_gds.py 脚本 + script_path = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "scripts", "preprocess_gds.py") + + # 创建层映射配置 + layer_mapping = { + "1/0": 0, # 金属层 1 + "2/0": 1, # 过孔层 + "3/0": 2 # 金属层 2 + } + + # 构建命令 + cmd = [ + sys.executable, script_path, + "--gds-file", gds_file, + "--output-dir", output_dir, + "--patch-size", "5.0", + "--patch-stride", "2.5" + ] + + # 添加层映射参数 + for layer_str, idx in layer_mapping.items(): + cmd.extend(["--layer-mapping", f"{layer_str}:{idx}"]) + + print(f"运行预处理命令: {' '.join(cmd)}") + + # 执行命令 + result = subprocess.run(cmd, capture_output=True, text=True) + + if result.returncode == 0: + print("预处理成功完成!") + print("输出:") + print(result.stdout) + else: + print("预处理失败!") + print("错误:") + print(result.stderr) + +def main(): + """主函数""" + # 定义路径 + examples_dir = os.path.dirname(os.path.abspath(__file__)) + gds_file = os.path.join(examples_dir, "simple_layout.gds") + output_dir = os.path.join(examples_dir, "processed_data") + + # 创建 GDS 文件 + create_simple_gds(gds_file) + + # 预处理数据 + preprocess_sample_data(gds_file, output_dir) + + print("\n示例数据生成完成!") + print(f"GDS 文件: {gds_file}") + print(f"处理后的数据: {output_dir}") + +if __name__ == "__main__": + main() diff --git a/examples/run_sample_flow.py b/examples/run_sample_flow.py new file mode 100644 index 0000000..946eecf --- /dev/null +++ b/examples/run_sample_flow.py @@ -0,0 +1,89 @@ +#!/usr/bin/env python3 +""" +一键运行的小样流程脚本 +- 生成示例数据 +- 训练模型 +- 评估模型 +""" +import os +import sys +import subprocess +import time + +# 添加项目根目录到 Python 路径 +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +def run_command(cmd, cwd=None): + """运行命令并打印输出""" + print(f"\n运行命令: {' '.join(cmd)}") + result = subprocess.run(cmd, cwd=cwd, capture_output=True, text=True) + print("输出:") + print(result.stdout) + if result.stderr: + print("错误:") + print(result.stderr) + if result.returncode != 0: + print(f"命令执行失败,返回码: {result.returncode}") + sys.exit(1) + return result + +def generate_sample_data(): + """生成示例数据""" + print("\n=== 步骤 1: 生成示例数据 ===") + script_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "generate_sample_data.py") + run_command([sys.executable, script_path]) + return os.path.join(os.path.dirname(os.path.abspath(__file__)), "processed_data") + +def train_model(data_dir): + """训练模型""" + print("\n=== 步骤 2: 训练模型 ===") + main_script = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "main.py") + config_file = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "configs", "hotspot_detection.yaml") + + # 运行训练命令 + cmd = [ + sys.executable, main_script, + "--config-file", config_file, + "--mode", "train", + "--data-dir", data_dir + ] + run_command(cmd) + +def evaluate_model(data_dir): + """评估模型""" + print("\n=== 步骤 3: 评估模型 ===") + main_script = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "main.py") + config_file = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "configs", "hotspot_detection.yaml") + + # 运行评估命令 + cmd = [ + sys.executable, main_script, + "--config-file", config_file, + "--mode", "eval", + "--data-dir", data_dir + ] + run_command(cmd) + +def main(): + """主函数""" + start_time = time.time() + + print("Geo-Layout Transformer 小样流程") + print("==============================") + + # 步骤 1: 生成示例数据 + data_dir = generate_sample_data() + + # 步骤 2: 训练模型 + train_model(data_dir) + + # 步骤 3: 评估模型 + evaluate_model(data_dir) + + total_time = time.time() - start_time + print(f"\n=== 流程完成 ===") + print(f"总耗时: {total_time:.2f} 秒") + print("示例流程已成功运行!") + +if __name__ == "__main__": + main() diff --git a/examples/simple_layout.gds b/examples/simple_layout.gds new file mode 100644 index 0000000..59905d0 --- /dev/null +++ b/examples/simple_layout.gds @@ -0,0 +1 @@ +GDSII*\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\ \ No newline at end of file diff --git a/logs/events.out.tfevents.1770816771.jiao77-macdeMacBook-Air.local.70402.0 b/logs/events.out.tfevents.1770816771.jiao77-macdeMacBook-Air.local.70402.0 new file mode 100644 index 0000000000000000000000000000000000000000..110115f10c4ff9da0cbc8d74d5721d1b106c94fd GIT binary patch literal 88 zcmeZZfPjCKJmzv{T)5((n0(7oiZ`h!F*8rkwJbHS#L6g0k4vW{HLp0oC@DX&C`GTh hG&eV~s8X-ID6=HBNG}znDn2bUCp8`-lHsYg7XTPsAHVNT9%quVr3Mh$E8z}npd1(l$4)Xl%iK$ hnwy(gRH;{9lv$Emq?Za(6`z)wlNt{Zsbbi9008tjAD;jK literal 0 HcmV?d00001 diff --git a/logs/events.out.tfevents.1770817175.jiao77-macdeMacBook-Air.local.73741.0 b/logs/events.out.tfevents.1770817175.jiao77-macdeMacBook-Air.local.73741.0 new file mode 100644 index 0000000000000000000000000000000000000000..8bdd61fa08d8b08d2e8ec5ccb9f9fe5a94484904 GIT binary patch literal 88 zcmeZZfPjCKJmzwG%{jMJG5MCG6mL>dVrHJ6YguYuiIq{19+yr@YF=@EQBrM*}QKepaQD#YMkzOiDReV}zPHH?vL?rigBLEv_AO-*c literal 0 HcmV?d00001 diff --git a/logs/events.out.tfevents.1770817223.jiao77-macdeMacBook-Air.local.74546.0 b/logs/events.out.tfevents.1770817223.jiao77-macdeMacBook-Air.local.74546.0 new file mode 100644 index 0000000000000000000000000000000000000000..76d3d89a16bca8aea0a9409d52464409dc0520da GIT binary patch literal 88 zcmeZZfPjCKJmzu+x%}9un0(7oiZ`h!F*8rkwJbHS#L6g0k4vW{HLp0oC@DX&C`GTh hG&eV~s8X-ID6=HBNG}znDn2bUCp8`-a=2.3.2", "pyyaml>=6.0.2", "scikit-learn>=1.7.1", + "tensorboard>=2.20.0", "torch>=2.8.0", "torch-geometric>=2.6.1", "torchvision>=0.23.0", diff --git a/scripts/visualize_attention.py b/scripts/visualize_attention.py index 635d0b3..be510a5 100644 --- a/scripts/visualize_attention.py +++ b/scripts/visualize_attention.py @@ -3,6 +3,7 @@ import argparse import torch import matplotlib.pyplot as plt import seaborn as sns +import os from src.utils.config_loader import load_config from src.models.geo_layout_transformer import GeoLayoutTransformer @@ -13,52 +14,93 @@ def main(): parser.add_argument("--config-file", required=True, help="模型配置文件的路径。") parser.add_argument("--model-path", required=True, help="已训练模型检查点的路径。") parser.add_argument("--patch-data", required=True, help="区块数据样本(.pt 文件)的路径。") + parser.add_argument("--output-dir", default="docs/attention_visualization", help="注意力图保存目录。") + parser.add_argument("--layer-index", type=int, default=0, help="要可视化的 Transformer 层索引。") + parser.add_argument("--head-index", type=int, default=-1, help="要可视化的注意力头索引,-1 表示所有头的平均值。") args = parser.parse_args() logger = get_logger("Attention_Visualizer") - logger.info("这是一个用于注意力可视化的占位符脚本。") - logger.info("完整的实现需要加载一个训练好的模型、一个数据样本,然后提取注意力权重。") + # 确保输出目录存在 + os.makedirs(args.output_dir, exist_ok=True) # 1. 加载配置和模型 - # logger.info("正在加载模型...") - # config = load_config(args.config_file) - # model = GeoLayoutTransformer(config) - # model.load_state_dict(torch.load(args.model_path)) - # model.eval() + logger.info("正在加载模型...") + config = load_config(args.config_file) + model = GeoLayoutTransformer(config) + model.load_state_dict(torch.load(args.model_path, map_location=torch.device('cpu'))) + model.eval() # 2. 加载一个数据样本 - # logger.info(f"正在加载数据样本从 {args.patch_data}") - # sample_data = torch.load(args.patch_data) + logger.info(f"正在加载数据样本从 {args.patch_data}") + sample_data = torch.load(args.patch_data) # 3. 注册钩子(Hook)到模型中以提取注意力权重 - # 这是一个复杂的过程,需要访问 nn.MultiheadAttention 模块的前向传播过程。 - # attention_weights = [] - # def hook(module, input, output): - # # output[1] 是注意力权重 - # attention_weights.append(output[1]) - # model.transformer_core.transformer_encoder.layers[0].self_attn.register_forward_hook(hook) + attention_weights = [] + + def hook(module, input, output): + # 对于 PyTorch 的 nn.MultiheadAttention,output 是一个元组 + # output[0] 是注意力输出,output[1] 是注意力权重 + if len(output) > 1: + attention_weights.append(output[1]) + + # 获取指定层的自注意力模块 + if hasattr(model.transformer_core.transformer_encoder, 'layers'): + layer = model.transformer_core.transformer_encoder.layers[args.layer_index] + if hasattr(layer, 'self_attn'): + layer.self_attn.register_forward_hook(hook) + logger.info(f"已注册钩子到 Transformer 层 {args.layer_index} 的自注意力模块") + else: + logger.error("找不到自注意力模块") + return + else: + logger.error("找不到 Transformer 层") + return # 4. 运行一次前向传播以获取权重 - # logger.info("正在运行前向传播...") - # with torch.no_grad(): - # # 模型需要修改以支持返回注意力权重,或者通过钩子获取 - # _ = model(sample_data) + logger.info("正在运行前向传播...") + with torch.no_grad(): + _ = model(sample_data) # 5. 绘制注意力图 - # if attention_weights: - # logger.info("正在绘制注意力图...") - # # attention_weights[0] 的形状是 [batch_size, num_heads, seq_len, seq_len] - # # 我们取第一项,并在所有头上取平均值 - # avg_attention = attention_weights[0][0].mean(dim=0).cpu().numpy() - # plt.figure(figsize=(10, 10)) - # sns.heatmap(avg_attention, cmap='viridis') - # plt.title("区块之间的平均注意力图") - # plt.xlabel("区块索引") - # plt.ylabel("区块索引") - # plt.show() - # else: - # logger.warning("未能提取注意力权重。") + if attention_weights: + logger.info("正在绘制注意力图...") + # attention_weights[0] 的形状是 [batch_size, num_heads, seq_len, seq_len] + attn_weights = attention_weights[0] + batch_size, num_heads, seq_len, _ = attn_weights.shape + + logger.info(f"注意力权重形状: batch_size={batch_size}, num_heads={num_heads}, seq_len={seq_len}") + + # 选择第一个样本 + sample_attn = attn_weights[0] + + if args.head_index == -1: + # 计算所有头的平均值 + avg_attention = sample_attn.mean(dim=0).cpu().numpy() + plt.figure(figsize=(12, 10)) + sns.heatmap(avg_attention, cmap='viridis', square=True, vmin=0, vmax=1) + plt.title(f"所有注意力头的平均注意力图 (Layer {args.layer_index})") + plt.xlabel("区块索引") + plt.ylabel("区块索引") + output_file = os.path.join(args.output_dir, f"attention_layer_{args.layer_index}_avg.png") + plt.savefig(output_file, bbox_inches='tight', dpi=150) + logger.info(f"已保存平均注意力图到 {output_file}") + else: + # 可视化指定的注意力头 + if 0 <= args.head_index < num_heads: + head_attention = sample_attn[args.head_index].cpu().numpy() + plt.figure(figsize=(12, 10)) + sns.heatmap(head_attention, cmap='viridis', square=True, vmin=0, vmax=1) + plt.title(f"注意力头 {args.head_index} 的注意力图 (Layer {args.layer_index})") + plt.xlabel("区块索引") + plt.ylabel("区块索引") + output_file = os.path.join(args.output_dir, f"attention_layer_{args.layer_index}_head_{args.head_index}.png") + plt.savefig(output_file, bbox_inches='tight', dpi=150) + logger.info(f"已保存注意力头 {args.head_index} 的注意力图到 {output_file}") + else: + logger.error(f"注意力头索引 {args.head_index} 超出范围,有效范围是 0-{num_heads-1}") + else: + logger.warning("未能提取注意力权重。") if __name__ == "__main__": main() diff --git a/src/engine/__pycache__/evaluator.cpython-312.pyc b/src/engine/__pycache__/evaluator.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb132f4edae38fb7fdff4b4598593ab776953fac GIT binary patch literal 2977 zcmbsrZEVv<^c;W0Nt!l50!^Xd0s{i{i;n>v0|5mFLV);~)=0=Qwv#0|Nq4r#id+dz z6{K5Pi>8V4F-XOvc4A`D_6Lb+e>Z7=Ql+S!S2QFj9sd|bLgLr%?6{QDX-u5uckkZU zy?b}}-u>qHdjV>18(oPmHvsqXz-d$onK*)w3qS%Ap@1mvT?&zMMO_5ONrg;NQ7T18 z=@b)XQtqfb#YS1gqa<4Kq`Xlt0bF1=keCi2xo@~^t5IJY5n&(U!TtzgdYc7J*q0F` z8AfQcTo6Sa3St|tiWw-I6|GPf6KW!pE)mMmSWuMGh>bOtFh9(Qyr7Hr@DZL18;{Ex zuVmtJ8H)DoHY~XUCKjXd1t5bcA%Un%BBG?^>HtwnB4s+}lBfgRub-|^t5{Qc9PijG-430e9lq6oPF48j0Brdv(JT&o!% z=|tNhp+3n>8I;f)gpD9OlNxPkk^52D=u}3M6^M7#^zgi>2&&5SCd>0_AtfV>v4aqC znUpFkF<6cyRqQUuJwI4~RF$DxpG*ju4I64xf>__25%rXu*3^0t8CL6e%bD7JLR&`H zYL7r6t;RAiCBu3ZiuH0jo=D5}vJ;-V);2tEBAw89{?zQru+=#kF%^Jz3e8vA4o4UF2IXv0J3YlJJ5|QU z4cd{Eh!f4oB4fCcz8U10r2kc_vj)co@{F8On@Zx1p4MK^VwX+<54!O9bFMA{rmQM3 z%`6%kF2|$foL9My*X^$VutL6;6;i_&(P{>XW-eX6cDFWfMmnp7XQX?y`7_eJTJ?-{ zAKLqc4(COmta*Vum^8T(v_r42MHz01zD#{R8O{Ht%^2((XuA#18)&nJ_YJf?hVK<^ zx4{}7!;3iHEA(No1T+GbIEe~S3%ZX4=pnlHenPf@lSIU)FGq17yD~6-`{LNmOXHV1 z9)6X7c<$z-%V!?k`piz!cEU!fI&N>#y{-Foe7|h&`0XzrUcJ(!S0W0>jh#I+-gk5S zi?iH{ZKa|VY2r?ib?ag}K1m!m{=?6X#6hP_ArXD9t>69K*yXF1{voSoVU}%rS0gu; zJ%iliR_to5J&Y?IHtplMZAW+19y+k2i9<=9b!wsMnXG~)W7j{^6>Bs}RPQ3QtXytV zI02h9s!!sFraPVC<4}-Hx0;BjG6~6Kgw|F$Et#a)s+&X`>IKcD(|W444dQA7aSnty zsY3J^01&4vI12%jN~B|%NEw`sDR^b32iITZk))blj4Hx$Sy4^;xS*lA8B`=%b!6O< zfOx?WhYMC=fUg={fPh+t05_#UoH0zcR7d2gmBXP~MXMj>-P7P6Th{IE^mbQtRt&M> zk>K2{XCyQ)>wU_A@;TkdI**;>{|GPXp+~~i*Qzd7<#rasD|?s+mDSh$7yY?dv9dN# z4py#30)g;wfEx^Oxln(&FMNIeP@r}sP&FK=84T3q-s)f8x4eH<->U1icLEK6`#{yw zXCx@EemWb3sxIs~zo$1*46Yas)(r;hiou4gXB3fJ&bRcQDuz}LhZ+V$4aLwJ%dOzz z+{%GOan8mpI~rV)+ne_kgH1?bR&{T3AX1#QF3XJiLp{k{r08Frb&ZzK@7+7#DVDFv zl6bMcPTw$FImlM_zB|OO9GOZf%wL;-@7Ce`mO^;@5W8a+yxwcHAuq7ZHvVjJ%*SeoDTD^qy? zgMuIpR>%cbezqQksFQzg?eUM>Htk+d{<5Clz16)x;)b|vm|mVoFBVRs$7~H9BXUho z0q$)jERRkwPTFSBx*>dA&ZJ}wCd4}YzOd#oE7TK;EWmV~%~AjLVwiqiODJj`ep4mV zaqHP-JJo=uc;TqYZnjss1+GI3{9sTy1Rj$FLEHyrcfqPZ!SOrb_+7B<8MT6l5l;Yy HmW6)+#~qVF literal 0 HcmV?d00001 diff --git a/src/engine/__pycache__/self_supervised.cpython-312.pyc b/src/engine/__pycache__/self_supervised.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ecebb5d9616bf25ac66168a76a3b3c264c9e9ded GIT binary patch literal 12197 zcmeHNeNbE1m4EsIeF3tOK;j#lFCma@2OHZMe_#wYU{Zf2HA$?>de2z0KFE_~$5bS4 z*3;s~sYN$gV4W0^c8f@yTIseM(zFe;GrQxR{pY<1=V@P^8PC*!v;UOVnaNE5***6? z-g`1u++?T!>`WiF?s@l~d+xdC-goagzkB~PFHeiXQ}*&z&;Csq_IvaqJjzw#fe(mr z499Sp4;$c9nO{C2mmym2Q}~qwN{&|eRQ{ZS9FA7{a{cN7HP9+N$EWdY2ee$i%%>aB zakBISdO6mO;klPFT>XxmZ*3s&qRguK488PQWvpr+?(#nkgtpTacJ&2aI6=Z&-m!?^ z?;dwd@&pJnvIAr-{^as1DSfLq84B=>5TWRY<+4EQ&eB;;0O1)}{7N~btYl9rI zb~sNc1vvtlk$MXR?{dL&H^+pN53}|!pMPqvGs7_Dl_Ia3Y3(Jk}#)E$jzu> z7V@MtTq@AsGJ%p(&`8N?q~d;a(67*IAqDNkK7C1_t0jp7h35iqh43b&LZ4-r*tVK$ zTZ4W_6k=|MMkC6MK5md?2~sh(d<{aWlnSGe!IiyT zrO?}VR#ybDrAd$zm3x~7T1s&(t`lkjcd<|s-YihwEdnK_U~BY+u6c0%jXXgXvUk?@ z^6>(J2ihwUMEu3}S_DoheYm}Ft55>kW#PVI5UdUDigPpECe*^$Wt59EN%)4c_GLX@ zn9-L?QIZi_s1n~y@Fvb!kXM9eURHUbgL1x$&QLfwG=y{< za)muaz)gIH3`3&ruZvZSeaVf42&c+SKpO#)v?yRy}$$a2sbKJqe8P6b@i(WKN zr`zWWg`7@S?Q{lQegdc*etRa-e#dHuTqI<}4jy7P!J)9n?->RQ$-$hthmRfT3x*)y z&290Z7pwV(z|N||LDGGeRe@dc0J#ncTTc>$U^0S)-mgk!wa19(BhWvO%g5&Q!M5!p zSxukoB0=`UV6e)gMBfRvkR(9uP?(Ikx$d7OTsVv=j6K@ASa~qSs{O8W1dKt5vpI>j zQ8woSrx>d`;|sdNtmZTUeQ zmk;=ILR?Cgb38D#L92>YhQI;g3{0ZMfJA#MX@H*)Qltt!e}2kwA_Qi1cs;IQYiqOL z<#u!iVN9X(A%`2LX~@w{1e>AnFi)GohY{wZBSg9#L|_n%<3Q~=L!zf+AG(M}4nBae z(>Yv`)eW)>nQ8igs|hmWqk+ z)bQjmQ`tgSwxlYz&25`+rrW+j9r+gB=320v2Cn>jmeQ$`$&%|xQ2ysRSY`E;Z_;;T z>&>>=w&eM_!S`O8kI=1$Z~5rfXHvFjQx(sRDi_qHch&l9+ACT{T}!KLsk&`*N@ja6 zy}frq-M5m16_rhNk2&Md(WWPoeY9!cs5WgX<+4uDCR+lsc8qH8Y^8|Kvo#Thc4ktQqt2cu_=8Ey*vEjHi;Y<}HF1^InH`bSExvk!Gr)XoM^Jd>{ z-vVzvFhxn_M1;3X+FUU)LYtcz^Jdz-ncC8O>uAdS4b)t^*o@?*ODZ8L;F_B7(#2M$ zxS1|)PI%@_sp35x$JQP1Kl}4%Un9o7W8U$=Sb!;V&_#}gqRsPRDAz&NrY*@v+Sbn4 z_R+R|RELvtole=@sHY8ST>+!B(K=hg^|8*8)*G(%Ug>4@WwgGGs_d9=VfMS}{cfrQ zr^<=j`ZFsEOuy*?Ocz)HaB%A3^8u&cjaDh)L4G9X5nqJ`3_RlOJ}Tl^^9Z{R_z9SI z)mzDA@jkW=E*k-FJDXq4`?|P-_wK^TW9OczOxQO%v6%9?pJUjw(2o)2h$^ZCNm=%# zVer~<**WmXNVFU21xZr+INzJ=7y9(qJTW;eM(c&lY(BTQNGO$3FY3;kN3h^ljB9I6 zDSGvazPMBvRS9+&$%(46`RF6LkT3c_QcYZmt4gphDTBzd1RDk`j;J%d)GDdAKtAq_q%@i*aN??HK3G$S-~=EJJn{?>njyxJL}p9IxoVORWntW@H^$JBwHE-74`4RCJwpn?Osc z&{A|=)UXS(foq3UDp1i}7@wSXR=;GfRiP*&^@<+2lo=I$@sahQE80ERJkj-n4kLxp z!fdVMIbenMHPzr1R|+)+4|W9nb?=@>Iitp?0gZh$hfhU*6Cfq-d-p$@ z_mW9y-Fx8CQelKZqLQC`pxI;BCRk6l15dWryC0PgoLSZ}L8CA2$Dv0hNt(uFfiFPr zqbz@0f50YxTYsd9A9%?s5S)3A%uT z%OL>(8F+#us@VV8|G*PD2;|b050-9B-eS*--rHIAXNA4xjJ1Ie%T-+HPXfM= zaS>|TX5B>t1 z3!UW%9WohU1OWMjc>D47SR-f0wNf}VDMP9e3_wPP!jL6t#jfWEz^+%o>m*fa5R+T=*;X;b^C_MW*q-pH7nuN)lR zHyXKHV2HIag-_7%ui%N~^B)&%PaBNm`D6Kvp@B9uEEpOgbTt^SMPG^DF_nxDj17#R z8au_9nrKtgZIk_Oqy1*%Y$MaSmu}oUZ%Q?GQ^tCzv7w%+-b`0-j`gJLY%`u|Pg0hY z$Nfp-eb0NIdByztWPq;gj6Df)cv#Qy=Ul-p8NI!SuI*VU?M<7@rgW1!#%!a_ zHmd3Tg84$aZfnv-*KLnI3Gy48C}Z87@(uBwpOo7tWOr+iJ#j&ADUK+5q9Wo%keV~rg$#@a<&yXHx%_Y~9X zrhDBKPNaIzq^yHb?QUu1)b7dM(2>M$+PpJuwk7&$bK3)zqO2&^`30Pc6**SCGHjh` zTsNj;4Ar!uI$o19)H8-nv|-bNVGA^DDw>du9hf-Il(*CHudsc7U8=C-uF*W+JJ!n> zYiMIl{P@hN=~D~Fmb8#tM;q%tf&QqmdfSY5+RN1KqU(0eok-Q~rwkiGYQ;vTyqPX< zj%lUg>WDWdYu~rNXPw(WcQk3IH|&iSzz~_s$1jduj8`U~oZm;cbkWsa3&q`HOZBv| zo@yW$jG^>~&BpcNjp{$m^n2)jPpZ`mrSVg=p#|01#n+iHx#grzeJfS(fI1F{j_aA)%jG)_B3Tcrmx%0)OOId9rIhL?g6ITNq0M` zZ@E(4r&G0VsCBovbo|A!7oivN7ir_Rw6Q*sPa9idc9_@4lq*^^E%L_lUPI>0G6_EB zCqe&MKVw9Luwb;_)fY|FOj#$b*KM@E_WM06FDWpKeWgQ*87qFGi1+-sfG%{5s$sE0 zARePeix*Yel=x=So#>}5Z42sM>4M_Xf-f#AU}n-L`=3MLvS9zZt82&c9jb*LsuKzg ztLIn5$AMl;xnB?T!bK2DhDis++2Ak1l^T(Up$+~4cS$YKQkn^P34t6yT=M5|aIq9n zCrkK9R(wgomIZW^!e!pU3hexfnj=_EQd6X|0eb+3DDWHvgFPc>Z-IUht-PU=2M)R?Pxco|A*S_HU?E zx-tOnJLVLDhm2!O@fcQA0P~8+s8SxZR3*I$6mAd$D0O(Fkn2S*@|p+gi^rN$kynKA zq&(rfgBp!eu0Z{T8rkO9t( ziyH8~!B^v!OXC84&cX+|O&JZZlINZU7f}xKx8j8`wjw|W+>bJ2S+zx=TNcT%0+=%A z8MHA7(5KkK<5tV4ock&?Uqg$csOHhR{a2fz#s9x&C~kg3@qe6)mbG&c#w6QH<#=?u z2;bn>n)@Zf-uJa?p<{<^-(cCpP{YM?YefWI;aG(ees*=~^1bp!0YPuhpM9xP7iC3d;ElhYc7J8Ctl9|OXEx9|FW}c zSYh9CCIW#TR<%6y!{K7?WL&tn;NUabVi&VgYCHFu2 z(ej7yRz+%&!U$sFjI#8`+e@!~u=o#`!Fs$Id6Od3p?d*Kmp&Bbh#fD6`wn4emaZyD z;_R?l6T03;(A848i;mJ_RW;b3^Ujo>0V^OD;U_ZC#7 z1nlB0tMtNU74C`zhZF(e0(V(>yh$q!C=WWJJ5(IR6C$iy9(p;?HNvAKNI6!;O()jC zC1`K~HsK2N?g&YuODY^bp%vfItqKG^@LU@1Dh;REDxiZZhqTa)gkSKug@TL(hAr9g z3Ml#ys1hm$pNq#SBv;?EZ9(072QZ0~V<-Rq8Neb;B@-=EZIf;B5xR71@+@7thcWFP zJ&-Oay!O)fUqVO(7nfGihAOJMbHQ+cHJI*IZk)O}dGY$UDSg>PN97Y2;*rFW=?igB z;`yZU{nGbJ=bocC_ue{4Z$3e7dWtH4dO`oqbZI41YNJcxnnJ2{$0)kaP*O2fFgmB=fd^!Lgtrn)D)r}`%Q62_b5v*pzGBdM~ZG0h!=X^k`jJ4#COuBWtZI9q@tlTh%1mx>&6ft>_USy4s*Mh43!II(fT1Y@2tv;-4^N>GHZY$RmKxw?Ybl@ro3 zjYyqfx1_d95{sr=ln24xzMLj$n4SH>%>L0urqbS$DKqs0e(aPj*-fW^cE9gj-7Dd$ zY0~NJ>>2x>*LS}2o$s9Uo$sN4GZ?fKq2C{0h8TCU}U}(_7!BVLVKZ{+C|aYvlOj!%e_F~^7M{EXt5U)NIgIr-jS!< zN`_?Cg1cDRZ&n{cx|_nb@l|Dw3~sN;b3TBz{w09W<0E$xr-AaOuw6T_yYa?&?FO@G&@c?eIX|s zU|cxAn9|fz_Zy+*Mc}EMvdd`7E~jO71+9!weKJ~pR%2JvinEkm6``J$+4BxkRu!+= z8E`p$J%L~_)DB}w&4N|S!h3Hly)yHM4@Z|S#O8nX%EFJ%E{%PA=_em7MBbU7eR<*i zpWeB2>F&gPcYiqf`!nBkDMeBsM{4l@{!8$ARW?DfkPC@T^N^y#)XO@GdMDj3qr$Q< zC_aE2976Hg7r8boZ~0Ntkp1X zI^<$K0l!tltAe3HH^U$@UhQyr{2tcf;8jlF;0Y(M>UZMStiVpW-_Nh;+P>qkhdt2~ z==ZQe

i_?s7X$LmO@uGy+?dLAS4uDTYBz34X|Mjl<#cIfFrmBZ!N!Ksxj1{mqAh zZYJ35^*94XzCt>X|I&)D3fLYpSF&q8!zV zS;Z=-72K++P1bDWYBtXH&DHFPC~vA)-iNW2OWGl*;Qo+(E(Ly10Jptqz9XcVU|;3p z%^6X!(s-}5yCNj)mqJI0dpDv4e0iBZu|&(l%3+W#&@1h&*as_zwqamySb^UW6+jU= zl!3q!EkC&yHpLWHd8Mhz}47L0*kV&WDFjKFIRg zKD@-F9z>Sm$+B3ODjb>*r{%Php@&yw&P2MY6#9c<_?)3qMP&v%~d zoP6Sjt}az%p4d9RHBq!Wu3w!ptT=z{+_A~f4Z~{e=p64%6xGG`b+?PmalM&Y2?SUb zgiDMD%4GG{e7uQ%=YSh-qSoOU2+$!Pmh}$DGa;uh-IMQd&;ggjK~6o>0Hb+B`jWZa zKA*#3l`)mj8pMP7C}9Y>=H!E0AH<9Kr?b?*E4%(nu2Suk{da*@RZVCjFfQnpl6Ye8 zcS6%L87l|53ewfB>XXyTvr2njx~zh7KAq=Q(fQD)q}5PY4{5AgJ|FHfUfK|nElAft zd>C1}IKK4yn+sQdJpX?54$jI)xr$;|x;`)lk3YG75!iFMekeep zNk~M1O4pHn5Q_KoYx7EYX!>zYcc*g)&K> zJAYLcos>g5OY};>Ae|@L3^m-aG`^gBP)*l}D*rV>FulJONYzNNM8oIvT*Thv4Vf{$ zj+VQ~VBVD7HW2c$M9~2x3IO5rdL7utD?I)_z6|)|8v;TCgM-MXK_}~R`(bHlIPK)s zLFXyAgZ41I4p!khIT-MOozM#wjoSUJC&)SiC;5D^LI{Ww3Set8dy$6xP|)pg4h-^D z{fv_aoUFs?0y`t*gBb#Tvh*O)!a^vj7sjrF-M9F|s|&yS@zT$(E?vHgmLMrCUil}& zN$Xl(1$vp=himuudHNY#EYY_ZP_IF2h55mDWl%}swS!Iu)HyJ3gUkkmLv4!JJ^|*e zdk4b=7zQODgFDV^T*M9rvIK2y6ZW7A!z)1hWkB7b7(D@Wi{>z|>hlGhEMIUOY-|U# zI)bhMO1NnScAmoSXd2NQp%r3+t4}d~F;5_LDz6J@m7}Tx!ywijX2Coof!4{<} zL_2u}vQfJW*mw}sZorrqBssH>z`!;JtV*J|Gy8EUUMRjGa|YOSS#@SKFQ_-Gr_8@+f`bSRN;jmUGs8V55ymLa$AqbpE(VS@{ zuCL?To`^qkIBtG2QFtV-Ka#bk0H?Ffthu3Un>W`a%^NuLhJ<~?d^)aw8jx33qq&>Xl}^@2E4j+`oNj%} zTo>KTnVV;HoOyF(f68c%mfkQn+^TJeDX$j1RdB8TdL`G@o9ud$>v}SN#2+7I5?w)9 zUKrR2*>?FpTxD3cU;Y_pgqxT_-$6NzS)be^2YWg!^CBZwos^F$UoA+jsEM}USh4O_ z?dI9ukB_~7EPmi{{K@0o0ax-sKX;%%e&Td|=o^UxVL-4`-h+7d$q(aHJLON}R6FHI zh1yZ=`+OQ|*)o$y6*%b)n!>Sa@CZ;Y5WMR}VGIn&_d6;mLuBa!y^9DzBa# z=E@tB<@4H<taIVw)~0mEq;4IjTNiWQ&^4#@rt^=Tdn~D6$?4%3c3o>t zZg+9pUGXkDZg$_)_uW%a`W9jj@>&{Pb|7THW#^LISa3u0;U7iZ{he^9i5mne%aa!& zmL>6}2lu3GM4n}%Ng87}H70xN1ByBd^NlFLB?6AHTuzZ;5Q(%*uq*%qrPs3L)h%m) zYG!^IgYtA)V*Z1b9BjFCdvD>g0Z{cyR#B!;oabQ*w8Gp@BBiLCfy+kem2}_?4zHxg zX4=9kX;dEIMR&)t6jp`vM7?MvpN7jvOU@`l3pyW`C9k9#W;&kF7iaQH_D-f(TpwUm z(28N`0BZ1&T9&eQ=_7#1HFw>zAy5U#9Q!HK^H*0P>5UnT89MzVfK~DiWFW%-C;Jf2 z{|9lgGBX~>zigJ&e)D_LimwTlqLU~!Mg$7G|wR)P(~T1AMA9XG542sO;zDE4J? zZ&Ri$mS|;IO{*>{GhPAMHT6NNj@kg8snZbdI083A26ugQ@3V>{)M=SDFNCV>osY-v zTpC??`PKQ^x6_`W@z~;P)8Hp^4+dN(z(E8~A)*=zZpIMaR{>JcdOIwCR@hqmtYt(K z0iPrKoOq4f4-RJ3z=_I2^f6uoT6(~N%3(&~VU7Vl2E9tW+8JWO%NS&N*$|(9%Huxm za0Z!!*zd%C714pfG3j)jM1LcLCK#U=8U(57=9LiF7~mB?x1aeM4ml1Vt47d~Xav-< z0T8Z&OH=0IRmYtyI7joq=M9ceK2Pv_@}SXz%Tw(<)h{?Z(O-y46iI@=Q%9EQAp99F z5;dxLN~5Unrb-{C~*l>L(+Dnc?$=BMO&_|xKY$~ zySRMfvGK=}#jCmE)zRL|$1WY4D{f7Rt@T`S{mtUFKt^TtbmdfKbVs71aV!t(HPdBN zWz#iNHL=n}MGLgSv%PZSx$)XaR~0&11Pe$7jJ?>!}>Ds1u&nj`8IYy@Zx!bVJX#@5bk zj8(@g+UL}pQpU1~@w3A+UUaHz{(LC@J<0DI+;z20s=2h4+7P=i>_t7vH#Z6>Qg;#nBfR zqc6gh1@izc173;t3v&pI>a;22246V7bqu){`Cu_I9#=4ktAbN5oN+ab^5DH1mP`h* zWfI_mD6^!~ecpJ^IH#^i6_!re#%;+$D_3ZZH6{ulj^w9|rI7+bxzP(ZL{M*dEgDd1 zRR?Dl6Fx+Qzm;(Fz6^h9Nu`ld8AT@-r$sRgu_B4*xy8KbghpurW)FpwP7`vdPYl+Ara0BzxoSOj!~yCItQpbbRMb<7b!abZ?6*^E?HCoyyPCv zOqg{pQ7y`JNTr;sRf{|%Duv`)l?K6MC_;=NrKU`eSjwdk(rLFh?h<)4i$&;6&Kw1N zIoiGG4Oo`M^FgTmduY&n37+WS;nrVNf-_u$`2)PyjHvxJUjkQwfa?JzfXe0k&!x!H z#&8}eXX(2LYtQ#euG0*IMBa2$Pl5VJj0Ww94rr8I!Wr0{vzWqKbm@h4*>jCHfFkt8 zajXcdvqyiiI}=u6Om{^8CFUs%>$B(i5<5{8*3-qXt0k{1#1l<8nnUrYOM%n5+Oq-x zl7V3e8-NN+hR0%wxLd!Cy$3Xh+BOfjmUtAM-cF(e6~#bQ>TlUfL%8;=JfZ)vGa?0iM*hE?5zF#30=&v4v)Zz@FI^N*=a)K@ zwORN6i=T~EK$Yb$DWK~fH0_2~Ti1jzevmfb-Wy$f^=A;gNZVyUe_<*19nb+o_{1)J zcxK`9cZ43B&BoxBdkg1YmVi4f!&>~klkjV!Vf8*iHEkX?^rVA2md8;i)%JP@+|6VS#2yvs zyK`n@apD(%Lt%TUFQmljmM(vnm-lS|dkz8-jMFuA_kSP_3od*TS+EVVCqn2An*aFq zg)5Vbubr9y=%!Uj#zp?be6-x-c9~Q8hg}E1B zdwcQ1`wQPc3mgz4*5D+=R zE4>gnBGC`VLRvAdB-kWc_>CfqF(wzwJc)fO7g*hHt4T0Ro`)vp47MX6!#V`Y$#(N5 zl0hs<>d=aI1=#I6v?|gD4X=eQXD|xogZPObya|x&gii9iA6}oXz@U&Df;Q;z5GKq_hVC$qC{sOMC)ZAHj4|Fd_f2t~dRAIjFn=pusmmu@Vd>$qJ2EG|s8l z-GW~|9T`9Jy|05`prmrLb^0q)Ux^NLRS(Uyb5)y@C0impQpTe5&%g9MdIgBJZsAOp zcH)Jk}1ik-g9YKc~ugtj)GLeRB$o`KGC+m^M+lG4dE0 zG+jLf16#R@)|mrbMcb^BtJo6Ri=;mH(sPmnqmDDx&6(DuD$Kdy$MPOnQ*_Hr0>y!OdC(-8=ljB9T{u=={z8T&OwytO-CyJOCT9uAS{Rh)iRbTc42 zG^gL5YHCk5b#hIeziE1uGg-%UlSic)>p4?>tbAtOj5*%0_1enop_`^d;ET6ylie0w z_5|0~nZ4IkbNZba(WsuTn5wu~6>n&RCBteT0F)*bThZe``CQC9TNH2FGTR$BcL7O2 zYQ^elfzUU6Pw33L%40rr7d$sEvfp(vF;0BmE1b$ z?1?vY%(i~K<^3(Q&G9|{_?p0PiU;qxz-`j`=?VB&3dcCw!Id;cv=8`O;#KQo_E|-I z{pQ*3c*V9kbr&$Rq-_^t^E&p@Eje8~*^o+#cH(ZEkf(PcK)I?*xSkt|xv6|GGa zHOBVLsUKN}v~5;1^No1L_BnMoa8+Xv96S}Ax&lH>=y_Xq)s_u!o>3-SwsS4puWg42 zRJ>x(P4!;ls@M;H!?kc}8B~fH)Pb3G_`#=u!S^srYU1U7iz36w^=D8KCLh!(gPrif z-~Mg*qq1GqFT2Oe+EnFH?|llgdyGt}DvNg9ry#qh-yu^yCW~tCQKT@d{ZuFjs1d8R zRFI@<^0Q8x-%sM%yiTZ+-$9T-xDfX)&uX^$7?5V^-z1PQH;Hib#RSID?+yr((1!6R zvO>O07#?ChzMzeD`-1`IcmSQQ=%C>ZLOVv^;U@$_R33dyBxX)t0p>-7%iBkORVl>S z31_Ae3g5c4pJ_9^fbdI+PUbSSVZ0!?4LQT{ 1: + self.logger.info(f"启用梯度累积,累积步数: {self.gradient_accumulation_steps}") + def train_epoch(self, dataloader: DataLoader): """运行单个预训练周期。""" self.model.train() + self.reconstruction_head.train() total_loss = 0 mask_ratio = self.config['pretraining']['mask_ratio'] - for batch in dataloader: - self.optimizer.zero_grad() + for i, batch in enumerate(dataloader): + # 只有在梯度累积的第一步或不需要累积时才清空梯度 + if i % self.gradient_accumulation_steps == 0: + self.optimizer.zero_grad() - # 1. 获取原始的区块嵌入(作为重建的目标) - with torch.no_grad(): + # 使用混合精度训练 + if self.use_amp: + with torch.cuda.amp.autocast(): + # 1. 获取原始的区块嵌入(作为重建的目标) + original_embeddings = self.model.gnn_encoder(batch) + + # 2. 根据 batch.ptr 逐图生成 mask 索引,避免跨图混淆 + num_graphs = batch.num_graphs + nodes_per_graph = batch.ptr[1:] - batch.ptr[:-1] + + # 确保所有图的节点数相同 + if not torch.all(nodes_per_graph == nodes_per_graph[0]): + self.logger.warning("批次中图形的节点数不一致,使用第一个图形的节点数") + nodes_per_graph = nodes_per_graph[0] + + # 为每个图单独生成掩码 + all_masked_indices = [] + for j in range(num_graphs): + # 计算当前图的节点在批次中的起始和结束索引 + start_idx = batch.ptr[j] + end_idx = batch.ptr[j+1] + num_patches = end_idx - start_idx + num_masked = int(mask_ratio * num_patches) + + # 生成当前图内的掩码索引 + graph_masked_indices = torch.randperm(num_patches)[:num_masked] + start_idx + all_masked_indices.append(graph_masked_indices) + + # 合并所有图的掩码索引 + masked_indices = torch.cat(all_masked_indices) + + # 3. 创建损坏的嵌入 + corrupted_embeddings = original_embeddings.clone() + # 使用可学习的 [MASK] 嵌入 + corrupted_embeddings[masked_indices] = self.mask_embedding.to(corrupted_embeddings.device) + + # 4. 为 Transformer 重塑形状 + corrupted_embeddings = corrupted_embeddings.view(num_graphs, nodes_per_graph, -1) + + # 5. 将损坏的嵌入传入 Transformer 进行编码 + encoded_embeddings = self.model.transformer_core(corrupted_embeddings) + + # 6. 通过重建头生成重建的嵌入 + reconstructed_embeddings = self.reconstruction_head(encoded_embeddings) + + # 7. 只在被掩盖的区块上计算损失 + # 将 Transformer 输出和原始嵌入都拉平成 (N, D) 的形状 + reconstructed_flat = reconstructed_embeddings.view(-1, original_embeddings.size(1)) + # 只选择被掩盖的那些进行比较 + loss = self.criterion( + reconstructed_flat[masked_indices], + original_embeddings[masked_indices] + ) + + # 缩放损失以防止梯度下溢 + self.scaler.scale(loss).backward() + + # 只有在累积步数达到设定值时才更新权重 + if (i + 1) % self.gradient_accumulation_steps == 0: + # 取消缩放并更新权重 + self.scaler.step(self.optimizer) + self.scaler.update() + else: + # 标准训练流程 + # 1. 获取原始的区块嵌入(作为重建的目标) original_embeddings = self.model.gnn_encoder(batch) - # 2. 创建掩码并损坏输入 - num_patches = original_embeddings.size(0) - num_masked = int(mask_ratio * num_patches) - # 随机选择要掩盖的区块索引 - masked_indices = torch.randperm(num_patches)[:num_masked] + # 2. 根据 batch.ptr 逐图生成 mask 索引,避免跨图混淆 + num_graphs = batch.num_graphs + nodes_per_graph = batch.ptr[1:] - batch.ptr[:-1] + + # 确保所有图的节点数相同 + if not torch.all(nodes_per_graph == nodes_per_graph[0]): + self.logger.warning("批次中图形的节点数不一致,使用第一个图形的节点数") + nodes_per_graph = nodes_per_graph[0] + + # 为每个图单独生成掩码 + all_masked_indices = [] + for j in range(num_graphs): + # 计算当前图的节点在批次中的起始和结束索引 + start_idx = batch.ptr[j] + end_idx = batch.ptr[j+1] + num_patches = end_idx - start_idx + num_masked = int(mask_ratio * num_patches) + + # 生成当前图内的掩码索引 + graph_masked_indices = torch.randperm(num_patches)[:num_masked] + start_idx + all_masked_indices.append(graph_masked_indices) + + # 合并所有图的掩码索引 + masked_indices = torch.cat(all_masked_indices) + + # 3. 创建损坏的嵌入 + corrupted_embeddings = original_embeddings.clone() + # 使用可学习的 [MASK] 嵌入 + corrupted_embeddings[masked_indices] = self.mask_embedding.to(corrupted_embeddings.device) + + # 4. 为 Transformer 重塑形状 + corrupted_embeddings = corrupted_embeddings.view(num_graphs, nodes_per_graph, -1) + + # 5. 将损坏的嵌入传入 Transformer 进行编码 + encoded_embeddings = self.model.transformer_core(corrupted_embeddings) + + # 6. 通过重建头生成重建的嵌入 + reconstructed_embeddings = self.reconstruction_head(encoded_embeddings) + + # 7. 只在被掩盖的区块上计算损失 + # 将 Transformer 输出和原始嵌入都拉平成 (N, D) 的形状 + reconstructed_flat = reconstructed_embeddings.view(-1, original_embeddings.size(1)) + # 只选择被掩盖的那些进行比较 + loss = self.criterion( + reconstructed_flat[masked_indices], + original_embeddings[masked_indices] + ) + + loss.backward() + + # 只有在累积步数达到设定值时才更新权重 + if (i + 1) % self.gradient_accumulation_steps == 0: + # 更新权重 + self.optimizer.step() - # 创建一个损坏的嵌入副本 - # 这是一个简化的方法。更稳健的方法是直接在批次数据中掩盖特征。 - # 在这个占位符中,我们直接掩盖嵌入向量。 - corrupted_embeddings = original_embeddings.clone() - # 创建一个可学习的 [MASK] 嵌入 - mask_embedding = nn.Parameter(torch.randn(original_embeddings.size(1), device=original_embeddings.device)) - corrupted_embeddings[masked_indices] = mask_embedding - - # 3. 为 Transformer 重塑形状 - num_graphs = batch.num_graphs - nodes_per_graph = batch.ptr[1:] - batch.ptr[:-1] - corrupted_embeddings = corrupted_embeddings.view(num_graphs, nodes_per_graph[0], -1) - - # 4. 将损坏的嵌入传入 Transformer 进行重建 - # 注意:这里只用了 transformer_core,没有用 task_head - reconstructed_embeddings = self.model.transformer_core(corrupted_embeddings) - - # 5. 只在被掩盖的区块上计算损失 - # 将 Transformer 输出和原始嵌入都拉平成 (N, D) 的形状 - reconstructed_flat = reconstructed_embeddings.view(-1, original_embeddings.size(1)) - # 只选择被掩盖的那些进行比较 - loss = self.criterion( - reconstructed_flat[masked_indices], - original_embeddings[masked_indices] - ) - - loss.backward() - self.optimizer.step() total_loss += loss.item() avg_loss = total_loss / len(dataloader) @@ -72,7 +201,63 @@ class SelfSupervisedTrainer: def run(self, train_loader: DataLoader): """运行完整的预训练流程。""" self.logger.info("开始自监督预训练...") + start_time = time.time() + for epoch in range(self.config['pretraining']['epochs']): + if self.early_stop: + self.logger.info("早停触发,停止预训练。") + break + + epoch_start_time = time.time() self.logger.info(f"周期 {epoch+1}/{self.config['pretraining']['epochs']}") - self.train_epoch(train_loader) + current_loss = self.train_epoch(train_loader) + + # 记录学习率 + current_lr = self.optimizer.param_groups[0]['lr'] + + # 记录到 TensorBoard + self.writer.add_scalar('Loss/pretrain', current_loss, epoch) + self.writer.add_scalar('Learning Rate', current_lr, epoch) + + # 计算周期耗时 + epoch_time = time.time() - epoch_start_time + self.writer.add_scalar('Time/epoch', epoch_time, epoch) + self.logger.info(f"周期耗时: {epoch_time:.2f} 秒") + + # 检查是否需要保存最佳模型 + if current_loss < self.best_loss: + self.best_loss = current_loss + self.counter = 0 + # 保存最佳模型 + save_path = os.path.join(self.save_dir, 'best_pretrain_model.pth') + torch.save({ + 'model_state_dict': self.model.state_dict(), + 'reconstruction_head_state_dict': self.reconstruction_head.state_dict(), + 'optimizer_state_dict': self.optimizer.state_dict(), + 'best_loss': self.best_loss + }, save_path) + self.logger.info(f"保存最佳预训练模型到 {save_path}") + else: + self.counter += 1 + if self.counter >= self.patience: + self.early_stop = True + self.logger.info(f"预训练损失连续 {self.patience} 个周期未改善,触发早停。") + + # 计算总训练耗时 + total_time = time.time() - start_time + self.logger.info(f"总预训练耗时: {total_time:.2f} 秒") + + # 保存最后一个模型 + save_path = os.path.join(self.save_dir, 'last_pretrain_model.pth') + torch.save({ + 'model_state_dict': self.model.state_dict(), + 'reconstruction_head_state_dict': self.reconstruction_head.state_dict(), + 'optimizer_state_dict': self.optimizer.state_dict() + }, save_path) + self.logger.info(f"保存最后一个预训练模型到 {save_path}") + + # 关闭 TensorBoard SummaryWriter + self.writer.close() + self.logger.info("预训练完成。") + self.logger.info(f"最佳预训练损失: {self.best_loss:.4f}") diff --git a/src/engine/trainer.py b/src/engine/trainer.py index 0a26176..cb3eace 100644 --- a/src/engine/trainer.py +++ b/src/engine/trainer.py @@ -2,8 +2,34 @@ import torch import torch.nn as nn from torch.optim import Adam, AdamW +from torch.optim.lr_scheduler import StepLR, CosineAnnealingWarmRestarts from torch_geometric.data import DataLoader +from torch.utils.tensorboard import SummaryWriter from ..utils.logging import get_logger +from .evaluator import Evaluator +import os +import time + +class FocalLoss(nn.Module): + """Focal Loss 实现,用于处理类别不平衡问题。""" + def __init__(self, alpha=1, gamma=2, reduction='mean'): + super(FocalLoss, self).__init__() + self.alpha = alpha + self.gamma = gamma + self.reduction = reduction + self.bce_with_logits = nn.BCEWithLogitsLoss(reduction='none') + + def forward(self, inputs, targets): + bce_loss = self.bce_with_logits(inputs, targets) + pt = torch.exp(-bce_loss) + focal_loss = self.alpha * (1 - pt) ** self.gamma * bce_loss + + if self.reduction == 'mean': + return focal_loss.mean() + elif self.reduction == 'sum': + return focal_loss.sum() + else: + return focal_loss class Trainer: """处理(监督学习)训练循环。""" @@ -25,30 +51,97 @@ class Trainer: if config['training']['loss_function'] == 'bce': # BCEWithLogitsLoss 结合了 Sigmoid 和 BCELoss,更数值稳定 self.criterion = nn.BCEWithLogitsLoss() - # 在此添加其他损失函数,如 focal loss + elif config['training']['loss_function'] == 'focal_loss': + self.criterion = FocalLoss() else: raise ValueError(f"不支持的损失函数: {config['training']['loss_function']}") + # 初始化学习率调度器 + self.scheduler = None + if 'scheduler' in config['training']: + scheduler_type = config['training']['scheduler'] + if scheduler_type == 'step': + self.scheduler = StepLR(self.optimizer, step_size=config['training'].get('scheduler_step_size', 30), gamma=config['training'].get('scheduler_gamma', 0.1)) + elif scheduler_type == 'cosine': + self.scheduler = CosineAnnealingWarmRestarts(self.optimizer, T_0=config['training'].get('scheduler_T_0', 10), T_mult=config['training'].get('scheduler_T_mult', 2)) + + # 初始化评估器 + self.evaluator = Evaluator(model) + + # 初始化早停相关变量 + self.best_val_score = -float('inf') + self.patience = config['training'].get('early_stopping_patience', 10) + self.counter = 0 + self.early_stop = False + + # 确保保存目录存在 + self.save_dir = config.get('save_dir', 'checkpoints') + os.makedirs(self.save_dir, exist_ok=True) + + # 初始化 TensorBoard 日志记录器 + self.log_dir = config.get('log_dir', 'logs') + os.makedirs(self.log_dir, exist_ok=True) + self.writer = SummaryWriter(log_dir=self.log_dir) + + # 初始化混合精度训练 + self.use_amp = config['training'].get('use_amp', False) + self.scaler = torch.cuda.amp.GradScaler() if self.use_amp else None + + # 初始化梯度累积 + self.gradient_accumulation_steps = config['training'].get('gradient_accumulation_steps', 1) + if self.gradient_accumulation_steps > 1: + self.logger.info(f"启用梯度累积,累积步数: {self.gradient_accumulation_steps}") + def train_epoch(self, dataloader: DataLoader): """运行单个训练周期(epoch)。""" self.model.train() # 将模型设置为训练模式 total_loss = 0 - for batch in dataloader: - self.optimizer.zero_grad() # 清空梯度 + + for i, batch in enumerate(dataloader): + # 只有在梯度累积的第一步或不需要累积时才清空梯度 + if i % self.gradient_accumulation_steps == 0: + self.optimizer.zero_grad() - # 前向传播 - output = self.model(batch) - - # 准备目标标签 - # 假设标签在图级别,并且需要调整形状以匹配输出 - target = batch.y.view_as(output) + # 使用混合精度训练 + if self.use_amp: + with torch.cuda.amp.autocast(): + # 前向传播 + output = self.model(batch) + + # 准备目标标签 + # 假设标签在图级别,并且需要调整形状以匹配输出 + target = batch.y.view_as(output) - # 计算损失 - loss = self.criterion(output, target) - # 反向传播 - loss.backward() - # 更新权重 - self.optimizer.step() + # 计算损失 + loss = self.criterion(output, target) + + # 缩放损失以防止梯度下溢 + self.scaler.scale(loss).backward() + + # 只有在累积步数达到设定值时才更新权重 + if (i + 1) % self.gradient_accumulation_steps == 0: + # 取消缩放并更新权重 + self.scaler.step(self.optimizer) + self.scaler.update() + else: + # 标准训练流程 + # 前向传播 + output = self.model(batch) + + # 准备目标标签 + # 假设标签在图级别,并且需要调整形状以匹配输出 + target = batch.y.view_as(output) + + # 计算损失 + loss = self.criterion(output, target) + + # 反向传播 + loss.backward() + + # 只有在累积步数达到设定值时才更新权重 + if (i + 1) % self.gradient_accumulation_steps == 0: + # 更新权重 + self.optimizer.step() total_loss += loss.item() @@ -56,11 +149,79 @@ class Trainer: self.logger.info(f"训练损失: {avg_loss:.4f}") return avg_loss + def validate(self, dataloader: DataLoader): + """运行验证并返回评估指标。""" + self.model.eval() # 将模型设置为评估模式 + metrics = self.evaluator.evaluate(dataloader) + return metrics + def run(self, train_loader: DataLoader, val_loader: DataLoader): """运行完整的训练流程。""" self.logger.info("开始训练...") + start_time = time.time() + for epoch in range(self.config['training']['epochs']): + if self.early_stop: + self.logger.info("早停触发,停止训练。") + break + + epoch_start_time = time.time() self.logger.info(f"周期 {epoch+1}/{self.config['training']['epochs']}") - self.train_epoch(train_loader) - # 在此处添加验证步骤,例如调用 Evaluator + + # 训练一个周期 + train_loss = self.train_epoch(train_loader) + + # 验证 + self.logger.info("正在验证...") + val_metrics = self.validate(val_loader) + + # 更新学习率调度器 + current_lr = self.optimizer.param_groups[0]['lr'] + if self.scheduler: + self.scheduler.step() + new_lr = self.optimizer.param_groups[0]['lr'] + self.logger.info(f"学习率从 {current_lr:.6f} 调整为 {new_lr:.6f}") + current_lr = new_lr + else: + self.logger.info(f"当前学习率: {current_lr:.6f}") + + # 记录到 TensorBoard + self.writer.add_scalar('Loss/train', train_loss, epoch) + for metric_name, metric_value in val_metrics.items(): + self.writer.add_scalar(f'Metrics/{metric_name}', metric_value, epoch) + self.writer.add_scalar('Learning Rate', current_lr, epoch) + + # 计算周期耗时 + epoch_time = time.time() - epoch_start_time + self.writer.add_scalar('Time/epoch', epoch_time, epoch) + self.logger.info(f"周期耗时: {epoch_time:.2f} 秒") + + # 检查是否需要保存最佳模型 + val_score = val_metrics.get('f1', val_metrics.get('accuracy', -1)) + if val_score > self.best_val_score: + self.best_val_score = val_score + self.counter = 0 + # 保存最佳模型 + save_path = os.path.join(self.save_dir, 'best_model.pth') + torch.save(self.model.state_dict(), save_path) + self.logger.info(f"保存最佳模型到 {save_path}") + else: + self.counter += 1 + if self.counter >= self.patience: + self.early_stop = True + self.logger.info(f"验证性能连续 {self.patience} 个周期未改善,触发早停。") + + # 计算总训练耗时 + total_time = time.time() - start_time + self.logger.info(f"总训练耗时: {total_time:.2f} 秒") + + # 保存最后一个模型 + save_path = os.path.join(self.save_dir, 'last_model.pth') + torch.save(self.model.state_dict(), save_path) + self.logger.info(f"保存最后一个模型到 {save_path}") + + # 关闭 TensorBoard SummaryWriter + self.writer.close() + self.logger.info("训练完成。") + self.logger.info(f"最佳验证分数: {self.best_val_score:.4f}") diff --git a/src/models/__pycache__/geo_layout_transformer.cpython-312.pyc b/src/models/__pycache__/geo_layout_transformer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4f912970d639e2f8e09bb7a0e4ff3e9dbfb77c87 GIT binary patch literal 3557 zcmb_eZ){W76~E7Z&+-5KA$EeDfCDVUWWh8*7#YwOQi3)}*FrbR)B5E4y(CWk?Cieh zwn2^>RbiMmh(;5O)3w7$`{OnYGHII9()IJCeaRz8cvqkzferbzNTgb|FWb5A*)LAW zhgI8s(tYQid;Z;X@BN+oHArVaqB8E}z85~K-14%52_*fE)L6$Tei$`QRHV}(O zlvpxxLW~Hc;q2B~gT4@K|GaiJ$CV0<33 zA)+uNps-0`!e+rFvIC|J8ejzTgfq+utY{f_SF7vvfJxvc*swK&-eSVG(mjeTPB zWMnj%QmWRI+6D5(>yxF+e|q%6L;yrOb(uhwP@t5*TKx5U560gimS{3D5F3n|s$5!u zra?dX9dIu*(?~HU6+f@n6%zZ`*l2a10J^l z&$0ophVrgiXUhi;VyGNFq`01n!A7Npwf=P8TH89$zQ&)Qk^_>XeEsCci3NxFu z*i7ZMz~Ruw=4*{rkr}a!x(r*MvaP27gDr2_7Im$){~@j>&6ZE`Z(yskw&uzeug>$R zi+;h)sW@mBGanUy|6XzOm&O6N+XLnHJRX#Lj1gt(N53afoIF>2f3|dCy!7D(UGeds z&lNAdQ@T6{cWCL{#Yca+Ui`y!@!UMz7{dbLB*2q6F%KpZ#GL?|k0nM@3NOSYVjqeL zf|$^0xS<9=k&^g$WK_g5v4SG6jE;yw4%b2O7(f)^2HMfHfg7PC4rLXqh>dCw0lvtj zM!?AtFdSrcHBG}y)#%Rl6Ppr|U)N-bb0nFJ12Ad62rG$^1YxOP;*PFCJz_H!0A^dM zvDJ|1CKI1TArTe;n8>eFwqpYjiB&3tVbHAG(N;xG1l@#_QzIhAfF6igc}QAJ;d$(& z?&z%|4w@HHM{7)PGpVae0v`p$#uWb{^P+++RFB?4OpiOJxD(_ED~s^~%_ecsgn<@? zIjpV1#_O^3yiCp0qcHxjpN3wMVOK)Ku}Jd3fletB4fQ6YDM?Hya)?$AIkXC@rE~Peg=HqJ9_x=lp z$Di$+;&ZR4o{s$YRnLJHho!}yVQ*WTR~k^;o{EJrwQ2txm^6H>GQ2I>A5SH7lIq9oe4){oJuvlR?v&ctHFr>LJiKD#+BB1`%XP@N1s32gIK0{RT+>WQbp#8(0HyLB zs;@iqgMz>P%CoSlBlF#h{-c?b_iKIGZnbvXLVNeE+U~n;dkb5FAiAZyuyt4dkh=A| z#=or-9BpfOpS|o}cA|#X?2pv?9Sb{qZ`JqSZ9DwU)bh?-_51F&g}$Md{Z@U~-L}rB zv>>z^wj7!nHSAs9i=4Zjz&;~l!J?AFi8Y0no<^S8_vON%!kgxi*AgC;$#%HRsFO%&FbYwFjs&WB0$4N zsjreX88LDop}iAVXDNj@0iT}ubuG;qnwGTXG)i-0+CRvpZh)wPzE7jM5K$sMfrs~z;rkJ zm;Mb6(~SyR@OJnSI|cGV92pssiFri9gnc6>zKM6hoCa;d8r)9hSh^@uH$qpzi_&YN zAOMidcsHG_r8P(90ZVmL!~-Cwh;J#WNW4mzSoHcM(>8*_I7) zsXTiEMCB}W<7l}NwFK_?J5+y1zIPGO$J)GXM_%7FKgDP4_kCNg?3~$|dnVVF{Z78| zM)SwbbKB>-@~!lK*WUMc-0|;G{d?wo^ZZ|5QV+kl@Y3l;{~2n-g4(`yC$L`)>|f|U zz8L6Ru(krvS$FBkyGIJ1&C{<=y}GdD;G(A|a~%BDc{8qM4tYB=$M4sAGA>M?yg@sy zC5eb6@;tHgJOFYkPH894{|pjRo~hw^AsGcWE5xFThRTFfk~lg(rkncp1o}I>Z<+1k8nVh)2!6|FnXQ~Ro4S&o zN&gc8sYR-31>N`%=n$*sdpEowOo@bcHnbY3p9HP;*^NRjZ&d4u<&?Hk&FaTU_n_^r z^^l%}Be)kl(a>eeUp6xg^8h*TA literal 0 HcmV?d00001 diff --git a/src/models/__pycache__/gnn_encoder.cpython-312.pyc b/src/models/__pycache__/gnn_encoder.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a463ce4495498afbb3b4116ddcd48acd042d1cb GIT binary patch literal 3051 zcmbtWZ)_Vy7N7lNd*eUUaV{ZE+QytJ5?7?55~W=3bRsk%LV;XG0w0XfvUoR+efByx zy9-U^oYE*z2V9b~~J%CT3|kvxac4bmwVnyA=ARD&@ZE<_h9(M#C6rxcl5}7e1vJxHj77do$ImVQO6@yg3c9V3v1FY`oJluUa zA-|+^J#C#GhUDmM>oo*#IF@)(i1Berkokc`A{Hu}L&#-yZodG+aqw0`K}tkHTBL%E z=n`p(6&d(g_-ydmC0m~}h59JbG3E){MQ#iQ9b$#VMSUC1nsOg4I>*=`=pUtmt{xO{ z>(0*Z?hZMW5GAZV0b04U3#$un_+iwae{(Fqcsf6GEO%!1)?zw8Jqt7W>9cF8$y|CN z_jcyiV(N&XhWc+Uj)QLQvvls2f9EErZXTbwF*2&(BZE;sA`fUPFGk|Jvp*t=l5ENe zO)ZVfTAYswgA!JBE-cHuIyfMO>>Ip?_eA8LZ_iK&ouV>GElsL)QB0x}aCdW7mqMze zI1?r7x0Iq@67fZJpd^YK7}~5F^=$5w9%$~^W|T_ONm~&SbKol};7O$^!8%r?hG384 zTswMxKSIOy-@P-v{m$tJ+24OEI>~H%suI574omkoNCz10PjGan-zsf5thAPTWN6g zD(*=~vwN4hrpw&E>ov8R_*DF4BE_a%KXUE~*SPBnSGUa7Wgogy-@IJkJS)#fziC>j z@43wN-mxR!y^}*TZBzgJPv!lowySRMMAvxN75AQH_nxeJrQyJG!+}rs84DVEzH|2$ z802d#aHztwjp6A0zOVaN>i>3`I|7Dln^Nov_igYJg;-Ksg~$eLje(FsNk*jp-bAih zwU`0;02pJ=Vz3^fL0cMDR@p{_u{6r8waG@A^+YzwyoKJT=d9gEAwOxm!Iia{3u`An z${V;`f8+Dq#OWed32QR|Q@x@DQs@h+wiA}E&YxJD{&+1-g40Y7hJ{vtfW^dJOwNmm zTe|aEA*M+k7$-1HlEu0!BJ-hsL6+h9bvHnx)H4vLki`UJco(r!1JzLN20&~A^?fj_ z^g=a)eson$YL{HQu2Uattp!Fm zuCb48YQm8-T-^A{vJ2@=I=)zW_JZyAIRoKLX3skj*KLq_q+= z{Xg@Hf+_$<&5x|kf9O9l*qgwiem{h3TvBl)L=wKg9aQ{lsnh<)NM`rvKEAMi{v6?_ zlG&$?tJ%Worp?uh|5`scYccu4`?=H{m^9gz2#YgYxHDkWsiV4662s(2L6nZ-KfnQW zn@SwhSuDl0fXxIDOMJ(TFpVLIK*!`g@+bs1CUA5IaLvmC7KeNrz+;UygR0U2)d(th z(azmx?Pu&W?E5>;?MStzsDd5U)McKVelFE<&09SwWb0PEdsA)Kz111#lrwwaq-VUX zU_&+c6`aUhnc@tb1Ss492H`gf2wE265)ibL=Yi#bo~Z=L&#c}0K4tc59AousE`Uljtg1YXohp5U)UuNgj&btUi!^E$ixi$Cz literal 0 HcmV?d00001 diff --git a/src/models/__pycache__/task_heads.cpython-312.pyc b/src/models/__pycache__/task_heads.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..42396663a0f14ced9a268cc0aa1f3b19e3d5e00e GIT binary patch literal 7595 zcmeHMeQZ;wbV zhT2K(m3-d4_nmw0efQk+JLg>gB{$bWpe)?g>Hm8HA)jMIFBYvb`A?|y5`hRzfYho_ zrqNVuVu*>XBZ6rU5m?dGVAThvpQl;1UTao3&A@571E)pdv;wE?4xCnnlLJos9XM?Y zrvo^1?!d{lG9KqOTy%{u{TYOqwZ`OjsO*E?5J@c~kXn<#)Utw2Fp1_mQY`hE`eC?zwhmh0S8-}zq?S$8KN}we$FFimAEe+fC!41lJ z(p&0NFbSs9noJmPoRe14N<3_hZ1;x4VleCv1+OtM!ef&yQlv?w6xS$oJns+s!#q#% zK+2|IP_|b4gQAyG8*;ORUU{`vP7`;&};|>2`PgdsF3BNu*L`gWnrkxpHx% z*H^hVL8c7+|EGX+7WKy7*}zr^ns{(jGH}SCC*uW4}X?( z|K$VB2dF zEb6n4z3nJY8MPI@*f?tG`7>6R<@gyhsq=xc?lq~ioyOw^MSnKrs_jgw?2^W}(*y=! zYJ+pwke|t;)-yS`knlY^hhWdHrvdVoNuPN)<8eg(03#f3tvpV5NL5A)J6o#UiBEo$ zdi77^!)M(uZuf?L^}OV76D!=3*vtpSV1>KhF9>3g7yOMscLPc!&d09pJLoW^Y*tRx zG_kjkIZc!XINh(M#p)}S&5fcrNI6)&hml#BTxJ`+Efg(YW_zn$Rqf5JA=&6c3?@Pa3^jYrWE5#+f!Jgpptta_o{0EypY(6`8 zc=ND#c=HD>XI>ng|M+O}6Bmo09w~nM&&~1Tr$>u7Mb}*_Eb6oMY`EwuA90lr-4l0} zkGdAcor}f`O5GC!mBLg3J~|6`()yejltPrw!A7ozx;`-w;CYw;U7v(T z^O&;l3Xm%VAzXS8$X9#F-^@S$(qy$Pj%~R?;4v|y(DDq^S2S4KUphgsJ~aRRic=N; zQtHam6%jrdZ4@00CWFv*lV5hAUxwJk4F1>*zSmL*9u|DG6@0W6;tN{?ZpgGEaM}C8 zN87su+e)hFQ^w)62$IqQh zym|}{ed_leiLN8ytig9D_Y9^6PbA+w`^EX_#_DIkIKN+IOC0|QeYf2oY>I@@m}N)W z8f6E#DXm|z7$5GbUJS2(6?(p|R}Z6!DW%=40}}XY!)b&e22e|bQ?I3;2#LBIfl-F6 zj8R#3S&u1Tgswx9F(TVMrgpHoUvQt(SIu;_^A|w?^aA(!Cj80F%yAJh0;}mW8-mHEj4uG(TNHsvz zIAAQ#9V;K<9vtJAU7k5FRy{g%Y1G{5yzDAcaAVt$W7PE!^l+C8i+k7gtUI(G8g~74 z?NG_6YkAb%#i?e12uC4YCsTq0boWYd@ETN<;6QU*&@@UrqiYzP8EG&=1cQNqHg$K$ z0!QZQ(wIOGnH+}3Gh}3PuwW}pq#NEkbeECUxh3B^adqFj?u-b`eeK<$#K#||KIuuF zdlQW3>V-pzy(d(!1{QJ$$CiAvBXur{{z@@HcObOGAC_>OzKm05S_#Xd2uO4p5{$E< zH|orE@lqt#P@;MZ1?{DP=MReBl7Ih8}*@St;%`~b;3ATsL@ zhG{uAm7|XUP3I%Q1=9PGEC7;;$cmvLK|W~>5CxH+`O<8${4h3eg23a0hm^;sYm`U) zh1wg~JCT=Xap^D(RUM_FsxB0z`rv02yH@*lK)Q5eBoOvjd$)^$n>7_di`Rhq_|ar+ zU+TRJ(-1oolW7)!_#S`h64*o?f$_Mf5m>b^1ASFXTZ$}(xzXj&l(GWs{M)I{fZ8W; zd?wT`9D&q$apbyg)= zwkXN6HG5Kgt0WqdWo}Bc?2w%s@x}^nw3Re@6-=)I)O+uYe;iBgk0u5#BsxF7wzmt@ z{f4q-0PopCx?zD-s*+>9ZW{MjsUlMtP|Zm za6g*~G}lb<5$aW{NoJa85j5dAC|T-L_zFX5*-;k>DksGoknKVEv%(vIdtgsy-kQII zdYkHbv2nX72w-$_q2zD$2fWlDZcX=Q3&9GfoF)y+eg)G?Yk?@4?7d}!75x>l=7Gv+ z{*|(MgS+~74ec5&do<3Kj+f0H+}XeLV%dt3vK8@_8{oR~`4{44D@MzH3U4pv&x>>O z?i98uhJr-GA4YacZo1}6Gi!Mmb5nR=ZVHc2eK+ti;mEU;r*qQsbWU2XPKMIphFQ6~ z%*=r-*bMyo4PgQ>2bhW&X~8_iK$D@<>w>35x|8ciib@u!pFloU<&rZ<3>(%ASXfBaEt_?M*{=S$i|aq3rdzs+97C$$+-ru!+N_mIFF7;z9)=8%fa?&&R0uG-LsAF&<$HKt_i(D$u+sjrS7-a>K)voz z()}`}JWT6?C`OY}RTOoHnxf%8-7zTBl%hLBG?eU>bw?zg3aNTcI;osY%krSCbMmQv zolk_+KAnrFdUZ~TCUq_hHZ>~PDE%o#*6XR<8&y=9iYL;Mh)e@6omb>|gkoh=OhXDc z2FJs0qvUi!ioPO(K|-$)#c({ND59uf_WO-C^q(Wa=M|YM!B{ku+P-~rA`}iDNQKi0 zIjJf^s8(4CJ|m|#e={_YN~@c@F_cJ(CS)2^XgFvTrV=zFG^fRIippL60~Awmuw`lO zcUeD84Qx;lr9K}b9|@iruYYn~UU*_g*jn^7oIg5pH2dh}hJxpDjV(5=n%F(Qdy<*l z{@(gOYICsb(LG26B$-?pdF zw)d~ye|z!IFW&AgJpID-ffuH~b3A|Gc)l%EXgmR?O)V4A@#v(G+mZKgD>OZ+vFEmE zAx#?dX_B^mr1ncUeC(+|e{AGfcI|88tfwRI>6kp3TU+pSYV5~C-K@}>7h1F3lUqL& zI_|^A5}i4vJE^RuX)$z$^;Z2Bq(T$8vqUys|LXG28p#&-zYa`n2oT_ zur>&Pr@;oGm`lG51Gjnee*Zq&tL(BzID2ZBzx4j(y|d^4{r+g_)|>O^eqWrr;XkU# zClPi#{U<_dxKC7~gL0=|PMnY>#Bx!JCSJ1Ll&gg`^N`2A0hT@n=Tbii1;5zur>mjU zU63)~4N+US23?*bJMz@Gml8&QZCyFQ2p?G=4C5loi4AWdRsT#RU zk@a+Ex13Z`l&-^pr(8l|IUX0qD-1;dQ4kRJ8ee48z9H~3l_x=beun&m-|{J6@7OTO zEfDB(;fL5Qe9h(XWmjHmz0kTqK+Ubc)_JA#GgOW64L4!%U^lE6Z1D{5?)Vbo+YjZz6Te^8;E_ZqUjWhl#6aT%pa;59PnZGqYfAb=;y1OqbNpcdi zM|UUF2{9fT036b5Q16pNk^;m}Q~iKCVe8$3UvHPmdocLwd=JcC=GzPr5vq-*B5E3K zIR=MP;SY5OcmuI@M-so3ZiS=uTC=($ zQL3l`hCTsvff|Yr3fhGN@zf-jT0nOZoJly$6=1xP~QRGlxIPhP(PZ=s(E4a zj1c^&Wpgf8XxXXpGs3E3bIWY=wtVxpsc+nlPKWxYUm2Wf{+{M7wy&RU-;r}NEBXygve8JPYvGtL04v^^0jxg+1%WkI&HxzN4c-Q<{Wf3$dY6IN zj194;*$n%Cd8y9a@>x}0xD02&au5Qt5p*dmFaAV8=?XYkW3mm@mD+mA4~1xhx_yYn z0blHkH-GWT<=ouJ(A>q5(&&xS&nM9Mccb`59~%c=fWmd#J`NikwunNCQvd{SeZ~0h z4!v3$9WISsU4$g&{{Z#};OSPt#-ofPBnCK&uSoh73Iru6E`x9IsR_?08rGcvBK>Jq z!Aw90X3DOE(O`Ay7%5xk9V|KJIS{}*t)za?JIb@H13YKVb;4k?_Hf$ z3f^`c`_3mu64`y(lLhZujazV%hNf9xAnyz0+6%re&3UJOrRFIUQ|}uWuwDOppq8$O zGbq*v-GE{viU0~93Zv)|VCds0Fl4$31yW))^B53;q8%3wf-sopz^A;^ah#dlwm`6d zXTRD1soYN=KJw65SXl9hqbcXOR(GXtfq*)-@`Kjvt)HQ4vJtull)zAQ!FbYkC6jbJ z3@QE>p*x>5c-3S;+wGa3f#+Pll a)O?}&0Ri>Ft~E^4*a>`{JtUwS{{9P?Bfgga literal 0 HcmV?d00001 diff --git a/src/models/geo_layout_transformer.py b/src/models/geo_layout_transformer.py index 34a9a88..0c15093 100644 --- a/src/models/geo_layout_transformer.py +++ b/src/models/geo_layout_transformer.py @@ -3,7 +3,7 @@ import torch import torch.nn as nn from .gnn_encoder import GNNEncoder from .transformer_core import TransformerCore -from .task_heads import ClassificationHead, MatchingHead +from .task_heads import ClassificationHead, MultiLabelClassificationHead, RegressionHead, MatchingHead class GeoLayoutTransformer(nn.Module): """完整的 Geo-Layout Transformer 模型。""" @@ -38,16 +38,34 @@ class GeoLayoutTransformer(nn.Module): self.task_head = None if 'task_head' in config['model']: head_config = config['model']['task_head'] + pooling_type = head_config.get('pooling_type', 'mean') + if head_config['type'] == 'classification': self.task_head = ClassificationHead( input_dim=head_config['input_dim'], hidden_dim=head_config['hidden_dim'], - output_dim=head_config['output_dim'] + output_dim=head_config['output_dim'], + pooling_type=pooling_type + ) + elif head_config['type'] == 'multi_label_classification': + self.task_head = MultiLabelClassificationHead( + input_dim=head_config['input_dim'], + hidden_dim=head_config['hidden_dim'], + output_dim=head_config['output_dim'], + pooling_type=pooling_type + ) + elif head_config['type'] == 'regression': + self.task_head = RegressionHead( + input_dim=head_config['input_dim'], + hidden_dim=head_config['hidden_dim'], + output_dim=head_config['output_dim'], + pooling_type=pooling_type ) elif head_config['type'] == 'matching': self.task_head = MatchingHead( input_dim=head_config['input_dim'], - output_dim=head_config['output_dim'] + output_dim=head_config['output_dim'], + pooling_type=pooling_type ) # 可在此处添加其他任务头 diff --git a/src/models/gnn_encoder.py b/src/models/gnn_encoder.py index 1140c3f..3677757 100644 --- a/src/models/gnn_encoder.py +++ b/src/models/gnn_encoder.py @@ -48,15 +48,14 @@ class GNNEncoder(nn.Module): data: 一个 PyTorch Geometric 的 Data 或 Batch 对象。 Returns: - 一个代表区块的图级别嵌入的张量。 + 一个代表节点级别的嵌入的张量。 """ - x, edge_index, batch = data.x, data.edge_index, data.batch + x, edge_index = data.x, data.edge_index # 通过所有 GNN 层 for layer in self.layers: x = layer(x, edge_index) x = torch.relu(x) - # 全局池化以获得图级别的嵌入 - graph_embedding = self.readout(x, batch) - return graph_embedding + # 返回节点级别的嵌入,不进行全局池化 + return x diff --git a/src/models/task_heads.py b/src/models/task_heads.py index fd374eb..b860c76 100644 --- a/src/models/task_heads.py +++ b/src/models/task_heads.py @@ -2,11 +2,44 @@ import torch import torch.nn as nn +class PoolingLayer(nn.Module): + """可插拔的池化层,支持多种池化策略。""" + def __init__(self, pooling_type: str = 'mean'): + super(PoolingLayer, self).__init__() + self.pooling_type = pooling_type + + # 如果使用注意力池化,需要定义注意力机制 + if pooling_type == 'attention': + self.attention = nn.Linear(1, 1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: 形状为 [batch_size, seq_len, hidden_dim] 的张量 + + Returns: + 形状为 [batch_size, hidden_dim] 的池化后的张量 + """ + if self.pooling_type == 'mean': + return torch.mean(x, dim=1) + elif self.pooling_type == 'max': + return torch.max(x, dim=1)[0] + elif self.pooling_type == 'cls': + # 取第一个 token 作为 [CLS] token + return x[:, 0, :] + elif self.pooling_type == 'attention': + # 计算注意力权重 + weights = self.attention(torch.ones_like(x[:, :, :1])).softmax(dim=1) + return (x * weights).sum(dim=1) + else: + raise ValueError(f"不支持的池化类型: {self.pooling_type}") + class ClassificationHead(nn.Module): """一个用于分类任务的简单多层感知机(MLP)任务头。""" - def __init__(self, input_dim: int, hidden_dim: int, output_dim: int): + def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, pooling_type: str = 'mean'): super(ClassificationHead, self).__init__() + self.pooling = PoolingLayer(pooling_type) self.fc1 = nn.Linear(input_dim, hidden_dim) self.relu = nn.ReLU() self.fc2 = nn.Linear(hidden_dim, output_dim) @@ -19,9 +52,60 @@ class ClassificationHead(nn.Module): Returns: 最终的分类 logits。 """ - # 我们可以取第一个 token(类似 [CLS])的嵌入,或者进行平均池化 - # 为简单起见,我们假设在序列维度上进行平均池化 - x_pooled = torch.mean(x, dim=1) + # 使用指定的池化方法 + x_pooled = self.pooling(x) + + out = self.fc1(x_pooled) + out = self.relu(out) + out = self.fc2(out) + return out + +class MultiLabelClassificationHead(nn.Module): + """用于多标签分类任务的任务头。""" + + def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, pooling_type: str = 'mean'): + super(MultiLabelClassificationHead, self).__init__() + self.pooling = PoolingLayer(pooling_type) + self.fc1 = nn.Linear(input_dim, hidden_dim) + self.relu = nn.ReLU() + self.fc2 = nn.Linear(hidden_dim, output_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: 来自 Transformer 骨干网络的输入张量。 + + Returns: + 最终的多标签分类 logits。 + """ + # 使用指定的池化方法 + x_pooled = self.pooling(x) + + out = self.fc1(x_pooled) + out = self.relu(out) + out = self.fc2(out) + return out + +class RegressionHead(nn.Module): + """用于回归任务的任务头。""" + + def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, pooling_type: str = 'mean'): + super(RegressionHead, self).__init__() + self.pooling = PoolingLayer(pooling_type) + self.fc1 = nn.Linear(input_dim, hidden_dim) + self.relu = nn.ReLU() + self.fc2 = nn.Linear(hidden_dim, output_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: 来自 Transformer 骨干网络的输入张量。 + + Returns: + 最终的回归输出。 + """ + # 使用指定的池化方法 + x_pooled = self.pooling(x) out = self.fc1(x_pooled) out = self.relu(out) @@ -31,8 +115,9 @@ class ClassificationHead(nn.Module): class MatchingHead(nn.Module): """用于学习版图匹配的相似性嵌入的任务头。""" - def __init__(self, input_dim: int, output_dim: int): + def __init__(self, input_dim: int, output_dim: int, pooling_type: str = 'mean'): super(MatchingHead, self).__init__() + self.pooling = PoolingLayer(pooling_type) self.projection = nn.Linear(input_dim, output_dim) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -43,8 +128,8 @@ class MatchingHead(nn.Module): Returns: 代表整个输入图(例如一个 IP 模块)的单个嵌入向量。 """ - # 全局平均池化,为整个序列获取一个单一的向量 - graph_embedding = torch.mean(x, dim=1) + # 使用指定的池化方法 + graph_embedding = self.pooling(x) # 投影到最终的嵌入空间 similarity_embedding = self.projection(graph_embedding) # 对嵌入进行 L2 归一化,以便使用余弦相似度 diff --git a/src/utils/__init__.py b/src/utils/__init__.py new file mode 100644 index 0000000..4384acc --- /dev/null +++ b/src/utils/__init__.py @@ -0,0 +1,6 @@ +# src/utils/__init__.py +from .config_loader import load_config, merge_configs +from .logging import get_logger +from .seed import set_seed + +__all__ = ['load_config', 'merge_configs', 'get_logger', 'set_seed'] diff --git a/src/utils/__pycache__/__init__.cpython-312.pyc b/src/utils/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8027b81ad27291347cd90d0083ae31d12b35e961 GIT binary patch literal 376 zcmYL^p-;px7{%Ldw_Of~!C}x7M-Ai*L0y6&r#LgyG;6OZX;;#2;uOK-@i_h#1ir}` z5ClhZ4#n4TgipM_U-P|}_N6Q*2&TV#P_J;lNBEm$XSSzc9uY?zYh-YZO;RT~a-vg{ z)!E=vJu-QnV0kmQ<{$_4C^+BAs`r2(-IiJ_@lP+DYjJ7-t#Ze?aK ziDi#oingcN)Dox?3RNh@V59exnyCHrR_IgEHZY_R(Dlz#4h#hH)Ngk8u9w=*GT(gj z{hx1sv%d#}egxwUb48ZHZ+Y{Fkg!?El(liyrZaT4u z)Dp~{%RnZOQPP=EgzEW0niX0Y$4~|vE{{+j9O2)?#Ivq<#-l=Esyyp)Y%nrC?dT>_EWz< z4C6HIWz9Axc4bFv-tDw`eQ4{)dF!X4lJ|S|>B|CGW5d??tdnW|=jn|r^WbFNoUn!$ zgppIQKdA zG0`)K^rSu7$p~f9K26b&A8$-5v1p4POD9RnU{Q!mSo9sDH=b1n^|aC0L6sDX>oiGd zl+jo;ZD==+NdQ^-2)w%x3rkT1ws*dK*->=y_ zaqz2yQ+TE3aL)HwLjKnt1ZuKvkG)0i39AIpFGEpj-g~lw2QDa#xWReAP*}gaiMjU* z#!f~XL94KtKql%dVaV3oWY%Rf9uxPKP_$d$fg8>>HE3<@9ArX&n>}^vn@vX9r+Xzl?fXd2$wmEX! zn!CP!=j!_6qBS}VXmP+}jpwY}x2#;=`u5J|bk;@G!p#!&o(67NTbsj<%fqGI_TuE$ z+;WO57Pf%L8ZjIG4GT*e)XT?p8|Zu%jK93$nt6iNiy!& zn3iIOl8O;7t6I$9L4ag$5f&`srO!cgXblkVvsEgW`UwQMs3g*ai7OOm2-9j|S63vp zz_28cEGj&Ys`uRs9=IDkuo66!ZG8}W@m{FmZm8kUP-MiPlX9KwfxYA8n~v4M;VCu~ zpUy1R|I+$f)30Y%&v&k#eE)uZ!;Ei+-3a76SL)x)wT?H9o+Mq6}3YK_!a9K9hkYmgv#Ho%0OsuaLPgf~BZpE)Euo>*J=lDQIXn zN{Ed!27)Q&;!7Y!SJ_h@)KBLi!eAL^K5RXh&ZieQmlK=I?|#k9YhP!6FD_~6XWH{O z8;Sc?tDB4q$d5a9k4E&KI^Bn``4^K~ayp+{qD2%%M+n}9p@Y!;N=lnr*kit?-*yRk zK}bQAHlHT6hZ<>%uQtCtHoDw9-P$f=T|3WB11MGQ+JvcTdPR(Y&Qk*K1UcSn#cDhkQlNWp$& zxm|!`)hbF-vDRZo^!HWGx)_zGZAoQW5v!Ju8ElyTs`|109S}w^L|6z0#7IYnR~CK3 zrHC&o!?1z{Uqpgf=z$S$zZfHeyw{N!#*_vk0VAIfRRSR_6!&-PPE56uE#4|LMzy2~ zZGo9ndn~UtPh7KhLV@F+=*QtRMzDjEg474>x-$rxxQ70 zkmy<|&a!z23k5Pf1Q*E`rfL=Id*Dqpx3UGiIi>L{j$Vk~Fn zh!hRMbEuJo^tEFTnYLMmVK#u{C$O%&+fueU+aJ)#R40eGsaZPy(e>W7OI@AnAE@?G A-T(jq literal 0 HcmV?d00001 diff --git a/src/utils/__pycache__/seed.cpython-312.pyc b/src/utils/__pycache__/seed.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c12b2d1bc731120cf67e8fe6a666b44392b2d942 GIT binary patch literal 1095 zcmb7C&1(}u6rbIlO`5c6Yayr=YYs(;kS?fLEJdia*bj+9>%m;glHJ%W$!?h06@mp5 zDdf`9+*+!Yib7E_f)}M!{R6x-P$^8kwQ0;rr8#+WX47qwllTtr_ulV)%;U}b>hpO4 zm6=EL@v8y=d(3f!`nDWiqOb%Mpl}HgsUa8PZC{88Ho^#0pu($Sj8}v?G2%*dL9xQh z!vtbgRid2TtEM0Y*8MO>*(G2e$27hn>&To1Z;s*|Fj=K7nqxGe=1kncHC4NvfY#zz zb!^IXIlN}?$MI^Yd76FZ{YSO$AJrXByx!C7bU3ZcvFfOBgVZto4x{k1;0aje>w8NJ zL1?j+c>`2_6h42TwQ3>kFl~`JbEf<>!U{FRP`(v&!rF^2%Ik@m(diPHH7Zt0o!7&*w+pO8e5ezn|746-?=w6dOJKiG&wpsGBRv+)90HKKfJFE zvj1^oAP^3=5*Lzn#YiIdjfB@zL^KT)oguABS-0fGG<%G|sHMoHqk)>16A2<<6A|9P z#I5QN;>gg6m_l*g403im%uXoqXI#3ARfHuiE*t&*p`;v@hK;C|RCN|z6+Rm?;Yq2eF=o{=iq|bvOSGmZ*jo1wnQMZz* zG-8*J*o=r>9}(zg?6F?~#10Cv`ywraNkg#`>UDIIx@=?kELFb*j^p-#Zx?j!g0@}Y Q`RzK*N%_lFK*28m3+Z+MivR!s literal 0 HcmV?d00001 diff --git a/src/utils/seed.py b/src/utils/seed.py new file mode 100644 index 0000000..d650785 --- /dev/null +++ b/src/utils/seed.py @@ -0,0 +1,33 @@ +# src/utils/seed.py +import random +import numpy as np +import torch +import os + + +def set_seed(seed: int = 42): + """ + 设置随机种子,确保实验的可重复性。 + + Args: + seed: 随机种子值 + """ + # 设置 Python 内置随机种子 + random.seed(seed) + + # 设置 NumPy 随机种子 + np.random.seed(seed) + + # 设置 PyTorch 随机种子 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) # 对于多 GPU 环境 + + # 禁用 CUDA 中的确定性算法,以提高性能(可选) + # torch.backends.cudnn.deterministic = True + # torch.backends.cudnn.benchmark = False + + # 设置环境变量中的随机种子 + os.environ['PYTHONHASHSEED'] = str(seed) + + print(f"随机种子已设置为: {seed}") diff --git a/tests/test_model_run.py b/tests/test_model_run.py new file mode 100644 index 0000000..47d8e8b --- /dev/null +++ b/tests/test_model_run.py @@ -0,0 +1,199 @@ +#!/usr/bin/env python3 +""" +测试脚本,用于验证模型是否可以正常跑通,不需要真实数据 +- 生成随机图数据 +- 加载模型配置 +- 初始化模型 +- 运行前向传播和反向传播 +- 验证模型是否可以正常工作 +""" +import os +import sys +import torch +from torch_geometric.data import Data, Batch + +# 添加项目根目录到 Python 路径 +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from src.utils.config_loader import load_config +from src.models.geo_layout_transformer import GeoLayoutTransformer +from src.engine.trainer import Trainer +from src.engine.self_supervised import SelfSupervisedTrainer +from src.utils.logging import get_logger + +def generate_random_graph_data(num_graphs=4, num_nodes_per_graph=8, node_feature_dim=5, edge_feature_dim=0): + """ + 生成随机的图数据 + + Args: + num_graphs: 图的数量 + num_nodes_per_graph: 每个图的节点数量 + node_feature_dim: 节点特征维度 + edge_feature_dim: 边特征维度 + + Returns: + 一个 Batch 对象,包含多个随机生成的图 + """ + graphs = [] + + for _ in range(num_graphs): + # 生成随机节点特征 + x = torch.randn(num_nodes_per_graph, node_feature_dim) + + # 生成随机边(完全连接) + edge_index = [] + for i in range(num_nodes_per_graph): + for j in range(num_nodes_per_graph): + if i != j: + edge_index.append([i, j]) + edge_index = torch.tensor(edge_index, dtype=torch.long).t() + + # 生成随机标签 + y = torch.randn(1, 1) # 假设是图级别的标签 + + # 创建图数据 + graph = Data(x=x, edge_index=edge_index, y=y) + graphs.append(graph) + + # 构建批次 + batch = Batch.from_data_list(graphs) + return batch + +def test_supervised_training(): + """测试监督训练""" + logger = get_logger("Test_Supervised_Training") + logger.info("=== 测试监督训练 ===") + + # 加载配置 + config = load_config('configs/default.yaml') + + # 生成随机数据 + batch = generate_random_graph_data() + logger.info(f"生成的批次数据: {batch}") + logger.info(f"批次大小: {batch.num_graphs}") + logger.info(f"总节点数: {batch.num_nodes}") + logger.info(f"总边数: {batch.num_edges}") + + # 初始化模型 + logger.info("初始化模型...") + model = GeoLayoutTransformer(config) + logger.info("模型初始化成功") + + # 初始化训练器 + logger.info("初始化训练器...") + trainer = Trainer(model, config) + logger.info("训练器初始化成功") + + # 测试前向传播 + logger.info("测试前向传播...") + with torch.no_grad(): + # 先测试 GNN 编码器 + gnn_output = model.gnn_encoder(batch) + logger.info(f"GNN 编码器输出形状: {gnn_output.shape}") + + # 测试形状重塑 + num_graphs = batch.num_graphs + nodes_per_graph = batch.ptr[1:] - batch.ptr[:-1] + logger.info(f"每个图的节点数: {nodes_per_graph}") + reshaped_embeddings = gnn_output.view(num_graphs, nodes_per_graph[0], -1) + logger.info(f"重塑后的嵌入形状: {reshaped_embeddings.shape}") + + # 测试 Transformer 核心 + transformer_output = model.transformer_core(reshaped_embeddings) + logger.info(f"Transformer 输出形状: {transformer_output.shape}") + + # 测试完整模型 + output = model(batch) + logger.info(f"前向传播成功,输出形状: {output.shape}") + + # 测试反向传播 + logger.info("测试反向传播...") + optimizer = torch.optim.Adam(model.parameters(), lr=0.0001) + optimizer.zero_grad() + output = model(batch) + + # 对输出进行全局池化,得到图级别的表示 + # 从 [batch_size, seq_len, hidden_dim] 变为 [batch_size, hidden_dim] + graph_output = output.mean(dim=1) + + # 使用 MSE 损失,只比较前 1 个维度(与 batch.y 形状匹配) + loss = torch.nn.MSELoss()(graph_output[:, :1], batch.y) + loss.backward() + optimizer.step() + logger.info(f"反向传播成功,损失值: {loss.item()}") + + logger.info("监督训练测试完成,模型可以正常工作!") + +def test_self_supervised_training(): + """测试自监督训练""" + logger = get_logger("Test_Self_Supervised_Training") + logger.info("\n=== 测试自监督训练 ===") + + # 加载配置 + config = load_config('configs/default.yaml') + + # 生成随机数据 + batch = generate_random_graph_data() + logger.info(f"生成的批次数据: {batch}") + logger.info(f"批次大小: {batch.num_graphs}") + logger.info(f"总节点数: {batch.num_nodes}") + logger.info(f"总边数: {batch.num_edges}") + + # 初始化模型 + logger.info("初始化模型...") + model = GeoLayoutTransformer(config) + logger.info("模型初始化成功") + + # 初始化自监督训练器 + logger.info("初始化自监督训练器...") + trainer = SelfSupervisedTrainer(model, config) + logger.info("自监督训练器初始化成功") + + # 测试前向传播 + logger.info("测试前向传播...") + with torch.no_grad(): + # 测试 GNN 编码器 + gnn_output = model.gnn_encoder(batch) + logger.info(f"GNN 编码器输出形状: {gnn_output.shape}") + + # 测试 Transformer 核心 + num_graphs = batch.num_graphs + nodes_per_graph = batch.ptr[1:] - batch.ptr[:-1] + if not torch.all(nodes_per_graph == nodes_per_graph[0]): + logger.warning("批次中图形的节点数不一致,使用第一个图形的节点数") + nodes_per_graph = nodes_per_graph[0] + + gnn_output_reshaped = gnn_output.view(num_graphs, nodes_per_graph, -1) + transformer_output = model.transformer_core(gnn_output_reshaped) + logger.info(f"Transformer 核心输出形状: {transformer_output.shape}") + + # 测试完整模型前向传播 + logger.info("测试完整模型前向传播...") + with torch.no_grad(): + output = model(batch) + logger.info(f"完整模型前向传播成功,输出形状: {output.shape}") + + logger.info("自监督训练测试完成,模型可以正常工作!") + +def main(): + """主函数""" + logger = get_logger("Test_Model_Run") + logger.info("开始测试模型是否可以正常跑通...") + + try: + # 测试监督训练 + test_supervised_training() + + # 测试自监督训练 + test_self_supervised_training() + + logger.info("\n✅ 所有测试通过,模型可以正常跑通!") + logger.info("模型已准备就绪,可以使用真实数据进行训练。") + except Exception as e: + logger.error(f"❌ 测试失败: {e}") + import traceback + traceback.print_exc() + sys.exit(1) + +if __name__ == "__main__": + main() diff --git a/uv.lock b/uv.lock index d11ebc7..c08d9b8 100644 --- a/uv.lock +++ b/uv.lock @@ -1,8 +1,16 @@ -# uv.lock version = 1 revision = 2 requires-python = ">=3.12" +[[package]] +name = "absl-py" +version = "2.4.0" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/64/c7/8de93764ad66968d19329a7e0c147a2bb3c7054c554d4a119111b8f9440f/absl_py-2.4.0.tar.gz", hash = "sha256:8c6af82722b35cf71e0f4d1d47dcaebfff286e27110a99fc359349b247dfb5d4", size = 116543, upload-time = "2026-01-28T10:17:05.322Z" } +wheels = [ + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/18/a6/907a406bb7d359e6a63f99c313846d9eec4f7e6f7437809e03aa00fa3074/absl_py-2.4.0-py3-none-any.whl", hash = "sha256:88476fd881ca8aab94ffa78b7b6c632a782ab3ba1cd19c9bd423abc4fb4cd28d", size = 135750, upload-time = "2026-01-28T10:17:04.19Z" }, +] + [[package]] name = "aiohappyeyeballs" version = "2.6.1" @@ -234,20 +242,35 @@ sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/e8/b5/a12ef182943856 wheels = [ { url = "https://pypi.tuna.tsinghua.edu.cn/packages/51/a8/cff9bd17789b41c2f09f4e4b574fd05c5860e0b5602b3708b6c1de4b6c82/gdstk-0.9.61-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:61f0ee05cdce9b4163ea812cbf2e2f5d8d01a293fa118ff98348280306bd91d6", size = 923143, upload-time = "2025-08-28T10:16:42.55Z" }, { url = "https://pypi.tuna.tsinghua.edu.cn/packages/9b/13/c97316d18510e2dcb3aeba0e13f4c6d7ad34884be62006b910f382bdfbc6/gdstk-0.9.61-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:fff1b104b6775e4c27ab2751b3f4ac6c1ce86a4e9afd5e5535ac4acefa6a7a07", size = 475724, upload-time = "2025-08-28T10:16:44.434Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/19/49/40aed8aae97054b08a0063a583ef55e3ab0e335441d6d339615d8593892a/gdstk-0.9.61-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5218f8c5ab13b6e979665c0a7dc1272768003a1cb7add0682483837f7485faed", size = 536850, upload-time = "2025-10-20T11:25:50.578Z" }, { url = "https://pypi.tuna.tsinghua.edu.cn/packages/35/7c/4d324ae83dac2ba15fcc61d688019ad79b8938c027d2467415cb804f8058/gdstk-0.9.61-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4e79f3881d3b3666a600efd5b2c131454507f69d3c9b9eaf383d106cfbd6e7bc", size = 600640, upload-time = "2025-08-28T10:16:45.627Z" }, { url = "https://pypi.tuna.tsinghua.edu.cn/packages/ef/cd/addb66d3740a654495e413fb28d6c4e4d9ac94bc0c25dc31f2868fb6e18c/gdstk-0.9.61-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e90a6e24c2145320e53e953a59c6297fd25c17c6ef098fa8602e64e64a5390ea", size = 536849, upload-time = "2025-08-28T10:16:47.039Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/d7/37/43fb416068f0722f1a9ccc67fe6a8f78c93ef3bb9cdf37a2edd276316b23/gdstk-0.9.61-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a3a49401cbd26c5a17a4152d1befa73efb21af694524557bf09d15f4c8a874e6", size = 540029, upload-time = "2025-10-20T11:26:06.539Z" }, { url = "https://pypi.tuna.tsinghua.edu.cn/packages/f9/f0/d4a24c1b6636454812b4990dc590d72b0a37f185122c6fa19d4b0d22a90e/gdstk-0.9.61-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:8738ac63bbe29dcb5abae6a19d207c4e0857f9dc1bd405c85af8a87f0dcfb348", size = 535749, upload-time = "2025-08-28T10:16:48.157Z" }, { url = "https://pypi.tuna.tsinghua.edu.cn/packages/85/67/ec1ef9f67ac26554d55f4e667e602698fb7f520ef89dfe6cb4550152eec1/gdstk-0.9.61-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:23bb023a49f3321673d0e32cdce2e2705a51d9e12328c928723ded49af970520", size = 1711671, upload-time = "2025-08-28T10:16:49.963Z" }, { url = "https://pypi.tuna.tsinghua.edu.cn/packages/89/73/4fd73bbc7500e383500a9edca4c09ce38bffd44eefb256ce17a41c526dbe/gdstk-0.9.61-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:81c2f19cab89623d1f56848e7a16e2fab82a93c61c8f7aa73f5ff59840b60c0f", size = 1535161, upload-time = "2025-08-28T10:16:51.181Z" }, { url = "https://pypi.tuna.tsinghua.edu.cn/packages/a0/b1/d740423bd7436c522295a09cadc0e21d59afa418185cffffce46a6eb85b0/gdstk-0.9.61-cp312-cp312-win_amd64.whl", hash = "sha256:4474f015ecc228b210165287cb7eea65639ea6308f60105cb49e970079bddc2b", size = 500139, upload-time = "2025-08-28T10:16:52.496Z" }, { url = "https://pypi.tuna.tsinghua.edu.cn/packages/07/9a/7ea7b7a295e029542d4aeb252d01c1bcc8e724df26155f9b5f432b02d02a/gdstk-0.9.61-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:3beeae846fc523c7e3a01c47edcd3b7dd83c29650e56b82a371e528f9cb0ec3e", size = 923024, upload-time = "2025-08-28T10:16:53.613Z" }, { url = "https://pypi.tuna.tsinghua.edu.cn/packages/b8/3d/5aa9b1a4665259702e9f17e03a9a114a873df25c9ba2c9e782ff25fb11e9/gdstk-0.9.61-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:575a21639b31e2fab4d9e918468b8b40a58183028db563e5963be594bff1403d", size = 475687, upload-time = "2025-08-28T10:16:54.886Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/80/13/ec783d8de5d9b4e51763102cac6da124f16747e4c73f166c36a105065008/gdstk-0.9.61-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:90d54b48223dcbb8257769faaa87542d12a749d8486e8d1187a45d06e9422859", size = 536872, upload-time = "2025-10-20T11:25:52.902Z" }, { url = "https://pypi.tuna.tsinghua.edu.cn/packages/20/4a/365ca49b76ee3d70a0d044fcc44272c85754fd763e0b3f3e9f498b8bf4a1/gdstk-0.9.61-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:35405bed95542a0b10f343b165ce0ad80740bf8127a4507565ec74222e6ec8d3", size = 600630, upload-time = "2025-08-28T10:16:56.326Z" }, { url = "https://pypi.tuna.tsinghua.edu.cn/packages/62/3e/815fd2977ff1d885ad87d0a54deb19926fd025933325c7a27625c1c0a0c3/gdstk-0.9.61-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b311ddf8982995b52ac3bf3b32a6cf6d918afc4e66dea527d531e8af73896231", size = 536873, upload-time = "2025-08-28T10:16:57.516Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/b0/8b/1ba0abc4fb3c60015d23894aa9d093f473fdc337584f9c1d7afe96d6f9f5/gdstk-0.9.61-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6dcbfc60fba92d10f1c7d612b5409c343fcaf2a380640e9fb01c504ca948b412", size = 540029, upload-time = "2025-10-20T11:26:09.727Z" }, { url = "https://pypi.tuna.tsinghua.edu.cn/packages/74/72/cc46f132741e541995ede7fccf9820f105fb2296ab70192bd27de56190f2/gdstk-0.9.61-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:fab67ccdd8029ef7eb873f8c98f875dc2665a5e45af7cf3d2a7a0f401826a1d3", size = 535763, upload-time = "2025-08-28T10:16:58.621Z" }, { url = "https://pypi.tuna.tsinghua.edu.cn/packages/5b/db/72196721fedfc38cf21158e5a436d73b41ea244b78c1053462a35d1e42cb/gdstk-0.9.61-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:5852749e203d6978e06d02f8ef9e29ce4512cb1aedeb62c37b8e8b2c10c4f529", size = 1711669, upload-time = "2025-08-28T10:16:59.838Z" }, { url = "https://pypi.tuna.tsinghua.edu.cn/packages/36/ec/eeebcc95c2741e1f39da08c95fdcc56b3d0f5305ad732d8a66213b6fa0b8/gdstk-0.9.61-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:7ee38a54c799e77dbe219266f765bbd3b2906b62bc7b6fb64b1387e6db3dd187", size = 1535165, upload-time = "2025-08-28T10:17:01.027Z" }, { url = "https://pypi.tuna.tsinghua.edu.cn/packages/cb/a8/653335b1ec13306b023a9aa1e2072e8bab5c0a5376c138066006504198c3/gdstk-0.9.61-cp313-cp313-win_amd64.whl", hash = "sha256:6abb396873b2660dd7863d664b3822f00547bf7f216af27be9f1f812bc5e8027", size = 500117, upload-time = "2025-08-28T10:17:02.459Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/33/52/28cd8720357d6b892ac19684e4af57f7264ba8eea0ea8078e17f0476408c/gdstk-0.9.61-cp314-cp314-macosx_10_15_universal2.whl", hash = "sha256:a674af8be5cf1f8ea9f6c5b5f165f797d7e2ed74cbca68b4a22adb92b515fb35", size = 916091, upload-time = "2025-10-20T11:25:35.497Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/fe/f1/c55c7b7b0158a8540716a4fea3aaf0288726122afc0e9af1f0564b79605b/gdstk-0.9.61-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:38ec0b7285d6c9bf8cbc279731dc0d314633cda2ce9e6f9053554b3e5f004fcd", size = 474548, upload-time = "2025-10-20T11:25:38.086Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/ab/83/23907af9b349d79af54c4a79ac7972b9a52f8cee48b366ee9635cf9a32d9/gdstk-0.9.61-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:3b63a77b57fb441c8017217aaf1e8b13d93cbee822031a8e2826adb716e01dd4", size = 537072, upload-time = "2025-10-20T11:25:55.473Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/2c/e7/e930928f9a03b896f51b6e975dbb652cffa81c61338609d6289c3f48313c/gdstk-0.9.61-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f7fae6eee627e837d1405b47d381ccd33dbba85473b1bb3822bdc8ae41dbc0dc", size = 540132, upload-time = "2025-10-20T11:26:11.379Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/6d/9d/56afbb84cccb07751f8532591bfc3a1608833943446b51978b52daaaa2b1/gdstk-0.9.61-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:9e396694cac24bd87d0e38c37e6740d9ba0c13f6c9f2211a871d62288430f069", size = 1578033, upload-time = "2025-10-20T11:26:18.707Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/58/14/c854d3ef2b0ece30c0cb4c034d3b7f58425086383a229c960d813a163861/gdstk-0.9.61-cp314-cp314-win_amd64.whl", hash = "sha256:7ea0c1200dc53b794e9c0cc6fe3ea51e49113dfdd9c3109e1961cda3cc2197c7", size = 513072, upload-time = "2025-10-20T11:25:21.089Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/be/0a/5e9c3be1327d36556a5e8d16c6696614fdc1df0d28c0629b829785245d76/gdstk-0.9.61-cp314-cp314t-macosx_10_15_universal2.whl", hash = "sha256:616dd1c3e7aea4a98aeb03db7cf76a853d134c54690790eaa25c63eede7b869a", size = 925091, upload-time = "2025-10-20T11:25:40.772Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/6a/4a/a72de67ea1c217537ae537d7e96b6a0b3c7427177bd1439c89e7617ad3f0/gdstk-0.9.61-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:b0e898202fbb7fd4c39f8404831415a0aa0445656342102c4e77d4a7c2c15a1d", size = 477723, upload-time = "2025-10-20T11:25:44.535Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/0c/94/5034a646ee3bda58210cd377b241582161c8683489210cba762d4f7f06d1/gdstk-0.9.61-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:29bb862a1a814f5bbd6f8bbc2f99e1163df9e6307071cb6e11251dbe7542feb5", size = 541773, upload-time = "2025-10-20T11:25:57.407Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/bc/90/3e5883b528dcc7c4daa74925276a09ff274004518deea48c859b9215120b/gdstk-0.9.61-cp314-cp314t-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c6c2a08d82a683aff50dc63f2943ed805d32d46bd984cbd4ac9cf876146d0ef9", size = 544525, upload-time = "2025-10-20T11:26:13.877Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/46/0b/549a0e72982b87011013705cdb552e2acfdb882ceaeb41a3fe020e981ec8/gdstk-0.9.61-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:3ba52f95763052a6968583942e6531ceca20c14c762d44fe2bd887445e2f73b6", size = 1582069, upload-time = "2025-10-20T11:26:21.189Z" }, ] [[package]] @@ -260,6 +283,7 @@ dependencies = [ { name = "pandas" }, { name = "pyyaml" }, { name = "scikit-learn" }, + { name = "tensorboard" }, { name = "torch" }, { name = "torch-geometric" }, { name = "torchvision" }, @@ -272,11 +296,53 @@ requires-dist = [ { name = "pandas", specifier = ">=2.3.2" }, { name = "pyyaml", specifier = ">=6.0.2" }, { name = "scikit-learn", specifier = ">=1.7.1" }, + { name = "tensorboard", specifier = ">=2.20.0" }, { name = "torch", specifier = ">=2.8.0" }, { name = "torch-geometric", specifier = ">=2.6.1" }, { name = "torchvision", specifier = ">=0.23.0" }, ] +[[package]] +name = "grpcio" +version = "1.78.0" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/06/8a/3d098f35c143a89520e568e6539cc098fcd294495910e359889ce8741c84/grpcio-1.78.0.tar.gz", hash = "sha256:7382b95189546f375c174f53a5fa873cef91c4b8005faa05cc5b3beea9c4f1c5", size = 12852416, upload-time = "2026-02-06T09:57:18.093Z" } +wheels = [ + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/4e/f4/7384ed0178203d6074446b3c4f46c90a22ddf7ae0b3aee521627f54cfc2a/grpcio-1.78.0-cp312-cp312-linux_armv7l.whl", hash = "sha256:f9ab915a267fc47c7e88c387a3a28325b58c898e23d4995f765728f4e3dedb97", size = 5913985, upload-time = "2026-02-06T09:55:26.832Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/81/ed/be1caa25f06594463f685b3790b320f18aea49b33166f4141bfdc2bfb236/grpcio-1.78.0-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:3f8904a8165ab21e07e58bf3e30a73f4dffc7a1e0dbc32d51c61b5360d26f43e", size = 11811853, upload-time = "2026-02-06T09:55:29.224Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/24/a7/f06d151afc4e64b7e3cc3e872d331d011c279aaab02831e40a81c691fb65/grpcio-1.78.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:859b13906ce098c0b493af92142ad051bf64c7870fa58a123911c88606714996", size = 6475766, upload-time = "2026-02-06T09:55:31.825Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/8a/a8/4482922da832ec0082d0f2cc3a10976d84a7424707f25780b82814aafc0a/grpcio-1.78.0-cp312-cp312-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:b2342d87af32790f934a79c3112641e7b27d63c261b8b4395350dad43eff1dc7", size = 7170027, upload-time = "2026-02-06T09:55:34.7Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/54/bf/f4a3b9693e35d25b24b0b39fa46d7d8a3c439e0a3036c3451764678fec20/grpcio-1.78.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:12a771591ae40bc65ba67048fa52ef4f0e6db8279e595fd349f9dfddeef571f9", size = 6690766, upload-time = "2026-02-06T09:55:36.902Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/c7/b9/521875265cc99fe5ad4c5a17010018085cae2810a928bf15ebe7d8bcd9cc/grpcio-1.78.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:185dea0d5260cbb2d224c507bf2a5444d5abbb1fa3594c1ed7e4c709d5eb8383", size = 7266161, upload-time = "2026-02-06T09:55:39.824Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/05/86/296a82844fd40a4ad4a95f100b55044b4f817dece732bf686aea1a284147/grpcio-1.78.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:51b13f9aed9d59ee389ad666b8c2214cc87b5de258fa712f9ab05f922e3896c6", size = 8253303, upload-time = "2026-02-06T09:55:42.353Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/f3/e4/ea3c0caf5468537f27ad5aab92b681ed7cc0ef5f8c9196d3fd42c8c2286b/grpcio-1.78.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:fd5f135b1bd58ab088930b3c613455796dfa0393626a6972663ccdda5b4ac6ce", size = 7698222, upload-time = "2026-02-06T09:55:44.629Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/d7/47/7f05f81e4bb6b831e93271fb12fd52ba7b319b5402cbc101d588f435df00/grpcio-1.78.0-cp312-cp312-win32.whl", hash = "sha256:94309f498bcc07e5a7d16089ab984d42ad96af1d94b5a4eb966a266d9fcabf68", size = 4066123, upload-time = "2026-02-06T09:55:47.644Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/ad/e7/d6914822c88aa2974dbbd10903d801a28a19ce9cd8bad7e694cbbcf61528/grpcio-1.78.0-cp312-cp312-win_amd64.whl", hash = "sha256:9566fe4ababbb2610c39190791e5b829869351d14369603702e890ef3ad2d06e", size = 4797657, upload-time = "2026-02-06T09:55:49.86Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/05/a9/8f75894993895f361ed8636cd9237f4ab39ef87fd30db17467235ed1c045/grpcio-1.78.0-cp313-cp313-linux_armv7l.whl", hash = "sha256:ce3a90455492bf8bfa38e56fbbe1dbd4f872a3d8eeaf7337dc3b1c8aa28c271b", size = 5920143, upload-time = "2026-02-06T09:55:52.035Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/55/06/0b78408e938ac424100100fd081189451b472236e8a3a1f6500390dc4954/grpcio-1.78.0-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:2bf5e2e163b356978b23652c4818ce4759d40f4712ee9ec5a83c4be6f8c23a3a", size = 11803926, upload-time = "2026-02-06T09:55:55.494Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/88/93/b59fe7832ff6ae3c78b813ea43dac60e295fa03606d14d89d2e0ec29f4f3/grpcio-1.78.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:8f2ac84905d12918e4e55a16da17939eb63e433dc11b677267c35568aa63fc84", size = 6478628, upload-time = "2026-02-06T09:55:58.533Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/ed/df/e67e3734527f9926b7d9c0dde6cd998d1d26850c3ed8eeec81297967ac67/grpcio-1.78.0-cp313-cp313-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:b58f37edab4a3881bc6c9bca52670610e0c9ca14e2ea3cf9debf185b870457fb", size = 7173574, upload-time = "2026-02-06T09:56:01.786Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/a6/62/cc03fffb07bfba982a9ec097b164e8835546980aec25ecfa5f9c1a47e022/grpcio-1.78.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:735e38e176a88ce41840c21bb49098ab66177c64c82426e24e0082500cc68af5", size = 6692639, upload-time = "2026-02-06T09:56:04.529Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/bf/9a/289c32e301b85bdb67d7ec68b752155e674ee3ba2173a1858f118e399ef3/grpcio-1.78.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:2045397e63a7a0ee7957c25f7dbb36ddc110e0cfb418403d110c0a7a68a844e9", size = 7268838, upload-time = "2026-02-06T09:56:08.397Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/0e/79/1be93f32add280461fa4773880196572563e9c8510861ac2da0ea0f892b6/grpcio-1.78.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:a9f136fbafe7ccf4ac7e8e0c28b31066e810be52d6e344ef954a3a70234e1702", size = 8251878, upload-time = "2026-02-06T09:56:10.914Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/65/65/793f8e95296ab92e4164593674ae6291b204bb5f67f9d4a711489cd30ffa/grpcio-1.78.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:748b6138585379c737adc08aeffd21222abbda1a86a0dca2a39682feb9196c20", size = 7695412, upload-time = "2026-02-06T09:56:13.593Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/1c/9f/1e233fe697ecc82845942c2822ed06bb522e70d6771c28d5528e4c50f6a4/grpcio-1.78.0-cp313-cp313-win32.whl", hash = "sha256:271c73e6e5676afe4fc52907686670c7cea22ab2310b76a59b678403ed40d670", size = 4064899, upload-time = "2026-02-06T09:56:15.601Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/4d/27/d86b89e36de8a951501fb06a0f38df19853210f341d0b28f83f4aa0ffa08/grpcio-1.78.0-cp313-cp313-win_amd64.whl", hash = "sha256:f2d4e43ee362adfc05994ed479334d5a451ab7bc3f3fee1b796b8ca66895acb4", size = 4797393, upload-time = "2026-02-06T09:56:17.882Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/29/f2/b56e43e3c968bfe822fa6ce5bca10d5c723aa40875b48791ce1029bb78c7/grpcio-1.78.0-cp314-cp314-linux_armv7l.whl", hash = "sha256:e87cbc002b6f440482b3519e36e1313eb5443e9e9e73d6a52d43bd2004fcfd8e", size = 5920591, upload-time = "2026-02-06T09:56:20.758Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/5d/81/1f3b65bd30c334167bfa8b0d23300a44e2725ce39bba5b76a2460d85f745/grpcio-1.78.0-cp314-cp314-macosx_11_0_universal2.whl", hash = "sha256:c41bc64626db62e72afec66b0c8a0da76491510015417c127bfc53b2fe6d7f7f", size = 11813685, upload-time = "2026-02-06T09:56:24.315Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/0e/1c/bbe2f8216a5bd3036119c544d63c2e592bdf4a8ec6e4a1867592f4586b26/grpcio-1.78.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:8dfffba826efcf366b1e3ccc37e67afe676f290e13a3b48d31a46739f80a8724", size = 6487803, upload-time = "2026-02-06T09:56:27.367Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/16/5c/a6b2419723ea7ddce6308259a55e8e7593d88464ce8db9f4aa857aba96fa/grpcio-1.78.0-cp314-cp314-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:74be1268d1439eaaf552c698cdb11cd594f0c49295ae6bb72c34ee31abbe611b", size = 7173206, upload-time = "2026-02-06T09:56:29.876Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/df/1e/b8801345629a415ea7e26c83d75eb5dbe91b07ffe5210cc517348a8d4218/grpcio-1.78.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:be63c88b32e6c0f1429f1398ca5c09bc64b0d80950c8bb7807d7d7fb36fb84c7", size = 6693826, upload-time = "2026-02-06T09:56:32.305Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/34/84/0de28eac0377742679a510784f049738a80424b17287739fc47d63c2439e/grpcio-1.78.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:3c586ac70e855c721bda8f548d38c3ca66ac791dc49b66a8281a1f99db85e452", size = 7277897, upload-time = "2026-02-06T09:56:34.915Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/ca/9c/ad8685cfe20559a9edb66f735afdcb2b7d3de69b13666fdfc542e1916ebd/grpcio-1.78.0-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:35eb275bf1751d2ffbd8f57cdbc46058e857cf3971041521b78b7db94bdaf127", size = 8252404, upload-time = "2026-02-06T09:56:37.553Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/3c/05/33a7a4985586f27e1de4803887c417ec7ced145ebd069bc38a9607059e2b/grpcio-1.78.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:207db540302c884b8848036b80db352a832b99dfdf41db1eb554c2c2c7800f65", size = 7696837, upload-time = "2026-02-06T09:56:40.173Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/73/77/7382241caf88729b106e49e7d18e3116216c778e6a7e833826eb96de22f7/grpcio-1.78.0-cp314-cp314-win32.whl", hash = "sha256:57bab6deef2f4f1ca76cc04565df38dc5713ae6c17de690721bdf30cb1e0545c", size = 4142439, upload-time = "2026-02-06T09:56:43.258Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/48/b2/b096ccce418882fbfda4f7496f9357aaa9a5af1896a9a7f60d9f2b275a06/grpcio-1.78.0-cp314-cp314-win_amd64.whl", hash = "sha256:dce09d6116df20a96acfdbf85e4866258c3758180e8c49845d6ba8248b6d0bbb", size = 4929852, upload-time = "2026-02-06T09:56:45.885Z" }, +] + [[package]] name = "idna" version = "3.10" @@ -307,6 +373,15 @@ wheels = [ { url = "https://pypi.tuna.tsinghua.edu.cn/packages/1e/e8/685f47e0d754320684db4425a0967f7d3fa70126bffd76110b7009a0090f/joblib-1.5.2-py3-none-any.whl", hash = "sha256:4e1f0bdbb987e6d843c70cf43714cb276623def372df3c22fe5266b2670bc241", size = 308396, upload-time = "2025-08-27T12:15:45.188Z" }, ] +[[package]] +name = "markdown" +version = "3.10.2" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/2b/f4/69fa6ed85ae003c2378ffa8f6d2e3234662abd02c10d216c0ba96081a238/markdown-3.10.2.tar.gz", hash = "sha256:994d51325d25ad8aa7ce4ebaec003febcce822c3f8c911e3b17c52f7f589f950", size = 368805, upload-time = "2026-02-09T14:57:26.942Z" } +wheels = [ + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/de/1f/77fa3081e4f66ca3576c896ae5d31c3002ac6607f9747d2e3aa49227e464/markdown-3.10.2-py3-none-any.whl", hash = "sha256:e91464b71ae3ee7afd3017d9f358ef0baf158fd9a298db92f1d4761133824c36", size = 108180, upload-time = "2026-02-09T14:57:25.787Z" }, +] + [[package]] name = "markupsafe" version = "3.0.2" @@ -615,6 +690,15 @@ wheels = [ { url = "https://pypi.tuna.tsinghua.edu.cn/packages/a2/eb/86626c1bbc2edb86323022371c39aa48df6fd8b0a1647bc274577f72e90b/nvidia_nvtx_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5b17e2001cc0d751a5bc2c6ec6d26ad95913324a4adb86788c944f8ce9ba441f", size = 89954, upload-time = "2025-03-07T01:42:44.131Z" }, ] +[[package]] +name = "packaging" +version = "26.0" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/65/ee/299d360cdc32edc7d2cf530f3accf79c4fca01e96ffc950d8a52213bd8e4/packaging-26.0.tar.gz", hash = "sha256:00243ae351a257117b6a241061796684b084ed1c516a08c48a3f7e147a9d80b4", size = 143416, upload-time = "2026-01-21T20:50:39.064Z" } +wheels = [ + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/b7/b9/c538f279a4e237a006a2c98387d081e9eb060d203d8ed34467cc0f0b9b53/packaging-26.0-py3-none-any.whl", hash = "sha256:b36f1fef9334a5588b4166f8bcd26a14e521f2b55e6b9de3aaa80d3ff7a37529", size = 74366, upload-time = "2026-01-21T20:50:37.788Z" }, +] + [[package]] name = "pandas" version = "2.3.2" @@ -772,6 +856,21 @@ wheels = [ { url = "https://pypi.tuna.tsinghua.edu.cn/packages/cc/35/cc0aaecf278bb4575b8555f2b137de5ab821595ddae9da9d3cd1da4072c7/propcache-0.3.2-py3-none-any.whl", hash = "sha256:98f1ec44fb675f5052cccc8e609c46ed23a35a1cfd18545ad4e29002d858a43f", size = 12663, upload-time = "2025-06-09T22:56:04.484Z" }, ] +[[package]] +name = "protobuf" +version = "6.33.5" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/ba/25/7c72c307aafc96fa87062aa6291d9f7c94836e43214d43722e86037aac02/protobuf-6.33.5.tar.gz", hash = "sha256:6ddcac2a081f8b7b9642c09406bc6a4290128fce5f471cddd165960bb9119e5c", size = 444465, upload-time = "2026-01-29T21:51:33.494Z" } +wheels = [ + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/b1/79/af92d0a8369732b027e6d6084251dd8e782c685c72da161bd4a2e00fbabb/protobuf-6.33.5-cp310-abi3-win32.whl", hash = "sha256:d71b040839446bac0f4d162e758bea99c8251161dae9d0983a3b88dee345153b", size = 425769, upload-time = "2026-01-29T21:51:21.751Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/55/75/bb9bc917d10e9ee13dee8607eb9ab963b7cf8be607c46e7862c748aa2af7/protobuf-6.33.5-cp310-abi3-win_amd64.whl", hash = "sha256:3093804752167bcab3998bec9f1048baae6e29505adaf1afd14a37bddede533c", size = 437118, upload-time = "2026-01-29T21:51:24.022Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/a2/6b/e48dfc1191bc5b52950246275bf4089773e91cb5ba3592621723cdddca62/protobuf-6.33.5-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:a5cb85982d95d906df1e2210e58f8e4f1e3cdc088e52c921a041f9c9a0386de5", size = 427766, upload-time = "2026-01-29T21:51:25.413Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/4e/b1/c79468184310de09d75095ed1314b839eb2f72df71097db9d1404a1b2717/protobuf-6.33.5-cp39-abi3-manylinux2014_aarch64.whl", hash = "sha256:9b71e0281f36f179d00cbcb119cb19dec4d14a81393e5ea220f64b286173e190", size = 324638, upload-time = "2026-01-29T21:51:26.423Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/c5/f5/65d838092fd01c44d16037953fd4c2cc851e783de9b8f02b27ec4ffd906f/protobuf-6.33.5-cp39-abi3-manylinux2014_s390x.whl", hash = "sha256:8afa18e1d6d20af15b417e728e9f60f3aa108ee76f23c3b2c07a2c3b546d3afd", size = 339411, upload-time = "2026-01-29T21:51:27.446Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/9b/53/a9443aa3ca9ba8724fdfa02dd1887c1bcd8e89556b715cfbacca6b63dbec/protobuf-6.33.5-cp39-abi3-manylinux2014_x86_64.whl", hash = "sha256:cbf16ba3350fb7b889fca858fb215967792dc125b35c7976ca4818bee3521cf0", size = 323465, upload-time = "2026-01-29T21:51:28.925Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/57/bf/2086963c69bdac3d7cff1cc7ff79b8ce5ea0bec6797a017e1be338a46248/protobuf-6.33.5-py3-none-any.whl", hash = "sha256:69915a973dd0f60f31a08b8318b73eab2bd6a392c79184b3612226b0a3f8ec02", size = 170687, upload-time = "2026-01-29T21:51:32.557Z" }, +] + [[package]] name = "psutil" version = "7.0.0" @@ -973,6 +1072,36 @@ wheels = [ { url = "https://pypi.tuna.tsinghua.edu.cn/packages/a2/09/77d55d46fd61b4a135c444fc97158ef34a095e5681d0a6c10b75bf356191/sympy-1.14.0-py3-none-any.whl", hash = "sha256:e091cc3e99d2141a0ba2847328f5479b05d94a6635cb96148ccb3f34671bd8f5", size = 6299353, upload-time = "2025-04-27T18:04:59.103Z" }, ] +[[package]] +name = "tensorboard" +version = "2.20.0" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +dependencies = [ + { name = "absl-py" }, + { name = "grpcio" }, + { name = "markdown" }, + { name = "numpy" }, + { name = "packaging" }, + { name = "pillow" }, + { name = "protobuf" }, + { name = "setuptools" }, + { name = "tensorboard-data-server" }, + { name = "werkzeug" }, +] +wheels = [ + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/9c/d9/a5db55f88f258ac669a92858b70a714bbbd5acd993820b41ec4a96a4d77f/tensorboard-2.20.0-py3-none-any.whl", hash = "sha256:9dc9f978cb84c0723acf9a345d96c184f0293d18f166bb8d59ee098e6cfaaba6", size = 5525680, upload-time = "2025-07-17T19:20:49.638Z" }, +] + +[[package]] +name = "tensorboard-data-server" +version = "0.7.2" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +wheels = [ + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/7a/13/e503968fefabd4c6b2650af21e110aa8466fe21432cd7c43a84577a89438/tensorboard_data_server-0.7.2-py3-none-any.whl", hash = "sha256:7e0610d205889588983836ec05dc098e80f97b7e7bbff7e994ebb78f578d0ddb" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/b7/85/dabeaf902892922777492e1d253bb7e1264cadce3cea932f7ff599e53fea/tensorboard_data_server-0.7.2-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:9fe5d24221b29625dbc7328b0436ca7fc1c23de4acf4d272f1180856e32f9f60" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/73/c6/825dab04195756cf8ff2e12698f22513b3db2f64925bdd41671bfb33aaa5/tensorboard_data_server-0.7.2-py3-none-manylinux_2_31_x86_64.whl", hash = "sha256:ef687163c24185ae9754ed5650eb5bc4d84ff257aabdc33f0cc6f74d8ba54530" }, +] + [[package]] name = "threadpoolctl" version = "3.6.0" @@ -1120,6 +1249,18 @@ wheels = [ { url = "https://pypi.tuna.tsinghua.edu.cn/packages/a7/c2/fe1e52489ae3122415c51f387e221dd0773709bad6c6cdaa599e8a2c5185/urllib3-2.5.0-py3-none-any.whl", hash = "sha256:e6b01673c0fa6a13e374b50871808eb3bf7046c4b125b216f6bf1cc604cff0dc", size = 129795, upload-time = "2025-06-18T14:07:40.39Z" }, ] +[[package]] +name = "werkzeug" +version = "3.1.5" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +dependencies = [ + { name = "markupsafe" }, +] +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/5a/70/1469ef1d3542ae7c2c7b72bd5e3a4e6ee69d7978fa8a3af05a38eca5becf/werkzeug-3.1.5.tar.gz", hash = "sha256:6a548b0e88955dd07ccb25539d7d0cc97417ee9e179677d22c7041c8f078ce67", size = 864754, upload-time = "2026-01-08T17:49:23.247Z" } +wheels = [ + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/ad/e4/8d97cca767bcc1be76d16fb76951608305561c6e056811587f36cb1316a8/werkzeug-3.1.5-py3-none-any.whl", hash = "sha256:5111e36e91086ece91f93268bb39b4a35c1e6f1feac762c9c822ded0a4e322dc", size = 225025, upload-time = "2026-01-08T17:49:21.859Z" }, +] + [[package]] name = "yarl" version = "1.20.1"