Compare commits
10 Commits
7ef7d6d3bc
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
13a749431a | ||
|
|
ed8270b0f3 | ||
|
|
f4e04f9b3c | ||
|
|
5783702047 | ||
|
|
3911e705d8 | ||
|
|
51086f364b | ||
|
|
7cc845b71a | ||
|
|
d8186d9d13 | ||
|
|
d110130008 | ||
|
|
b4929311d7 |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -1,2 +1,3 @@
|
||||
# .gitignore
|
||||
reference/
|
||||
.venv/
|
||||
|
||||
136
README.md
136
README.md
@@ -1,3 +1,4 @@
|
||||
<!-- README.md -->
|
||||
<div align="center">
|
||||
|
||||
<p align="center">
|
||||
@@ -5,11 +6,11 @@
|
||||
</p>
|
||||
|
||||
<p>
|
||||
<a href="https://github.com/your-username/Geo-Layout-Transformer/stargazers"><img src="https://img.shields.io/github/stars/your-username/Geo-Layout-Transformer.svg" /></a>
|
||||
<a href="https://github.com/your-username/Geo-Layout-Transformer/network/members"><img src="https://img.shields.io/github/forks/your-username/Geo-Layout-Transformer.svg" /></a>
|
||||
<a href="https://github.com/your-username/Geo-Layout-Transformer/issues"><img src="https://img.shields.io/github/issues-raw/your-username/Geo-Layout-Transformer" /></a>
|
||||
<a href="https://github.com/your-username/Geo-Layout-Transformer/issues?q=is%3Aissue+is%3Aclosed"><img src="https://img.shields.io/github/issues-closed-raw/your-username/Geo-Layout-Transformer" /></a>
|
||||
<a><img src="https://img.shields.io/badge/python-3.9%2B-blue" /></a>
|
||||
<a href="http://jiao77.cn:3012/Jiao77/Geo-Layout-Transformer/stargazers"><img src="https://img.shields.io/github/stars/your-username/Geo-Layout-Transformer.svg" /></a>
|
||||
<a href="http://jiao77.cn:3012/Jiao77/Geo-Layout-Transformer/network/members"><img src="https://img.shields.io/github/forks/your-username/Geo-Layout-Transformer.svg" /></a>
|
||||
<a href="http://jiao77.cn:3012/Jiao77/Geo-Layout-Transformer/issues"><img src="https://img.shields.io/github/issues-raw/your-username/Geo-Layout-Transformer" /></a>
|
||||
<a href="http://jiao77.cn:3012/Jiao77/Geo-Layout-Transformer/issues?q=is%3Aissue+is%3Aclosed"><img src="https://img.shields.io/github/issues-closed-raw/your-username/Geo-Layout-Transformer" /></a>
|
||||
<a><img src="https://img.shields.io/badge/python-3.12%2B-blue" /></a>
|
||||
<a><img src="https://img.shields.io/badge/PyTorch-2.x-orange" /></a>
|
||||
</p>
|
||||
|
||||
@@ -19,7 +20,7 @@
|
||||
|
||||
</div>
|
||||
|
||||
# Geo-Layout Transformer 🚀
|
||||
# Geo-Layout Transformer 🚀 🔬
|
||||
|
||||
**A Unified, Self-Supervised Foundation Model for Physical Design Analysis**
|
||||
|
||||
@@ -34,12 +35,12 @@
|
||||
|
||||
## 🖥️ Supported Systems
|
||||
|
||||
- **Python**: 3.9+
|
||||
- **Python**: 3.12+
|
||||
- **OS**: macOS 13+/Apple Silicon, Linux (Ubuntu 20.04/22.04). Windows via **WSL2** recommended
|
||||
- **Frameworks**: PyTorch, PyTorch Geometric (with CUDA optional)
|
||||
- **EDA I/O**: GDSII/OASIS (via `klayout` Python API)
|
||||
|
||||
## 1. Vision
|
||||
## 1. Vision 🎯
|
||||
|
||||
The **Geo-Layout Transformer** is a research project aimed at creating a paradigm shift in Electronic Design Automation (EDA) for physical design. Instead of relying on a fragmented set of heuristic-based tools, we are building a single, unified foundation model that understands the deep, contextual "language" of semiconductor layouts.
|
||||
|
||||
@@ -51,7 +52,7 @@ By leveraging a novel hybrid **Graph Neural Network (GNN) + Transformer** archit
|
||||
|
||||
Our vision is to move from disparate, task-specific tools to a centralized, reusable "Layout Understanding Engine" that accelerates the design cycle and pushes the boundaries of PPA (Power, Performance, and Area).
|
||||
|
||||
## 2. Core Architecture
|
||||
## 2. Core Architecture 🏗️
|
||||
|
||||
The model's architecture is designed to hierarchically process layout information, mimicking how a human expert analyzes a design from local details to global context.
|
||||
|
||||
@@ -93,53 +94,96 @@ Geo-Layout-Transformer/
|
||||
└─ README*.md # English/Chinese documentation
|
||||
```
|
||||
|
||||
## 3. Getting Started
|
||||
## 3. Getting Started ⚙️
|
||||
|
||||
### 3.1. Prerequisites
|
||||
### 3.1. Prerequisites 🧰
|
||||
|
||||
* Python 3.9+
|
||||
* A Conda environment is highly recommended.
|
||||
* Python 3.12+
|
||||
* Dependency management: using uv is recommended for fast, reproducible installs (uv.lock provided). Conda/Python is supported as an alternative.
|
||||
* Access to EDA tools for generating labeled data (e.g., a DRC engine for hotspot labels).
|
||||
|
||||
### 3.2. Installation
|
||||
### 3.2. Installation 🚧
|
||||
|
||||
1. **Clone the repository:**
|
||||
```bash
|
||||
git clone https://github.com/your-username/Geo-Layout-Transformer.git
|
||||
cd Geo-Layout-Transformer
|
||||
```
|
||||
#### A) Using uv (recommended)
|
||||
|
||||
2. **Create and activate the Conda environment:**
|
||||
```bash
|
||||
conda create -n geo_trans python=3.9
|
||||
conda activate geo_trans
|
||||
```
|
||||
1) Install uv (one-time):
|
||||
|
||||
3. **Install dependencies:**
|
||||
This project requires PyTorch and PyTorch Geometric (PyG). Please follow the official installation instructions for your specific CUDA version.
|
||||
```bash
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
```
|
||||
|
||||
* **PyTorch:** [https://pytorch.org/get-started/locally/](https://pytorch.org/get-started/locally/)
|
||||
* **PyG:** [https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html](https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html)
|
||||
2) Clone the repository:
|
||||
|
||||
After installing PyTorch and PyG, install the remaining dependencies:
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
*(Note: You may need to install `klayout` separately via its own package manager or build from source to enable its Python API).*
|
||||
```bash
|
||||
git clone http://jiao77.cn:3012/Jiao77/Geo-Layout-Transformer.git
|
||||
cd Geo-Layout-Transformer
|
||||
```
|
||||
|
||||
3) Ensure Python 3.12 is available (uv can manage it):
|
||||
|
||||
```bash
|
||||
uv python install 3.12
|
||||
```
|
||||
|
||||
4) Create the environment and install dependencies from uv.lock/pyproject:
|
||||
|
||||
```bash
|
||||
uv sync
|
||||
```
|
||||
|
||||
Notes:
|
||||
- For CUDA builds of PyTorch/PyG, follow the official installers first, then install the rest via uv:
|
||||
- PyTorch: https://pytorch.org/get-started/locally/
|
||||
- PyG: https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html
|
||||
After installing the correct Torch/PyG wheels, you may run `uv sync --frozen` to install the remaining packages.
|
||||
- You may need to install `klayout` separately (package manager or from source) to enable its Python API.
|
||||
|
||||
#### B) Using Python/Conda (alternative)
|
||||
|
||||
1) Clone the repository:
|
||||
|
||||
```bash
|
||||
git clone http://jiao77.cn:3012/Jiao77/Geo-Layout-Transformer.git
|
||||
cd Geo-Layout-Transformer
|
||||
```
|
||||
|
||||
2) Create and activate an environment (Conda example):
|
||||
|
||||
```bash
|
||||
conda create -n geo_trans python=3.12
|
||||
conda activate geo_trans
|
||||
```
|
||||
|
||||
3) Install PyTorch and PyTorch Geometric per your CUDA setup:
|
||||
|
||||
- PyTorch: https://pytorch.org/get-started/locally/
|
||||
- PyG: https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html
|
||||
|
||||
4) Install the remaining dependencies:
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
> Tip: GPU is optional. For CPU-only environments, install the CPU variants of PyTorch/PyG.
|
||||
> Note: You may need to install `klayout` separately to enable its Python API.
|
||||
|
||||
## 4. Project Usage
|
||||
## 4. Project Usage 🛠️
|
||||
|
||||
The project workflow is divided into two main stages: data preprocessing and model training.
|
||||
|
||||
### 4.1. Stage 1: Data Preprocessing
|
||||
### 4.1. Stage 1: Data Preprocessing 🧩
|
||||
|
||||
The first step is to convert your GDSII/OASIS files into a graph dataset that the model can consume.
|
||||
|
||||
1. Place your layout files in the `data/gds/` directory.
|
||||
2. Configure the preprocessing parameters in `configs/default.yaml`. You will need to define patch size, stride, layer mappings, and how to construct graph edges.
|
||||
3. Run the preprocessing script:
|
||||
- Using uv (recommended):
|
||||
```bash
|
||||
uv run python scripts/preprocess_gds.py --config-file configs/default.yaml --gds-file data/gds/my_design.gds --output-dir data/processed/my_design/
|
||||
```
|
||||
- Using Python/Conda:
|
||||
```bash
|
||||
python scripts/preprocess_gds.py --config-file configs/default.yaml --gds-file data/gds/my_design.gds --output-dir data/processed/my_design/
|
||||
```
|
||||
@@ -161,20 +205,24 @@ When building a graph for each patch, we now preserve both global and per-patch
|
||||
|
||||
This follows the spirit of LayoutGMN’s structural encoding while staying compatible with our GNN encoder.
|
||||
|
||||
### 4.2. Stage 2: Model Training
|
||||
### 4.2. Stage 2: Model Training 🏋️
|
||||
|
||||
Once the dataset is ready, you can train the Geo-Layout Transformer.
|
||||
|
||||
#### Self-Supervised Pre-training (Recommended)
|
||||
#### Self-Supervised Pre-training (Recommended) ⚡
|
||||
|
||||
To build a powerful foundation model, we first pre-train it on unlabeled data using a "Masked Layout Modeling" task.
|
||||
|
||||
```bash
|
||||
# Using uv (recommended)
|
||||
uv run python main.py --config-file configs/default.yaml --mode pretrain --data-dir data/processed/my_design/
|
||||
|
||||
# Using Python/Conda
|
||||
python main.py --config-file configs/default.yaml --mode pretrain --data-dir data/processed/my_design/
|
||||
```
|
||||
This will train the model to understand the fundamental "grammar" of physical layouts without requiring any expensive labels.
|
||||
|
||||
#### Supervised Fine-tuning
|
||||
#### Supervised Fine-tuning 🎯
|
||||
|
||||
After pre-training, you can fine-tune the model on a smaller, labeled dataset for a specific task like hotspot detection.
|
||||
|
||||
@@ -182,10 +230,14 @@ After pre-training, you can fine-tune the model on a smaller, labeled dataset fo
|
||||
2. Use a task-specific config file (e.g., `hotspot_detection.yaml`) that defines the model head and loss function.
|
||||
3. Run the main script in `train` mode:
|
||||
```bash
|
||||
# Using uv (recommended)
|
||||
uv run python main.py --config-file configs/hotspot_detection.yaml --mode train --data-dir data/processed/labeled_hotspots/ --checkpoint-path /path/to/pretrained_model.pth
|
||||
|
||||
# Using Python/Conda
|
||||
python main.py --config-file configs/hotspot_detection.yaml --mode train --data-dir data/processed/labeled_hotspots/ --checkpoint-path /path/to/pretrained_model.pth
|
||||
```
|
||||
|
||||
## 5. Roadmap & Contribution
|
||||
## 5. Roadmap & Contribution 🗺️
|
||||
|
||||
This project is ambitious and we welcome contributions. Our future roadmap includes:
|
||||
|
||||
@@ -196,7 +248,7 @@ This project is ambitious and we welcome contributions. Our future roadmap inclu
|
||||
|
||||
Please feel free to open an issue or submit a pull request.
|
||||
|
||||
## Acknowledgments
|
||||
## Acknowledgments 🙏
|
||||
|
||||
We stand on the shoulders of open-source communities. This project draws inspiration and/or utilities from:
|
||||
|
||||
@@ -206,7 +258,3 @@ We stand on the shoulders of open-source communities. This project draws inspira
|
||||
- Research works such as LayoutGMN (graph matching for structural similarity) that informed our polygon/graph handling design
|
||||
|
||||
If your work is used and not listed here, please open an issue or PR so we can properly credit you.
|
||||
|
||||
---
|
||||
|
||||
Made with ❤️ for EDA research and open-source collaboration.
|
||||
|
||||
138
README_zh.md
138
README_zh.md
@@ -1,3 +1,4 @@
|
||||
<!-- README_zh.md -->
|
||||
<div align="center">
|
||||
|
||||
<p align="center">
|
||||
@@ -5,11 +6,11 @@
|
||||
</p>
|
||||
|
||||
<p>
|
||||
<a href="https://github.com/your-username/Geo-Layout-Transformer/stargazers"><img src="https://img.shields.io/github/stars/your-username/Geo-Layout-Transformer.svg" /></a>
|
||||
<a href="https://github.com/your-username/Geo-Layout-Transformer/network/members"><img src="https://img.shields.io/github/forks/your-username/Geo-Layout-Transformer.svg" /></a>
|
||||
<a href="https://github.com/your-username/Geo-Layout-Transformer/issues"><img src="https://img.shields.io/github/issues-raw/your-username/Geo-Layout-Transformer" /></a>
|
||||
<a href="https://github.com/your-username/Geo-Layout-Transformer/issues?q=is%3Aissue+is%3Aclosed"><img src="https://img.shields.io/github/issues-closed-raw/your-username/Geo-Layout-Transformer" /></a>
|
||||
<a><img src="https://img.shields.io/badge/python-3.9%2B-blue" /></a>
|
||||
<a href="http://jiao77.cn:3012/Jiao77/Geo-Layout-Transformer/stargazers"><img src="https://img.shields.io/github/stars/your-username/Geo-Layout-Transformer.svg" /></a>
|
||||
<a href="http://jiao77.cn:3012/Jiao77/Geo-Layout-Transformer/network/members"><img src="https://img.shields.io/github/forks/your-username/Geo-Layout-Transformer.svg" /></a>
|
||||
<a href="http://jiao77.cn:3012/Jiao77/Geo-Layout-Transformer/issues"><img src="https://img.shields.io/github/issues-raw/your-username/Geo-Layout-Transformer" /></a>
|
||||
<a href="http://jiao77.cn:3012/Jiao77/Geo-Layout-Transformer/issues?q=is%3Aissue+is%3Aclosed"><img src="https://img.shields.io/github/issues-closed-raw/your-username/Geo-Layout-Transformer" /></a>
|
||||
<a><img src="https://img.shields.io/badge/python-3.12%2B-blue" /></a>
|
||||
<a><img src="https://img.shields.io/badge/PyTorch-2.x-orange" /></a>
|
||||
</p>
|
||||
|
||||
@@ -19,27 +20,27 @@
|
||||
|
||||
</div>
|
||||
|
||||
# Geo-Layout Transformer 🚀
|
||||
# Geo-Layout Transformer 🚀 🔬
|
||||
|
||||
**一个用于物理设计分析的统一、自监督基础模型**
|
||||
|
||||
---
|
||||
|
||||
## ✨ 亮点
|
||||
## ✨ 亮点 🌟
|
||||
|
||||
- **统一基础模型**:覆盖多种物理设计分析任务
|
||||
- **混合 GNN + Transformer**:从局部到全局建模版图语义
|
||||
- **自监督预训练**:在无标签 GDSII/OASIS 上学习强泛化表示
|
||||
- **模块化任务头**:轻松适配(如热点检测、连通性验证)
|
||||
|
||||
## 🖥️ 支持系统
|
||||
## 🖥️ 支持系统 💻
|
||||
|
||||
- **Python**:3.9+
|
||||
- **Python**:3.12+
|
||||
- **操作系统**:macOS 13+/Apple Silicon、Linux(Ubuntu 20.04/22.04)。Windows 建议使用 **WSL2**
|
||||
- **深度学习框架**:PyTorch、PyTorch Geometric(CUDA 可选)
|
||||
- **EDA I/O**:GDSII/OASIS(通过 `klayout` Python API)
|
||||
|
||||
## 1. 项目愿景
|
||||
## 1. 项目愿景 🎯
|
||||
|
||||
**Geo-Layout Transformer** 是一个旨在推动电子设计自动化(EDA)物理设计领域范式转变的研究项目。我们不再依赖于一套零散的、基于启发式规则的工具,而是致力于构建一个统一的基础模型,使其能够理解半导体版图深层次的、上下文相关的“设计语言”。
|
||||
|
||||
@@ -51,7 +52,7 @@
|
||||
|
||||
我们的愿景是,从目前分散的、任务特定的工具,演进为一个集中的、可复用的“版图理解引擎”,从而加速设计周期,并突破 PPA(功耗、性能、面积)的极限。
|
||||
|
||||
## 2. 核心架构
|
||||
## 2. 核心架构 🏗️
|
||||
|
||||
该模型的架构设计旨在分层处理版图信息,模仿人类专家从局部细节到全局上下文分析设计的过程。
|
||||
|
||||
@@ -65,7 +66,7 @@
|
||||
|
||||
4. **特定任务头**:从 Transformer 输出的、具有全局上下文感知能力的最终嵌入,被送入简单、轻量级的神经网络“头”(Head)中,以执行特定的下游任务。这种模块化设计使得核心模型能够以最小的代价适应新的应用。
|
||||
|
||||
## 🧭 项目结构
|
||||
## 🧭 项目结构 📁
|
||||
|
||||
```text
|
||||
Geo-Layout-Transformer/
|
||||
@@ -93,53 +94,96 @@ Geo-Layout-Transformer/
|
||||
└─ README*.md # 中英文文档
|
||||
```
|
||||
|
||||
## 3. 快速上手
|
||||
## 3. 快速上手 ⚙️
|
||||
|
||||
### 3.1. 环境要求
|
||||
### 3.1. 环境要求 🧰
|
||||
|
||||
* Python 3.9+
|
||||
* 强烈建议使用 Conda 进行环境管理。
|
||||
* Python 3.12+
|
||||
* 依赖管理:推荐使用 uv(已提供 uv.lock)来进行快速、可复现的安装;也支持使用 Conda/Python 作为替代方案。
|
||||
* 能够访问 EDA 工具以生成带标签的数据(例如,使用 DRC 工具生成热点标签)。
|
||||
|
||||
### 3.2. 安装步骤
|
||||
### 3.2. 安装步骤 🚧
|
||||
|
||||
1. **克隆代码仓库:**
|
||||
```bash
|
||||
git clone https://github.com/your-username/Geo-Layout-Transformer.git
|
||||
cd Geo-Layout-Transformer
|
||||
```
|
||||
#### A) 使用 uv(推荐)
|
||||
|
||||
2. **创建并激活 Conda 环境:**
|
||||
```bash
|
||||
conda create -n geo_trans python=3.9
|
||||
conda activate geo_trans
|
||||
```
|
||||
1)安装 uv(一次性):
|
||||
|
||||
3. **安装依赖:**
|
||||
本项目需要 PyTorch 和 PyTorch Geometric (PyG)。请根据您的 CUDA 版本遵循官方指南进行安装。
|
||||
```bash
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
```
|
||||
|
||||
* **PyTorch:** [https://pytorch.org/get-started/locally/](https://pytorch.org/get-started/locally/)
|
||||
* **PyG:** [https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html](https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html)
|
||||
2)克隆代码仓库:
|
||||
|
||||
安装完 PyTorch 和 PyG 后,安装其余的依赖项:
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
*(注意:您可能需要通过 `klayout` 自身的包管理器或从源码编译来单独安装它,以启用其 Python API)。*
|
||||
```bash
|
||||
git clone http://jiao77.cn:3012/Jiao77/Geo-Layout-Transformer.git
|
||||
cd Geo-Layout-Transformer
|
||||
```
|
||||
|
||||
3)确保系统可用 Python 3.12(uv 可管理):
|
||||
|
||||
```bash
|
||||
uv python install 3.12
|
||||
```
|
||||
|
||||
4)基于 uv.lock/pyproject 创建环境并安装依赖:
|
||||
|
||||
```bash
|
||||
uv sync
|
||||
```
|
||||
|
||||
说明:
|
||||
- 如需安装带 CUDA 的 PyTorch/PyG,请先根据官方说明安装对应版本,然后再用 uv 安装其余依赖:
|
||||
- PyTorch: https://pytorch.org/get-started/locally/
|
||||
- PyG: https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html
|
||||
正确安装 Torch/PyG 轮子后,可执行 `uv sync --frozen` 安装剩余依赖。
|
||||
- 若需要 `klayout` 的 Python API,可能需要通过其包管理器或源码单独安装。
|
||||
|
||||
#### B) 使用 Python/Conda(备选)
|
||||
|
||||
1)克隆代码仓库:
|
||||
|
||||
```bash
|
||||
git clone http://jiao77.cn:3012/Jiao77/Geo-Layout-Transformer.git
|
||||
cd Geo-Layout-Transformer
|
||||
```
|
||||
|
||||
2)创建并激活环境(以 Conda 为例):
|
||||
|
||||
```bash
|
||||
conda create -n geo_trans python=3.12
|
||||
conda activate geo_trans
|
||||
```
|
||||
|
||||
3)根据 CUDA 环境安装 PyTorch 和 PyTorch Geometric:
|
||||
|
||||
- PyTorch: https://pytorch.org/get-started/locally/
|
||||
- PyG: https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html
|
||||
|
||||
4)安装其余依赖:
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
> 提示:GPU 不是必须的。仅 CPU 环境可安装 PyTorch/PyG 的 CPU 版本。
|
||||
> 说明:如需 `klayout` 的 Python API,可能需要单独安装。
|
||||
|
||||
## 4. 项目使用
|
||||
## 4. 项目使用 🛠️
|
||||
|
||||
项目的工作流程分为两个主要阶段:数据预处理和模型训练。
|
||||
|
||||
### 4.1. 阶段一:数据预处理
|
||||
### 4.1. 阶段一:数据预处理 🧩
|
||||
|
||||
第一步是将您的 GDSII/OASIS 文件转换为模型可以使用的图数据集。
|
||||
|
||||
1. 将您的版图文件放入 `data/gds/` 目录。
|
||||
2. 在 `configs/default.yaml` 中配置预处理参数。您需要定义区块大小、步长、层映射以及图边的构建方式。
|
||||
3. 运行预处理脚本:
|
||||
- 使用 uv(推荐):
|
||||
```bash
|
||||
uv run python scripts/preprocess_gds.py --config-file configs/default.yaml --gds-file data/gds/my_design.gds --output-dir data/processed/my_design/
|
||||
```
|
||||
- 使用 Python/Conda:
|
||||
```bash
|
||||
python scripts/preprocess_gds.py --config-file configs/default.yaml --gds-file data/gds/my_design.gds --output-dir data/processed/my_design/
|
||||
```
|
||||
@@ -161,7 +205,7 @@ Geo-Layout-Transformer/
|
||||
|
||||
该设计借鉴了 LayoutGMN 的结构编码思想,同时与我们现有的 GNN 编码器保持兼容。
|
||||
|
||||
### 4.2. 阶段二:模型训练
|
||||
### 4.2. 阶段二:模型训练 🏋️
|
||||
|
||||
数据集准备就绪后,您就可以开始训练 Geo-Layout Transformer。
|
||||
|
||||
@@ -170,6 +214,10 @@ Geo-Layout-Transformer/
|
||||
为了构建一个强大的基础模型,我们首先在无标签数据上使用“掩码版图建模”任务对其进行预训练。
|
||||
|
||||
```bash
|
||||
# 使用 uv(推荐)
|
||||
uv run python main.py --config-file configs/default.yaml --mode pretrain --data-dir data/processed/my_design/
|
||||
|
||||
# 使用 Python/Conda
|
||||
python main.py --config-file configs/default.yaml --mode pretrain --data-dir data/processed/my_design/
|
||||
```
|
||||
这将训练模型理解物理版图的基本“语法”,而无需任何昂贵的标签。
|
||||
@@ -182,10 +230,14 @@ python main.py --config-file configs/default.yaml --mode pretrain --data-dir dat
|
||||
2. 使用一个特定于任务的配置文件(例如 `hotspot_detection.yaml`),其中定义了模型的任务头和损失函数。
|
||||
3. 在 `train` 模式下运行主脚本:
|
||||
```bash
|
||||
# 使用 uv(推荐)
|
||||
uv run python main.py --config-file configs/hotspot_detection.yaml --mode train --data-dir data/processed/labeled_hotspots/ --checkpoint-path /path/to/pretrained_model.pth
|
||||
|
||||
# 使用 Python/Conda
|
||||
python main.py --config-file configs/hotspot_detection.yaml --mode train --data-dir data/processed/labeled_hotspots/ --checkpoint-path /path/to/pretrained_model.pth
|
||||
```
|
||||
|
||||
## 5. 发展路线与贡献
|
||||
## 5. 发展路线与贡献 🗺️
|
||||
|
||||
这是一个宏伟的项目,我们欢迎任何形式的贡献。我们未来的发展路线图包括:
|
||||
|
||||
@@ -196,7 +248,7 @@ python main.py --config-file configs/default.yaml --mode pretrain --data-dir dat
|
||||
|
||||
欢迎随时提出 Issue 或提交 Pull Request。
|
||||
|
||||
## 致谢
|
||||
## 致谢 🙏
|
||||
|
||||
本项目离不开开源社区的贡献与启发,特别感谢:
|
||||
|
||||
@@ -206,7 +258,3 @@ python main.py --config-file configs/default.yaml --mode pretrain --data-dir dat
|
||||
- 研究工作 LayoutGMN(面向结构相似性的图匹配),启发了我们对多边形/图构建的设计
|
||||
|
||||
若您的工作被本项目使用但尚未列出,欢迎提交 Issue 或 PR 以便完善致谢。
|
||||
|
||||
---
|
||||
|
||||
Made with ❤️ 面向 EDA 研究与开源协作。
|
||||
|
||||
103
TODO.md
Normal file
103
TODO.md
Normal file
@@ -0,0 +1,103 @@
|
||||
# Geo-Layout-Transformer TODOs
|
||||
|
||||
本文件汇总项目目标、架构概览、当前完成度与改进计划,按优先级分组并提供可执行清单(复选框)。
|
||||
|
||||
## 项目目标(简述)
|
||||
- 构建用于物理设计版图理解的统一基础模型,面向热点检测、连通性验证、结构匹配等任务。
|
||||
- 采用“GNN Patch Encoder + 全局 Transformer”的混合架构,支持自监督预训练与任务头微调。
|
||||
|
||||
## 架构概览(对应代码位置)
|
||||
- 数据层:`src/data/`
|
||||
- `gds_parser.py`:GDSII/OASIS 解析、按 patch 裁剪与几何特征提取(使用 gdstk)。
|
||||
- `graph_constructor.py`:从几何对象构建 PyG 图(节点特征、KNN/Radius 边、元信息)。
|
||||
- `dataset.py`:InMemoryDataset 加载处理后的 `.pt` 数据。
|
||||
- 模型层:`src/models/`
|
||||
- `gnn_encoder.py`:可切换 GCN/GraphSAGE/GAT 的 Patch 编码器 + 全局池化。
|
||||
- `transformer_core.py`:Transformer 编码器(正余弦位置编码 + EncoderStack)。
|
||||
- `task_heads.py`:分类/匹配任务头;`geo_layout_transformer.py` 组装端到端模型。
|
||||
- 训练与评估:`src/engine/`
|
||||
- `trainer.py`:监督训练循环(BCEWithLogitsLoss);缺少 focal loss 等实现。
|
||||
- `evaluator.py`:Accuracy/Precision/Recall/F1/AUC 指标计算。
|
||||
- `self_supervised.py`:占位式“掩码版图建模”流程,尚不稳定(见改进项)。
|
||||
- 脚本与入口:
|
||||
- `scripts/preprocess_gds.py`:GDS → 图数据集流水线(保存为 InMemoryDataset)。
|
||||
- `scripts/visualize_attention.py`:注意力可视化占位,需实现细节。
|
||||
- `main.py`:加载配置、构建数据/模型,并在 pretrain/train/eval 模式下运行。
|
||||
- 配置:`configs/default.yaml`、`configs/hotspot_detection.yaml`
|
||||
- 依赖与版本:`pyproject.toml`(Python >=3.12,Torch/PyG 等);锁文件 `uv.lock`。
|
||||
|
||||
## 当前完成度(粗略评估)
|
||||
- 已完成
|
||||
- GDS 解析与 patch 裁剪(含裁剪多边形与面积比例等元信息)。
|
||||
- 图构建(节点几何/层特征,KNN/Radius 边,PyG Data 包装)。
|
||||
- GNN 编码器(GCN/GraphSAGE/GAT)与 Transformer 主干的基本数据流。
|
||||
- 监督训练 Trainer(BCEWithLogitsLoss)、Evaluator 指标管线。
|
||||
- 预处理脚本与 InMemoryDataset 持久化;基础日志与配置装载/合并。
|
||||
- README 中安装/运行指引(推荐 uv;备选 Conda/Pip)。
|
||||
- 进行中/占位
|
||||
- 自监督预训练(self_supervised):掩码策略与维度重塑存在假设,需调通与验证。
|
||||
- 注意力可视化脚本:仅说明性注释,未接入模型权重与实际权重提取。
|
||||
- main.py 数据集切分:目前 train/val 复用同一数据源,留有 TODO。
|
||||
- 缺失/需改进
|
||||
- 任务头与损失的更丰富支持(如 focal loss、class weights、masking/采样)。
|
||||
- 训练循环的验证与早停、最佳模型保存、学习率调度等训练工程化能力。
|
||||
- 自监督目标的严谨实现(mask 索引与 batch/ptr 对齐、掩码、重建头/投影器)。
|
||||
- 可复现实验脚本与最小数据样例;单元测试与快速 CI 校验。
|
||||
- CUDA/大图内存管理(梯度累积、混合精度、GraphSAINT/Cluster-GCN 等)。
|
||||
- 可观测性(TensorBoard/CSVLogger、随机种子、配置溯源与版本记录)。
|
||||
|
||||
## 优先级清单(可执行项)
|
||||
|
||||
### P0(立即优先)
|
||||
- [x] 数据集切分与 DataLoader 管线
|
||||
- 在 `main.py` 引入可配置的 train/val/test 切分比例与随机种子;支持从目录/清单载入各 split。
|
||||
- 为 `configs/default.yaml` 增加 `splits` 字段;更新 `README*` 用法说明。
|
||||
- [x] 监督训练工程化
|
||||
- 在 `trainer.py` 补充验证阶段与最佳模型保存(`torch.save` 至指定路径)。
|
||||
- 引入学习率调度器(如 StepLR/CosineAnnealingWarmRestarts)与早停策略。
|
||||
- 支持 class weights/focal loss:在 `trainer.py` 增加 `focal_loss` 实现并在配置选择。
|
||||
- [x] 自监督预训练修复
|
||||
- 明确 batch 内每图的 patch 序列映射:根据 `batch.ptr` 逐图生成 mask 索引,避免跨图混淆。
|
||||
- 将掩码作用在输入特征/图结构层而非已池化的图级嵌入;或增加“节点级→patch 聚合→重建头”。
|
||||
- 为 `transformer_core` 或单独模块增加重建头(MLP)以回归原 patch 表征;提供单元测试。
|
||||
|
||||
### P1(高优)
|
||||
- [x] 任务头与损失扩展
|
||||
- 在 `task_heads.py` 增加多标签分类、回归头;增添可插拔的池化(CLS token/Mean/Max/Attention Pool)。
|
||||
- 在 `trainer.py` 支持多任务训练配置(不同 head/loss 的加权)。
|
||||
- [x] 训练与日志可观测性
|
||||
- 增加 TensorBoard/CSVLogger;记录 epoch 指标、学习率、耗时;保存 `config` 与 `git` 提交信息。
|
||||
- 固定随机种子(PyTorch/NumPy/环境变量),在 `utils` 中提供 `set_seed()` 并在入口调用。
|
||||
- [x] 可复现实验与最小数据
|
||||
- 提供最小 GDS 示例与对应的 processed `.pt` 小样,便于 CI 与用户快速体验。
|
||||
- 在 `scripts/` 增加一键跑通的小样流程脚本(preprocess→train→eval)。
|
||||
|
||||
### P2(中优)
|
||||
- [x] 大图/性能优化
|
||||
- 引入混合精度(`torch.cuda.amp`)、梯度累积、可选更小 batch,监控显存。
|
||||
- 探索 GraphSAINT/Cluster-GCN 等大图训练策略,并与当前 patch 划分结合。
|
||||
- [ ] I/O 与生态集成
|
||||
- `klayout` Python API 的可选集成与安装脚本说明;解析 OASIS 的路径补全与测试。
|
||||
- 在 `graph_constructor.py` 为边策略加入可学习/基于几何关系的拓展(如跨层连接边)。
|
||||
- [x] 可解释性与可视化
|
||||
- 完成 `scripts/visualize_attention.py`:注册 Hook 提取注意力/特征图,绘图并保存到 `docs/`。
|
||||
- 在 `Data.node_meta` 基础上支持几何叠加可视化(patch bbox 与局部多边形)。
|
||||
|
||||
### P3(后续)
|
||||
- [ ] 更丰富的自监督任务
|
||||
- 对比学习(SimCLR/GraphCL/MaskGIT风格)、上下文预测、旋转/裁剪增广等。
|
||||
- [ ] 生成式方向探索
|
||||
- 以 Transformer 编码为条件,尝试版图片段重建/扩展的生成任务。
|
||||
- [ ] 文档与示例完善
|
||||
- 在 `README*` 增补训练曲线示例、模型结构图与常见问题(FAQ)。
|
||||
|
||||
## 风险与边界条件(建议处理)
|
||||
- 空 patch/稀疏边界:预处理阶段应丢弃无几何或孤立节点过多的 patch,并统计占比。
|
||||
- 类别不平衡:提供正负样本重采样或损失加权;评估报告中输出混淆矩阵与 PR 曲线。
|
||||
- 版本与兼容:已将 Python 要求更新为 3.12+;如需老版本 Python,需回溯依赖并测试。
|
||||
- 随机性:固定随机种子并在日志中写入,以确保结果可复现。
|
||||
|
||||
---
|
||||
|
||||
维护者可按上述优先级推进,每完成一项请勾选对应复选框并在 PR 中引用本条目以便追踪。
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# configs/default.yaml
|
||||
# Default Configuration for Geo-Layout Transformer
|
||||
|
||||
# 1. Data Preprocessing
|
||||
@@ -21,7 +22,7 @@ model:
|
||||
hidden_dim: 128
|
||||
output_dim: 256 # Dimension of the patch embedding
|
||||
num_layers: 4
|
||||
gnn_type: "rgat" # 'rgat', 'gcn', 'graphsage'
|
||||
gnn_type: "gat" # 'gat', 'gcn', 'graphsage'
|
||||
|
||||
# Transformer Backbone
|
||||
transformer:
|
||||
@@ -41,9 +42,25 @@ training:
|
||||
optimizer: "adamw"
|
||||
loss_function: "bce" # 'bce', 'focal_loss'
|
||||
weight_decay: 0.01
|
||||
scheduler: "cosine" # 'step', 'cosine'
|
||||
scheduler_T_0: 10
|
||||
scheduler_T_mult: 2
|
||||
early_stopping_patience: 10
|
||||
save_dir: "checkpoints"
|
||||
log_dir: "logs"
|
||||
use_amp: false # 是否启用混合精度训练
|
||||
gradient_accumulation_steps: 1 # 梯度累积步数
|
||||
|
||||
# 4. Data Splits
|
||||
splits:
|
||||
train_ratio: 0.8
|
||||
val_ratio: 0.1
|
||||
test_ratio: 0.1
|
||||
random_seed: 42
|
||||
|
||||
# 4. Self-Supervised Pre-training
|
||||
pretraining:
|
||||
mask_ratio: 0.15
|
||||
epochs: 200
|
||||
learning_rate: 0.0005
|
||||
early_stopping_patience: 10
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# configs/hotspot_detection.yaml
|
||||
# Hotspot Detection Task Configuration
|
||||
|
||||
# Inherits from default.yaml
|
||||
|
||||
102
examples/generate_sample_data.py
Normal file
102
examples/generate_sample_data.py
Normal file
@@ -0,0 +1,102 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
生成示例数据的脚本
|
||||
- 创建一个简单的 GDS 文件
|
||||
- 使用 preprocess_gds.py 处理它,生成示例数据集
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import gdstk
|
||||
import numpy as np
|
||||
|
||||
# 添加项目根目录到 Python 路径
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
def create_simple_gds(output_file):
|
||||
"""创建一个简单的 GDS 文件,包含几个矩形"""
|
||||
# 创建一个新的库
|
||||
lib = gdstk.Library("simple_layout")
|
||||
|
||||
# 创建一个新的单元
|
||||
top_cell = lib.new_cell("TOP")
|
||||
|
||||
# 在不同层上添加几个矩形
|
||||
# 层 1: 金属层 1
|
||||
rect1 = gdstk.rectangle((0, 0), (10, 10), layer=1, datatype=0)
|
||||
top_cell.add(rect1)
|
||||
|
||||
# 层 2: 过孔层
|
||||
via = gdstk.rectangle((4, 4), (6, 6), layer=2, datatype=0)
|
||||
top_cell.add(via)
|
||||
|
||||
# 层 3: 金属层 2
|
||||
rect2 = gdstk.rectangle((2, 2), (8, 8), layer=3, datatype=0)
|
||||
top_cell.add(rect2)
|
||||
|
||||
# 保存 GDS 文件
|
||||
lib.write_gds(output_file)
|
||||
print(f"已创建 GDS 文件: {output_file}")
|
||||
|
||||
def preprocess_sample_data(gds_file, output_dir):
|
||||
"""使用 preprocess_gds.py 处理 GDS 文件,生成示例数据集"""
|
||||
import subprocess
|
||||
|
||||
# 确保输出目录存在
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# 运行 preprocess_gds.py 脚本
|
||||
script_path = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "scripts", "preprocess_gds.py")
|
||||
|
||||
# 创建层映射配置
|
||||
layer_mapping = {
|
||||
"1/0": 0, # 金属层 1
|
||||
"2/0": 1, # 过孔层
|
||||
"3/0": 2 # 金属层 2
|
||||
}
|
||||
|
||||
# 构建命令
|
||||
cmd = [
|
||||
sys.executable, script_path,
|
||||
"--gds-file", gds_file,
|
||||
"--output-dir", output_dir,
|
||||
"--patch-size", "5.0",
|
||||
"--patch-stride", "2.5"
|
||||
]
|
||||
|
||||
# 添加层映射参数
|
||||
for layer_str, idx in layer_mapping.items():
|
||||
cmd.extend(["--layer-mapping", f"{layer_str}:{idx}"])
|
||||
|
||||
print(f"运行预处理命令: {' '.join(cmd)}")
|
||||
|
||||
# 执行命令
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
|
||||
if result.returncode == 0:
|
||||
print("预处理成功完成!")
|
||||
print("输出:")
|
||||
print(result.stdout)
|
||||
else:
|
||||
print("预处理失败!")
|
||||
print("错误:")
|
||||
print(result.stderr)
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
# 定义路径
|
||||
examples_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
gds_file = os.path.join(examples_dir, "simple_layout.gds")
|
||||
output_dir = os.path.join(examples_dir, "processed_data")
|
||||
|
||||
# 创建 GDS 文件
|
||||
create_simple_gds(gds_file)
|
||||
|
||||
# 预处理数据
|
||||
preprocess_sample_data(gds_file, output_dir)
|
||||
|
||||
print("\n示例数据生成完成!")
|
||||
print(f"GDS 文件: {gds_file}")
|
||||
print(f"处理后的数据: {output_dir}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
89
examples/run_sample_flow.py
Normal file
89
examples/run_sample_flow.py
Normal file
@@ -0,0 +1,89 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
一键运行的小样流程脚本
|
||||
- 生成示例数据
|
||||
- 训练模型
|
||||
- 评估模型
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
# 添加项目根目录到 Python 路径
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
def run_command(cmd, cwd=None):
|
||||
"""运行命令并打印输出"""
|
||||
print(f"\n运行命令: {' '.join(cmd)}")
|
||||
result = subprocess.run(cmd, cwd=cwd, capture_output=True, text=True)
|
||||
print("输出:")
|
||||
print(result.stdout)
|
||||
if result.stderr:
|
||||
print("错误:")
|
||||
print(result.stderr)
|
||||
if result.returncode != 0:
|
||||
print(f"命令执行失败,返回码: {result.returncode}")
|
||||
sys.exit(1)
|
||||
return result
|
||||
|
||||
def generate_sample_data():
|
||||
"""生成示例数据"""
|
||||
print("\n=== 步骤 1: 生成示例数据 ===")
|
||||
script_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "generate_sample_data.py")
|
||||
run_command([sys.executable, script_path])
|
||||
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "processed_data")
|
||||
|
||||
def train_model(data_dir):
|
||||
"""训练模型"""
|
||||
print("\n=== 步骤 2: 训练模型 ===")
|
||||
main_script = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "main.py")
|
||||
config_file = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "configs", "hotspot_detection.yaml")
|
||||
|
||||
# 运行训练命令
|
||||
cmd = [
|
||||
sys.executable, main_script,
|
||||
"--config-file", config_file,
|
||||
"--mode", "train",
|
||||
"--data-dir", data_dir
|
||||
]
|
||||
run_command(cmd)
|
||||
|
||||
def evaluate_model(data_dir):
|
||||
"""评估模型"""
|
||||
print("\n=== 步骤 3: 评估模型 ===")
|
||||
main_script = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "main.py")
|
||||
config_file = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "configs", "hotspot_detection.yaml")
|
||||
|
||||
# 运行评估命令
|
||||
cmd = [
|
||||
sys.executable, main_script,
|
||||
"--config-file", config_file,
|
||||
"--mode", "eval",
|
||||
"--data-dir", data_dir
|
||||
]
|
||||
run_command(cmd)
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
start_time = time.time()
|
||||
|
||||
print("Geo-Layout Transformer 小样流程")
|
||||
print("==============================")
|
||||
|
||||
# 步骤 1: 生成示例数据
|
||||
data_dir = generate_sample_data()
|
||||
|
||||
# 步骤 2: 训练模型
|
||||
train_model(data_dir)
|
||||
|
||||
# 步骤 3: 评估模型
|
||||
evaluate_model(data_dir)
|
||||
|
||||
total_time = time.time() - start_time
|
||||
print(f"\n=== 流程完成 ===")
|
||||
print(f"总耗时: {total_time:.2f} 秒")
|
||||
print("示例流程已成功运行!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
1
examples/simple_layout.gds
Normal file
1
examples/simple_layout.gds
Normal file
File diff suppressed because one or more lines are too long
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
48
main.py
48
main.py
@@ -1,8 +1,10 @@
|
||||
# main.py
|
||||
import argparse
|
||||
from torch.utils.data import random_split
|
||||
|
||||
from src.utils.config_loader import load_config, merge_configs
|
||||
from src.utils.logging import get_logger
|
||||
from src.utils.seed import set_seed
|
||||
from src.data.dataset import LayoutDataset
|
||||
from torch_geometric.data import DataLoader
|
||||
from src.models.geo_layout_transformer import GeoLayoutTransformer
|
||||
@@ -27,21 +29,45 @@ def main():
|
||||
task_config = load_config(args.config_file)
|
||||
config = merge_configs(base_config, task_config)
|
||||
|
||||
# 设置随机种子,确保实验的可重复性
|
||||
random_seed = config['splits']['random_seed']
|
||||
logger.info(f"正在设置随机种子: {random_seed}")
|
||||
set_seed(random_seed)
|
||||
|
||||
# 加载数据
|
||||
logger.info(f"从 {args.data_dir} 加载数据集")
|
||||
dataset = LayoutDataset(root=args.data_dir)
|
||||
|
||||
# TODO: 实现更完善的数据集划分逻辑
|
||||
# 这是一个简化的数据加载方式。在实际应用中,您需要将数据集划分为训练集、验证集和测试集。
|
||||
# 例如:
|
||||
# train_size = int(0.8 * len(dataset))
|
||||
# val_size = len(dataset) - train_size
|
||||
# train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
|
||||
# train_loader = DataLoader(train_dataset, batch_size=config['training']['batch_size'], shuffle=True)
|
||||
# val_loader = DataLoader(val_dataset, batch_size=config['training']['batch_size'], shuffle=False)
|
||||
# 实现数据集划分逻辑
|
||||
logger.info("正在划分数据集...")
|
||||
train_ratio = config['splits']['train_ratio']
|
||||
val_ratio = config['splits']['val_ratio']
|
||||
test_ratio = config['splits']['test_ratio']
|
||||
random_seed = config['splits']['random_seed']
|
||||
|
||||
train_loader = DataLoader(dataset, batch_size=config['training']['batch_size'], shuffle=True)
|
||||
val_loader = DataLoader(dataset, batch_size=config['training']['batch_size'], shuffle=False)
|
||||
# 计算各数据集大小
|
||||
train_size = int(train_ratio * len(dataset))
|
||||
val_size = int(val_ratio * len(dataset))
|
||||
test_size = len(dataset) - train_size - val_size
|
||||
|
||||
# 确保各部分大小合理
|
||||
if test_size < 0:
|
||||
test_size = 0
|
||||
val_size = len(dataset) - train_size
|
||||
|
||||
# 划分数据集
|
||||
train_dataset, val_dataset, test_dataset = random_split(
|
||||
dataset,
|
||||
[train_size, val_size, test_size],
|
||||
generator=torch.Generator().manual_seed(random_seed)
|
||||
)
|
||||
|
||||
# 创建数据加载器
|
||||
train_loader = DataLoader(train_dataset, batch_size=config['training']['batch_size'], shuffle=True)
|
||||
val_loader = DataLoader(val_dataset, batch_size=config['training']['batch_size'], shuffle=False)
|
||||
test_loader = DataLoader(test_dataset, batch_size=config['training']['batch_size'], shuffle=False)
|
||||
|
||||
logger.info(f"数据集划分完成: 训练集 {len(train_dataset)}, 验证集 {len(val_dataset)}, 测试集 {len(test_dataset)}")
|
||||
|
||||
# 初始化模型
|
||||
logger.info("正在初始化模型...")
|
||||
@@ -62,7 +88,7 @@ def main():
|
||||
elif args.mode == 'eval':
|
||||
logger.info("进入评估模式...")
|
||||
evaluator = Evaluator(model)
|
||||
evaluator.evaluate(val_loader)
|
||||
evaluator.evaluate(test_loader)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -4,4 +4,18 @@ version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = []
|
||||
dependencies = [
|
||||
"gdstk>=0.9.61",
|
||||
"numpy>=2.3.2",
|
||||
"pandas>=2.3.2",
|
||||
"pyyaml>=6.0.2",
|
||||
"scikit-learn>=1.7.1",
|
||||
"tensorboard>=2.20.0",
|
||||
"torch>=2.8.0",
|
||||
"torch-geometric>=2.6.1",
|
||||
"torchvision>=0.23.0",
|
||||
]
|
||||
|
||||
[[tool.uv.index]]
|
||||
url = "https://pypi.tuna.tsinghua.edu.cn/simple"
|
||||
default = true
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# requirements.txt
|
||||
torch
|
||||
torch-geometric
|
||||
gdstk
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# scripts/preprocess_gds.py
|
||||
import argparse
|
||||
import os
|
||||
from tqdm import tqdm
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
# scripts/visualize_attention.py
|
||||
import argparse
|
||||
import torch
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
import os
|
||||
|
||||
from src.utils.config_loader import load_config
|
||||
from src.models.geo_layout_transformer import GeoLayoutTransformer
|
||||
@@ -12,52 +14,93 @@ def main():
|
||||
parser.add_argument("--config-file", required=True, help="模型配置文件的路径。")
|
||||
parser.add_argument("--model-path", required=True, help="已训练模型检查点的路径。")
|
||||
parser.add_argument("--patch-data", required=True, help="区块数据样本(.pt 文件)的路径。")
|
||||
parser.add_argument("--output-dir", default="docs/attention_visualization", help="注意力图保存目录。")
|
||||
parser.add_argument("--layer-index", type=int, default=0, help="要可视化的 Transformer 层索引。")
|
||||
parser.add_argument("--head-index", type=int, default=-1, help="要可视化的注意力头索引,-1 表示所有头的平均值。")
|
||||
args = parser.parse_args()
|
||||
|
||||
logger = get_logger("Attention_Visualizer")
|
||||
|
||||
logger.info("这是一个用于注意力可视化的占位符脚本。")
|
||||
logger.info("完整的实现需要加载一个训练好的模型、一个数据样本,然后提取注意力权重。")
|
||||
# 确保输出目录存在
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
# 1. 加载配置和模型
|
||||
# logger.info("正在加载模型...")
|
||||
# config = load_config(args.config_file)
|
||||
# model = GeoLayoutTransformer(config)
|
||||
# model.load_state_dict(torch.load(args.model_path))
|
||||
# model.eval()
|
||||
logger.info("正在加载模型...")
|
||||
config = load_config(args.config_file)
|
||||
model = GeoLayoutTransformer(config)
|
||||
model.load_state_dict(torch.load(args.model_path, map_location=torch.device('cpu')))
|
||||
model.eval()
|
||||
|
||||
# 2. 加载一个数据样本
|
||||
# logger.info(f"正在加载数据样本从 {args.patch_data}")
|
||||
# sample_data = torch.load(args.patch_data)
|
||||
logger.info(f"正在加载数据样本从 {args.patch_data}")
|
||||
sample_data = torch.load(args.patch_data)
|
||||
|
||||
# 3. 注册钩子(Hook)到模型中以提取注意力权重
|
||||
# 这是一个复杂的过程,需要访问 nn.MultiheadAttention 模块的前向传播过程。
|
||||
# attention_weights = []
|
||||
# def hook(module, input, output):
|
||||
# # output[1] 是注意力权重
|
||||
# attention_weights.append(output[1])
|
||||
# model.transformer_core.transformer_encoder.layers[0].self_attn.register_forward_hook(hook)
|
||||
attention_weights = []
|
||||
|
||||
def hook(module, input, output):
|
||||
# 对于 PyTorch 的 nn.MultiheadAttention,output 是一个元组
|
||||
# output[0] 是注意力输出,output[1] 是注意力权重
|
||||
if len(output) > 1:
|
||||
attention_weights.append(output[1])
|
||||
|
||||
# 获取指定层的自注意力模块
|
||||
if hasattr(model.transformer_core.transformer_encoder, 'layers'):
|
||||
layer = model.transformer_core.transformer_encoder.layers[args.layer_index]
|
||||
if hasattr(layer, 'self_attn'):
|
||||
layer.self_attn.register_forward_hook(hook)
|
||||
logger.info(f"已注册钩子到 Transformer 层 {args.layer_index} 的自注意力模块")
|
||||
else:
|
||||
logger.error("找不到自注意力模块")
|
||||
return
|
||||
else:
|
||||
logger.error("找不到 Transformer 层")
|
||||
return
|
||||
|
||||
# 4. 运行一次前向传播以获取权重
|
||||
# logger.info("正在运行前向传播...")
|
||||
# with torch.no_grad():
|
||||
# # 模型需要修改以支持返回注意力权重,或者通过钩子获取
|
||||
# _ = model(sample_data)
|
||||
logger.info("正在运行前向传播...")
|
||||
with torch.no_grad():
|
||||
_ = model(sample_data)
|
||||
|
||||
# 5. 绘制注意力图
|
||||
# if attention_weights:
|
||||
# logger.info("正在绘制注意力图...")
|
||||
# # attention_weights[0] 的形状是 [batch_size, num_heads, seq_len, seq_len]
|
||||
# # 我们取第一项,并在所有头上取平均值
|
||||
# avg_attention = attention_weights[0][0].mean(dim=0).cpu().numpy()
|
||||
# plt.figure(figsize=(10, 10))
|
||||
# sns.heatmap(avg_attention, cmap='viridis')
|
||||
# plt.title("区块之间的平均注意力图")
|
||||
# plt.xlabel("区块索引")
|
||||
# plt.ylabel("区块索引")
|
||||
# plt.show()
|
||||
# else:
|
||||
# logger.warning("未能提取注意力权重。")
|
||||
if attention_weights:
|
||||
logger.info("正在绘制注意力图...")
|
||||
# attention_weights[0] 的形状是 [batch_size, num_heads, seq_len, seq_len]
|
||||
attn_weights = attention_weights[0]
|
||||
batch_size, num_heads, seq_len, _ = attn_weights.shape
|
||||
|
||||
logger.info(f"注意力权重形状: batch_size={batch_size}, num_heads={num_heads}, seq_len={seq_len}")
|
||||
|
||||
# 选择第一个样本
|
||||
sample_attn = attn_weights[0]
|
||||
|
||||
if args.head_index == -1:
|
||||
# 计算所有头的平均值
|
||||
avg_attention = sample_attn.mean(dim=0).cpu().numpy()
|
||||
plt.figure(figsize=(12, 10))
|
||||
sns.heatmap(avg_attention, cmap='viridis', square=True, vmin=0, vmax=1)
|
||||
plt.title(f"所有注意力头的平均注意力图 (Layer {args.layer_index})")
|
||||
plt.xlabel("区块索引")
|
||||
plt.ylabel("区块索引")
|
||||
output_file = os.path.join(args.output_dir, f"attention_layer_{args.layer_index}_avg.png")
|
||||
plt.savefig(output_file, bbox_inches='tight', dpi=150)
|
||||
logger.info(f"已保存平均注意力图到 {output_file}")
|
||||
else:
|
||||
# 可视化指定的注意力头
|
||||
if 0 <= args.head_index < num_heads:
|
||||
head_attention = sample_attn[args.head_index].cpu().numpy()
|
||||
plt.figure(figsize=(12, 10))
|
||||
sns.heatmap(head_attention, cmap='viridis', square=True, vmin=0, vmax=1)
|
||||
plt.title(f"注意力头 {args.head_index} 的注意力图 (Layer {args.layer_index})")
|
||||
plt.xlabel("区块索引")
|
||||
plt.ylabel("区块索引")
|
||||
output_file = os.path.join(args.output_dir, f"attention_layer_{args.layer_index}_head_{args.head_index}.png")
|
||||
plt.savefig(output_file, bbox_inches='tight', dpi=150)
|
||||
logger.info(f"已保存注意力头 {args.head_index} 的注意力图到 {output_file}")
|
||||
else:
|
||||
logger.error(f"注意力头索引 {args.head_index} 超出范围,有效范围是 0-{num_heads-1}")
|
||||
else:
|
||||
logger.warning("未能提取注意力权重。")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# src/data/dataset.py
|
||||
import torch
|
||||
from torch_geometric.data import Dataset, InMemoryDataset
|
||||
import os
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# src/data/gds_parser.py
|
||||
from typing import List, Dict, Tuple
|
||||
import gdstk
|
||||
import numpy as np
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# src/data/graph_constructor.py
|
||||
from typing import List, Dict, Tuple
|
||||
import torch
|
||||
from torch_geometric.data import Data
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
# src/data/init.py
|
||||
|
||||
BIN
src/engine/__pycache__/evaluator.cpython-312.pyc
Normal file
BIN
src/engine/__pycache__/evaluator.cpython-312.pyc
Normal file
Binary file not shown.
BIN
src/engine/__pycache__/self_supervised.cpython-312.pyc
Normal file
BIN
src/engine/__pycache__/self_supervised.cpython-312.pyc
Normal file
Binary file not shown.
BIN
src/engine/__pycache__/trainer.cpython-312.pyc
Normal file
BIN
src/engine/__pycache__/trainer.cpython-312.pyc
Normal file
Binary file not shown.
@@ -1,3 +1,4 @@
|
||||
# src/engine/evaluator.py
|
||||
import torch
|
||||
from torch_geometric.data import DataLoader
|
||||
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
# src/engine/init.py
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
# src/engine/self_supervised.py
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.optim import AdamW
|
||||
from torch_geometric.data import DataLoader
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from ..utils.logging import get_logger
|
||||
import os
|
||||
import time
|
||||
|
||||
class SelfSupervisedTrainer:
|
||||
"""处理自监督预训练循环(掩码版图建模)。"""
|
||||
@@ -15,43 +19,164 @@ class SelfSupervisedTrainer:
|
||||
# 使用均方误差损失来重建嵌入向量
|
||||
self.criterion = nn.MSELoss()
|
||||
|
||||
# 初始化可学习的 [MASK] 嵌入
|
||||
self.mask_embedding = nn.Parameter(torch.randn(config['model']['gnn']['output_dim']))
|
||||
# 将其添加到模型参数中,使其可被优化
|
||||
self.model.register_parameter('mask_embedding', self.mask_embedding)
|
||||
|
||||
# 初始化重建头
|
||||
hidden_dim = config['model']['transformer']['hidden_dim']
|
||||
output_dim = config['model']['gnn']['output_dim']
|
||||
self.reconstruction_head = nn.Sequential(
|
||||
nn.Linear(hidden_dim, hidden_dim),
|
||||
nn.LayerNorm(hidden_dim),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_dim, output_dim)
|
||||
)
|
||||
|
||||
# 确保保存目录存在
|
||||
self.save_dir = config.get('save_dir', 'checkpoints')
|
||||
os.makedirs(self.save_dir, exist_ok=True)
|
||||
|
||||
# 初始化 TensorBoard 日志记录器
|
||||
self.log_dir = config.get('log_dir', 'logs/pretrain')
|
||||
os.makedirs(self.log_dir, exist_ok=True)
|
||||
self.writer = SummaryWriter(log_dir=self.log_dir)
|
||||
|
||||
# 初始化早停相关变量
|
||||
self.best_loss = float('inf')
|
||||
self.patience = config['pretraining'].get('early_stopping_patience', 10)
|
||||
self.counter = 0
|
||||
self.early_stop = False
|
||||
|
||||
# 初始化混合精度训练
|
||||
self.use_amp = config['training'].get('use_amp', False)
|
||||
self.scaler = torch.cuda.amp.GradScaler() if self.use_amp else None
|
||||
|
||||
# 初始化梯度累积
|
||||
self.gradient_accumulation_steps = config['training'].get('gradient_accumulation_steps', 1)
|
||||
if self.gradient_accumulation_steps > 1:
|
||||
self.logger.info(f"启用梯度累积,累积步数: {self.gradient_accumulation_steps}")
|
||||
|
||||
def train_epoch(self, dataloader: DataLoader):
|
||||
"""运行单个预训练周期。"""
|
||||
self.model.train()
|
||||
self.reconstruction_head.train()
|
||||
total_loss = 0
|
||||
mask_ratio = self.config['pretraining']['mask_ratio']
|
||||
|
||||
for batch in dataloader:
|
||||
for i, batch in enumerate(dataloader):
|
||||
# 只有在梯度累积的第一步或不需要累积时才清空梯度
|
||||
if i % self.gradient_accumulation_steps == 0:
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
# 使用混合精度训练
|
||||
if self.use_amp:
|
||||
with torch.cuda.amp.autocast():
|
||||
# 1. 获取原始的区块嵌入(作为重建的目标)
|
||||
with torch.no_grad():
|
||||
original_embeddings = self.model.gnn_encoder(batch)
|
||||
|
||||
# 2. 创建掩码并损坏输入
|
||||
num_patches = original_embeddings.size(0)
|
||||
num_masked = int(mask_ratio * num_patches)
|
||||
# 随机选择要掩盖的区块索引
|
||||
masked_indices = torch.randperm(num_patches)[:num_masked]
|
||||
|
||||
# 创建一个损坏的嵌入副本
|
||||
# 这是一个简化的方法。更稳健的方法是直接在批次数据中掩盖特征。
|
||||
# 在这个占位符中,我们直接掩盖嵌入向量。
|
||||
corrupted_embeddings = original_embeddings.clone()
|
||||
# 创建一个可学习的 [MASK] 嵌入
|
||||
mask_embedding = nn.Parameter(torch.randn(original_embeddings.size(1), device=original_embeddings.device))
|
||||
corrupted_embeddings[masked_indices] = mask_embedding
|
||||
|
||||
# 3. 为 Transformer 重塑形状
|
||||
# 2. 根据 batch.ptr 逐图生成 mask 索引,避免跨图混淆
|
||||
num_graphs = batch.num_graphs
|
||||
nodes_per_graph = batch.ptr[1:] - batch.ptr[:-1]
|
||||
corrupted_embeddings = corrupted_embeddings.view(num_graphs, nodes_per_graph[0], -1)
|
||||
|
||||
# 4. 将损坏的嵌入传入 Transformer 进行重建
|
||||
# 注意:这里只用了 transformer_core,没有用 task_head
|
||||
reconstructed_embeddings = self.model.transformer_core(corrupted_embeddings)
|
||||
# 确保所有图的节点数相同
|
||||
if not torch.all(nodes_per_graph == nodes_per_graph[0]):
|
||||
self.logger.warning("批次中图形的节点数不一致,使用第一个图形的节点数")
|
||||
nodes_per_graph = nodes_per_graph[0]
|
||||
|
||||
# 5. 只在被掩盖的区块上计算损失
|
||||
# 为每个图单独生成掩码
|
||||
all_masked_indices = []
|
||||
for j in range(num_graphs):
|
||||
# 计算当前图的节点在批次中的起始和结束索引
|
||||
start_idx = batch.ptr[j]
|
||||
end_idx = batch.ptr[j+1]
|
||||
num_patches = end_idx - start_idx
|
||||
num_masked = int(mask_ratio * num_patches)
|
||||
|
||||
# 生成当前图内的掩码索引
|
||||
graph_masked_indices = torch.randperm(num_patches)[:num_masked] + start_idx
|
||||
all_masked_indices.append(graph_masked_indices)
|
||||
|
||||
# 合并所有图的掩码索引
|
||||
masked_indices = torch.cat(all_masked_indices)
|
||||
|
||||
# 3. 创建损坏的嵌入
|
||||
corrupted_embeddings = original_embeddings.clone()
|
||||
# 使用可学习的 [MASK] 嵌入
|
||||
corrupted_embeddings[masked_indices] = self.mask_embedding.to(corrupted_embeddings.device)
|
||||
|
||||
# 4. 为 Transformer 重塑形状
|
||||
corrupted_embeddings = corrupted_embeddings.view(num_graphs, nodes_per_graph, -1)
|
||||
|
||||
# 5. 将损坏的嵌入传入 Transformer 进行编码
|
||||
encoded_embeddings = self.model.transformer_core(corrupted_embeddings)
|
||||
|
||||
# 6. 通过重建头生成重建的嵌入
|
||||
reconstructed_embeddings = self.reconstruction_head(encoded_embeddings)
|
||||
|
||||
# 7. 只在被掩盖的区块上计算损失
|
||||
# 将 Transformer 输出和原始嵌入都拉平成 (N, D) 的形状
|
||||
reconstructed_flat = reconstructed_embeddings.view(-1, original_embeddings.size(1))
|
||||
# 只选择被掩盖的那些进行比较
|
||||
loss = self.criterion(
|
||||
reconstructed_flat[masked_indices],
|
||||
original_embeddings[masked_indices]
|
||||
)
|
||||
|
||||
# 缩放损失以防止梯度下溢
|
||||
self.scaler.scale(loss).backward()
|
||||
|
||||
# 只有在累积步数达到设定值时才更新权重
|
||||
if (i + 1) % self.gradient_accumulation_steps == 0:
|
||||
# 取消缩放并更新权重
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
else:
|
||||
# 标准训练流程
|
||||
# 1. 获取原始的区块嵌入(作为重建的目标)
|
||||
original_embeddings = self.model.gnn_encoder(batch)
|
||||
|
||||
# 2. 根据 batch.ptr 逐图生成 mask 索引,避免跨图混淆
|
||||
num_graphs = batch.num_graphs
|
||||
nodes_per_graph = batch.ptr[1:] - batch.ptr[:-1]
|
||||
|
||||
# 确保所有图的节点数相同
|
||||
if not torch.all(nodes_per_graph == nodes_per_graph[0]):
|
||||
self.logger.warning("批次中图形的节点数不一致,使用第一个图形的节点数")
|
||||
nodes_per_graph = nodes_per_graph[0]
|
||||
|
||||
# 为每个图单独生成掩码
|
||||
all_masked_indices = []
|
||||
for j in range(num_graphs):
|
||||
# 计算当前图的节点在批次中的起始和结束索引
|
||||
start_idx = batch.ptr[j]
|
||||
end_idx = batch.ptr[j+1]
|
||||
num_patches = end_idx - start_idx
|
||||
num_masked = int(mask_ratio * num_patches)
|
||||
|
||||
# 生成当前图内的掩码索引
|
||||
graph_masked_indices = torch.randperm(num_patches)[:num_masked] + start_idx
|
||||
all_masked_indices.append(graph_masked_indices)
|
||||
|
||||
# 合并所有图的掩码索引
|
||||
masked_indices = torch.cat(all_masked_indices)
|
||||
|
||||
# 3. 创建损坏的嵌入
|
||||
corrupted_embeddings = original_embeddings.clone()
|
||||
# 使用可学习的 [MASK] 嵌入
|
||||
corrupted_embeddings[masked_indices] = self.mask_embedding.to(corrupted_embeddings.device)
|
||||
|
||||
# 4. 为 Transformer 重塑形状
|
||||
corrupted_embeddings = corrupted_embeddings.view(num_graphs, nodes_per_graph, -1)
|
||||
|
||||
# 5. 将损坏的嵌入传入 Transformer 进行编码
|
||||
encoded_embeddings = self.model.transformer_core(corrupted_embeddings)
|
||||
|
||||
# 6. 通过重建头生成重建的嵌入
|
||||
reconstructed_embeddings = self.reconstruction_head(encoded_embeddings)
|
||||
|
||||
# 7. 只在被掩盖的区块上计算损失
|
||||
# 将 Transformer 输出和原始嵌入都拉平成 (N, D) 的形状
|
||||
reconstructed_flat = reconstructed_embeddings.view(-1, original_embeddings.size(1))
|
||||
# 只选择被掩盖的那些进行比较
|
||||
@@ -61,7 +186,12 @@ class SelfSupervisedTrainer:
|
||||
)
|
||||
|
||||
loss.backward()
|
||||
|
||||
# 只有在累积步数达到设定值时才更新权重
|
||||
if (i + 1) % self.gradient_accumulation_steps == 0:
|
||||
# 更新权重
|
||||
self.optimizer.step()
|
||||
|
||||
total_loss += loss.item()
|
||||
|
||||
avg_loss = total_loss / len(dataloader)
|
||||
@@ -71,7 +201,63 @@ class SelfSupervisedTrainer:
|
||||
def run(self, train_loader: DataLoader):
|
||||
"""运行完整的预训练流程。"""
|
||||
self.logger.info("开始自监督预训练...")
|
||||
start_time = time.time()
|
||||
|
||||
for epoch in range(self.config['pretraining']['epochs']):
|
||||
if self.early_stop:
|
||||
self.logger.info("早停触发,停止预训练。")
|
||||
break
|
||||
|
||||
epoch_start_time = time.time()
|
||||
self.logger.info(f"周期 {epoch+1}/{self.config['pretraining']['epochs']}")
|
||||
self.train_epoch(train_loader)
|
||||
current_loss = self.train_epoch(train_loader)
|
||||
|
||||
# 记录学习率
|
||||
current_lr = self.optimizer.param_groups[0]['lr']
|
||||
|
||||
# 记录到 TensorBoard
|
||||
self.writer.add_scalar('Loss/pretrain', current_loss, epoch)
|
||||
self.writer.add_scalar('Learning Rate', current_lr, epoch)
|
||||
|
||||
# 计算周期耗时
|
||||
epoch_time = time.time() - epoch_start_time
|
||||
self.writer.add_scalar('Time/epoch', epoch_time, epoch)
|
||||
self.logger.info(f"周期耗时: {epoch_time:.2f} 秒")
|
||||
|
||||
# 检查是否需要保存最佳模型
|
||||
if current_loss < self.best_loss:
|
||||
self.best_loss = current_loss
|
||||
self.counter = 0
|
||||
# 保存最佳模型
|
||||
save_path = os.path.join(self.save_dir, 'best_pretrain_model.pth')
|
||||
torch.save({
|
||||
'model_state_dict': self.model.state_dict(),
|
||||
'reconstruction_head_state_dict': self.reconstruction_head.state_dict(),
|
||||
'optimizer_state_dict': self.optimizer.state_dict(),
|
||||
'best_loss': self.best_loss
|
||||
}, save_path)
|
||||
self.logger.info(f"保存最佳预训练模型到 {save_path}")
|
||||
else:
|
||||
self.counter += 1
|
||||
if self.counter >= self.patience:
|
||||
self.early_stop = True
|
||||
self.logger.info(f"预训练损失连续 {self.patience} 个周期未改善,触发早停。")
|
||||
|
||||
# 计算总训练耗时
|
||||
total_time = time.time() - start_time
|
||||
self.logger.info(f"总预训练耗时: {total_time:.2f} 秒")
|
||||
|
||||
# 保存最后一个模型
|
||||
save_path = os.path.join(self.save_dir, 'last_pretrain_model.pth')
|
||||
torch.save({
|
||||
'model_state_dict': self.model.state_dict(),
|
||||
'reconstruction_head_state_dict': self.reconstruction_head.state_dict(),
|
||||
'optimizer_state_dict': self.optimizer.state_dict()
|
||||
}, save_path)
|
||||
self.logger.info(f"保存最后一个预训练模型到 {save_path}")
|
||||
|
||||
# 关闭 TensorBoard SummaryWriter
|
||||
self.writer.close()
|
||||
|
||||
self.logger.info("预训练完成。")
|
||||
self.logger.info(f"最佳预训练损失: {self.best_loss:.4f}")
|
||||
|
||||
@@ -1,8 +1,35 @@
|
||||
# src/engine/trainer.py
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.optim import Adam, AdamW
|
||||
from torch.optim.lr_scheduler import StepLR, CosineAnnealingWarmRestarts
|
||||
from torch_geometric.data import DataLoader
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from ..utils.logging import get_logger
|
||||
from .evaluator import Evaluator
|
||||
import os
|
||||
import time
|
||||
|
||||
class FocalLoss(nn.Module):
|
||||
"""Focal Loss 实现,用于处理类别不平衡问题。"""
|
||||
def __init__(self, alpha=1, gamma=2, reduction='mean'):
|
||||
super(FocalLoss, self).__init__()
|
||||
self.alpha = alpha
|
||||
self.gamma = gamma
|
||||
self.reduction = reduction
|
||||
self.bce_with_logits = nn.BCEWithLogitsLoss(reduction='none')
|
||||
|
||||
def forward(self, inputs, targets):
|
||||
bce_loss = self.bce_with_logits(inputs, targets)
|
||||
pt = torch.exp(-bce_loss)
|
||||
focal_loss = self.alpha * (1 - pt) ** self.gamma * bce_loss
|
||||
|
||||
if self.reduction == 'mean':
|
||||
return focal_loss.mean()
|
||||
elif self.reduction == 'sum':
|
||||
return focal_loss.sum()
|
||||
else:
|
||||
return focal_loss
|
||||
|
||||
class Trainer:
|
||||
"""处理(监督学习)训练循环。"""
|
||||
@@ -24,17 +51,60 @@ class Trainer:
|
||||
if config['training']['loss_function'] == 'bce':
|
||||
# BCEWithLogitsLoss 结合了 Sigmoid 和 BCELoss,更数值稳定
|
||||
self.criterion = nn.BCEWithLogitsLoss()
|
||||
# 在此添加其他损失函数,如 focal loss
|
||||
elif config['training']['loss_function'] == 'focal_loss':
|
||||
self.criterion = FocalLoss()
|
||||
else:
|
||||
raise ValueError(f"不支持的损失函数: {config['training']['loss_function']}")
|
||||
|
||||
# 初始化学习率调度器
|
||||
self.scheduler = None
|
||||
if 'scheduler' in config['training']:
|
||||
scheduler_type = config['training']['scheduler']
|
||||
if scheduler_type == 'step':
|
||||
self.scheduler = StepLR(self.optimizer, step_size=config['training'].get('scheduler_step_size', 30), gamma=config['training'].get('scheduler_gamma', 0.1))
|
||||
elif scheduler_type == 'cosine':
|
||||
self.scheduler = CosineAnnealingWarmRestarts(self.optimizer, T_0=config['training'].get('scheduler_T_0', 10), T_mult=config['training'].get('scheduler_T_mult', 2))
|
||||
|
||||
# 初始化评估器
|
||||
self.evaluator = Evaluator(model)
|
||||
|
||||
# 初始化早停相关变量
|
||||
self.best_val_score = -float('inf')
|
||||
self.patience = config['training'].get('early_stopping_patience', 10)
|
||||
self.counter = 0
|
||||
self.early_stop = False
|
||||
|
||||
# 确保保存目录存在
|
||||
self.save_dir = config.get('save_dir', 'checkpoints')
|
||||
os.makedirs(self.save_dir, exist_ok=True)
|
||||
|
||||
# 初始化 TensorBoard 日志记录器
|
||||
self.log_dir = config.get('log_dir', 'logs')
|
||||
os.makedirs(self.log_dir, exist_ok=True)
|
||||
self.writer = SummaryWriter(log_dir=self.log_dir)
|
||||
|
||||
# 初始化混合精度训练
|
||||
self.use_amp = config['training'].get('use_amp', False)
|
||||
self.scaler = torch.cuda.amp.GradScaler() if self.use_amp else None
|
||||
|
||||
# 初始化梯度累积
|
||||
self.gradient_accumulation_steps = config['training'].get('gradient_accumulation_steps', 1)
|
||||
if self.gradient_accumulation_steps > 1:
|
||||
self.logger.info(f"启用梯度累积,累积步数: {self.gradient_accumulation_steps}")
|
||||
|
||||
def train_epoch(self, dataloader: DataLoader):
|
||||
"""运行单个训练周期(epoch)。"""
|
||||
self.model.train() # 将模型设置为训练模式
|
||||
total_loss = 0
|
||||
for batch in dataloader:
|
||||
self.optimizer.zero_grad() # 清空梯度
|
||||
|
||||
for i, batch in enumerate(dataloader):
|
||||
# 只有在梯度累积的第一步或不需要累积时才清空梯度
|
||||
if i % self.gradient_accumulation_steps == 0:
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
# 使用混合精度训练
|
||||
if self.use_amp:
|
||||
with torch.cuda.amp.autocast():
|
||||
# 前向传播
|
||||
output = self.model(batch)
|
||||
|
||||
@@ -44,8 +114,32 @@ class Trainer:
|
||||
|
||||
# 计算损失
|
||||
loss = self.criterion(output, target)
|
||||
|
||||
# 缩放损失以防止梯度下溢
|
||||
self.scaler.scale(loss).backward()
|
||||
|
||||
# 只有在累积步数达到设定值时才更新权重
|
||||
if (i + 1) % self.gradient_accumulation_steps == 0:
|
||||
# 取消缩放并更新权重
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
else:
|
||||
# 标准训练流程
|
||||
# 前向传播
|
||||
output = self.model(batch)
|
||||
|
||||
# 准备目标标签
|
||||
# 假设标签在图级别,并且需要调整形状以匹配输出
|
||||
target = batch.y.view_as(output)
|
||||
|
||||
# 计算损失
|
||||
loss = self.criterion(output, target)
|
||||
|
||||
# 反向传播
|
||||
loss.backward()
|
||||
|
||||
# 只有在累积步数达到设定值时才更新权重
|
||||
if (i + 1) % self.gradient_accumulation_steps == 0:
|
||||
# 更新权重
|
||||
self.optimizer.step()
|
||||
|
||||
@@ -55,11 +149,79 @@ class Trainer:
|
||||
self.logger.info(f"训练损失: {avg_loss:.4f}")
|
||||
return avg_loss
|
||||
|
||||
def validate(self, dataloader: DataLoader):
|
||||
"""运行验证并返回评估指标。"""
|
||||
self.model.eval() # 将模型设置为评估模式
|
||||
metrics = self.evaluator.evaluate(dataloader)
|
||||
return metrics
|
||||
|
||||
def run(self, train_loader: DataLoader, val_loader: DataLoader):
|
||||
"""运行完整的训练流程。"""
|
||||
self.logger.info("开始训练...")
|
||||
start_time = time.time()
|
||||
|
||||
for epoch in range(self.config['training']['epochs']):
|
||||
if self.early_stop:
|
||||
self.logger.info("早停触发,停止训练。")
|
||||
break
|
||||
|
||||
epoch_start_time = time.time()
|
||||
self.logger.info(f"周期 {epoch+1}/{self.config['training']['epochs']}")
|
||||
self.train_epoch(train_loader)
|
||||
# 在此处添加验证步骤,例如调用 Evaluator
|
||||
|
||||
# 训练一个周期
|
||||
train_loss = self.train_epoch(train_loader)
|
||||
|
||||
# 验证
|
||||
self.logger.info("正在验证...")
|
||||
val_metrics = self.validate(val_loader)
|
||||
|
||||
# 更新学习率调度器
|
||||
current_lr = self.optimizer.param_groups[0]['lr']
|
||||
if self.scheduler:
|
||||
self.scheduler.step()
|
||||
new_lr = self.optimizer.param_groups[0]['lr']
|
||||
self.logger.info(f"学习率从 {current_lr:.6f} 调整为 {new_lr:.6f}")
|
||||
current_lr = new_lr
|
||||
else:
|
||||
self.logger.info(f"当前学习率: {current_lr:.6f}")
|
||||
|
||||
# 记录到 TensorBoard
|
||||
self.writer.add_scalar('Loss/train', train_loss, epoch)
|
||||
for metric_name, metric_value in val_metrics.items():
|
||||
self.writer.add_scalar(f'Metrics/{metric_name}', metric_value, epoch)
|
||||
self.writer.add_scalar('Learning Rate', current_lr, epoch)
|
||||
|
||||
# 计算周期耗时
|
||||
epoch_time = time.time() - epoch_start_time
|
||||
self.writer.add_scalar('Time/epoch', epoch_time, epoch)
|
||||
self.logger.info(f"周期耗时: {epoch_time:.2f} 秒")
|
||||
|
||||
# 检查是否需要保存最佳模型
|
||||
val_score = val_metrics.get('f1', val_metrics.get('accuracy', -1))
|
||||
if val_score > self.best_val_score:
|
||||
self.best_val_score = val_score
|
||||
self.counter = 0
|
||||
# 保存最佳模型
|
||||
save_path = os.path.join(self.save_dir, 'best_model.pth')
|
||||
torch.save(self.model.state_dict(), save_path)
|
||||
self.logger.info(f"保存最佳模型到 {save_path}")
|
||||
else:
|
||||
self.counter += 1
|
||||
if self.counter >= self.patience:
|
||||
self.early_stop = True
|
||||
self.logger.info(f"验证性能连续 {self.patience} 个周期未改善,触发早停。")
|
||||
|
||||
# 计算总训练耗时
|
||||
total_time = time.time() - start_time
|
||||
self.logger.info(f"总训练耗时: {total_time:.2f} 秒")
|
||||
|
||||
# 保存最后一个模型
|
||||
save_path = os.path.join(self.save_dir, 'last_model.pth')
|
||||
torch.save(self.model.state_dict(), save_path)
|
||||
self.logger.info(f"保存最后一个模型到 {save_path}")
|
||||
|
||||
# 关闭 TensorBoard SummaryWriter
|
||||
self.writer.close()
|
||||
|
||||
self.logger.info("训练完成。")
|
||||
self.logger.info(f"最佳验证分数: {self.best_val_score:.4f}")
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
# src/init.py
|
||||
|
||||
BIN
src/models/__pycache__/geo_layout_transformer.cpython-312.pyc
Normal file
BIN
src/models/__pycache__/geo_layout_transformer.cpython-312.pyc
Normal file
Binary file not shown.
BIN
src/models/__pycache__/gnn_encoder.cpython-312.pyc
Normal file
BIN
src/models/__pycache__/gnn_encoder.cpython-312.pyc
Normal file
Binary file not shown.
BIN
src/models/__pycache__/task_heads.cpython-312.pyc
Normal file
BIN
src/models/__pycache__/task_heads.cpython-312.pyc
Normal file
Binary file not shown.
BIN
src/models/__pycache__/transformer_core.cpython-312.pyc
Normal file
BIN
src/models/__pycache__/transformer_core.cpython-312.pyc
Normal file
Binary file not shown.
@@ -1,8 +1,9 @@
|
||||
# src/models/geo_layout_transformer.py
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .gnn_encoder import GNNEncoder
|
||||
from .transformer_core import TransformerCore
|
||||
from .task_heads import ClassificationHead, MatchingHead
|
||||
from .task_heads import ClassificationHead, MultiLabelClassificationHead, RegressionHead, MatchingHead
|
||||
|
||||
class GeoLayoutTransformer(nn.Module):
|
||||
"""完整的 Geo-Layout Transformer 模型。"""
|
||||
@@ -37,16 +38,34 @@ class GeoLayoutTransformer(nn.Module):
|
||||
self.task_head = None
|
||||
if 'task_head' in config['model']:
|
||||
head_config = config['model']['task_head']
|
||||
pooling_type = head_config.get('pooling_type', 'mean')
|
||||
|
||||
if head_config['type'] == 'classification':
|
||||
self.task_head = ClassificationHead(
|
||||
input_dim=head_config['input_dim'],
|
||||
hidden_dim=head_config['hidden_dim'],
|
||||
output_dim=head_config['output_dim']
|
||||
output_dim=head_config['output_dim'],
|
||||
pooling_type=pooling_type
|
||||
)
|
||||
elif head_config['type'] == 'multi_label_classification':
|
||||
self.task_head = MultiLabelClassificationHead(
|
||||
input_dim=head_config['input_dim'],
|
||||
hidden_dim=head_config['hidden_dim'],
|
||||
output_dim=head_config['output_dim'],
|
||||
pooling_type=pooling_type
|
||||
)
|
||||
elif head_config['type'] == 'regression':
|
||||
self.task_head = RegressionHead(
|
||||
input_dim=head_config['input_dim'],
|
||||
hidden_dim=head_config['hidden_dim'],
|
||||
output_dim=head_config['output_dim'],
|
||||
pooling_type=pooling_type
|
||||
)
|
||||
elif head_config['type'] == 'matching':
|
||||
self.task_head = MatchingHead(
|
||||
input_dim=head_config['input_dim'],
|
||||
output_dim=head_config['output_dim']
|
||||
output_dim=head_config['output_dim'],
|
||||
pooling_type=pooling_type
|
||||
)
|
||||
# 可在此处添加其他任务头
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# src/models/gnn_encoder.py
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch_geometric.nn import GCNConv, SAGEConv, GATConv, global_mean_pool
|
||||
@@ -47,15 +48,14 @@ class GNNEncoder(nn.Module):
|
||||
data: 一个 PyTorch Geometric 的 Data 或 Batch 对象。
|
||||
|
||||
Returns:
|
||||
一个代表区块的图级别嵌入的张量。
|
||||
一个代表节点级别的嵌入的张量。
|
||||
"""
|
||||
x, edge_index, batch = data.x, data.edge_index, data.batch
|
||||
x, edge_index = data.x, data.edge_index
|
||||
|
||||
# 通过所有 GNN 层
|
||||
for layer in self.layers:
|
||||
x = layer(x, edge_index)
|
||||
x = torch.relu(x)
|
||||
|
||||
# 全局池化以获得图级别的嵌入
|
||||
graph_embedding = self.readout(x, batch)
|
||||
return graph_embedding
|
||||
# 返回节点级别的嵌入,不进行全局池化
|
||||
return x
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
# src/models/init.py
|
||||
|
||||
@@ -1,11 +1,45 @@
|
||||
# src/models/task_heads.py
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class PoolingLayer(nn.Module):
|
||||
"""可插拔的池化层,支持多种池化策略。"""
|
||||
def __init__(self, pooling_type: str = 'mean'):
|
||||
super(PoolingLayer, self).__init__()
|
||||
self.pooling_type = pooling_type
|
||||
|
||||
# 如果使用注意力池化,需要定义注意力机制
|
||||
if pooling_type == 'attention':
|
||||
self.attention = nn.Linear(1, 1)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: 形状为 [batch_size, seq_len, hidden_dim] 的张量
|
||||
|
||||
Returns:
|
||||
形状为 [batch_size, hidden_dim] 的池化后的张量
|
||||
"""
|
||||
if self.pooling_type == 'mean':
|
||||
return torch.mean(x, dim=1)
|
||||
elif self.pooling_type == 'max':
|
||||
return torch.max(x, dim=1)[0]
|
||||
elif self.pooling_type == 'cls':
|
||||
# 取第一个 token 作为 [CLS] token
|
||||
return x[:, 0, :]
|
||||
elif self.pooling_type == 'attention':
|
||||
# 计算注意力权重
|
||||
weights = self.attention(torch.ones_like(x[:, :, :1])).softmax(dim=1)
|
||||
return (x * weights).sum(dim=1)
|
||||
else:
|
||||
raise ValueError(f"不支持的池化类型: {self.pooling_type}")
|
||||
|
||||
class ClassificationHead(nn.Module):
|
||||
"""一个用于分类任务的简单多层感知机(MLP)任务头。"""
|
||||
|
||||
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
|
||||
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, pooling_type: str = 'mean'):
|
||||
super(ClassificationHead, self).__init__()
|
||||
self.pooling = PoolingLayer(pooling_type)
|
||||
self.fc1 = nn.Linear(input_dim, hidden_dim)
|
||||
self.relu = nn.ReLU()
|
||||
self.fc2 = nn.Linear(hidden_dim, output_dim)
|
||||
@@ -18,9 +52,60 @@ class ClassificationHead(nn.Module):
|
||||
Returns:
|
||||
最终的分类 logits。
|
||||
"""
|
||||
# 我们可以取第一个 token(类似 [CLS])的嵌入,或者进行平均池化
|
||||
# 为简单起见,我们假设在序列维度上进行平均池化
|
||||
x_pooled = torch.mean(x, dim=1)
|
||||
# 使用指定的池化方法
|
||||
x_pooled = self.pooling(x)
|
||||
|
||||
out = self.fc1(x_pooled)
|
||||
out = self.relu(out)
|
||||
out = self.fc2(out)
|
||||
return out
|
||||
|
||||
class MultiLabelClassificationHead(nn.Module):
|
||||
"""用于多标签分类任务的任务头。"""
|
||||
|
||||
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, pooling_type: str = 'mean'):
|
||||
super(MultiLabelClassificationHead, self).__init__()
|
||||
self.pooling = PoolingLayer(pooling_type)
|
||||
self.fc1 = nn.Linear(input_dim, hidden_dim)
|
||||
self.relu = nn.ReLU()
|
||||
self.fc2 = nn.Linear(hidden_dim, output_dim)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: 来自 Transformer 骨干网络的输入张量。
|
||||
|
||||
Returns:
|
||||
最终的多标签分类 logits。
|
||||
"""
|
||||
# 使用指定的池化方法
|
||||
x_pooled = self.pooling(x)
|
||||
|
||||
out = self.fc1(x_pooled)
|
||||
out = self.relu(out)
|
||||
out = self.fc2(out)
|
||||
return out
|
||||
|
||||
class RegressionHead(nn.Module):
|
||||
"""用于回归任务的任务头。"""
|
||||
|
||||
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, pooling_type: str = 'mean'):
|
||||
super(RegressionHead, self).__init__()
|
||||
self.pooling = PoolingLayer(pooling_type)
|
||||
self.fc1 = nn.Linear(input_dim, hidden_dim)
|
||||
self.relu = nn.ReLU()
|
||||
self.fc2 = nn.Linear(hidden_dim, output_dim)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: 来自 Transformer 骨干网络的输入张量。
|
||||
|
||||
Returns:
|
||||
最终的回归输出。
|
||||
"""
|
||||
# 使用指定的池化方法
|
||||
x_pooled = self.pooling(x)
|
||||
|
||||
out = self.fc1(x_pooled)
|
||||
out = self.relu(out)
|
||||
@@ -30,8 +115,9 @@ class ClassificationHead(nn.Module):
|
||||
class MatchingHead(nn.Module):
|
||||
"""用于学习版图匹配的相似性嵌入的任务头。"""
|
||||
|
||||
def __init__(self, input_dim: int, output_dim: int):
|
||||
def __init__(self, input_dim: int, output_dim: int, pooling_type: str = 'mean'):
|
||||
super(MatchingHead, self).__init__()
|
||||
self.pooling = PoolingLayer(pooling_type)
|
||||
self.projection = nn.Linear(input_dim, output_dim)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@@ -42,8 +128,8 @@ class MatchingHead(nn.Module):
|
||||
Returns:
|
||||
代表整个输入图(例如一个 IP 模块)的单个嵌入向量。
|
||||
"""
|
||||
# 全局平均池化,为整个序列获取一个单一的向量
|
||||
graph_embedding = torch.mean(x, dim=1)
|
||||
# 使用指定的池化方法
|
||||
graph_embedding = self.pooling(x)
|
||||
# 投影到最终的嵌入空间
|
||||
similarity_embedding = self.projection(graph_embedding)
|
||||
# 对嵌入进行 L2 归一化,以便使用余弦相似度
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# src/models/transformer_core.py
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import math
|
||||
|
||||
6
src/utils/__init__.py
Normal file
6
src/utils/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
# src/utils/__init__.py
|
||||
from .config_loader import load_config, merge_configs
|
||||
from .logging import get_logger
|
||||
from .seed import set_seed
|
||||
|
||||
__all__ = ['load_config', 'merge_configs', 'get_logger', 'set_seed']
|
||||
BIN
src/utils/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
src/utils/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
src/utils/__pycache__/config_loader.cpython-312.pyc
Normal file
BIN
src/utils/__pycache__/config_loader.cpython-312.pyc
Normal file
Binary file not shown.
BIN
src/utils/__pycache__/logging.cpython-312.pyc
Normal file
BIN
src/utils/__pycache__/logging.cpython-312.pyc
Normal file
Binary file not shown.
BIN
src/utils/__pycache__/seed.cpython-312.pyc
Normal file
BIN
src/utils/__pycache__/seed.cpython-312.pyc
Normal file
Binary file not shown.
@@ -1,3 +1,4 @@
|
||||
# src/utils/config_loader.py
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
# src/utils/init.py
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# src/utils/logging.py
|
||||
import logging
|
||||
import sys
|
||||
|
||||
|
||||
33
src/utils/seed.py
Normal file
33
src/utils/seed.py
Normal file
@@ -0,0 +1,33 @@
|
||||
# src/utils/seed.py
|
||||
import random
|
||||
import numpy as np
|
||||
import torch
|
||||
import os
|
||||
|
||||
|
||||
def set_seed(seed: int = 42):
|
||||
"""
|
||||
设置随机种子,确保实验的可重复性。
|
||||
|
||||
Args:
|
||||
seed: 随机种子值
|
||||
"""
|
||||
# 设置 Python 内置随机种子
|
||||
random.seed(seed)
|
||||
|
||||
# 设置 NumPy 随机种子
|
||||
np.random.seed(seed)
|
||||
|
||||
# 设置 PyTorch 随机种子
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed) # 对于多 GPU 环境
|
||||
|
||||
# 禁用 CUDA 中的确定性算法,以提高性能(可选)
|
||||
# torch.backends.cudnn.deterministic = True
|
||||
# torch.backends.cudnn.benchmark = False
|
||||
|
||||
# 设置环境变量中的随机种子
|
||||
os.environ['PYTHONHASHSEED'] = str(seed)
|
||||
|
||||
print(f"随机种子已设置为: {seed}")
|
||||
199
tests/test_model_run.py
Normal file
199
tests/test_model_run.py
Normal file
@@ -0,0 +1,199 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
测试脚本,用于验证模型是否可以正常跑通,不需要真实数据
|
||||
- 生成随机图数据
|
||||
- 加载模型配置
|
||||
- 初始化模型
|
||||
- 运行前向传播和反向传播
|
||||
- 验证模型是否可以正常工作
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
from torch_geometric.data import Data, Batch
|
||||
|
||||
# 添加项目根目录到 Python 路径
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from src.utils.config_loader import load_config
|
||||
from src.models.geo_layout_transformer import GeoLayoutTransformer
|
||||
from src.engine.trainer import Trainer
|
||||
from src.engine.self_supervised import SelfSupervisedTrainer
|
||||
from src.utils.logging import get_logger
|
||||
|
||||
def generate_random_graph_data(num_graphs=4, num_nodes_per_graph=8, node_feature_dim=5, edge_feature_dim=0):
|
||||
"""
|
||||
生成随机的图数据
|
||||
|
||||
Args:
|
||||
num_graphs: 图的数量
|
||||
num_nodes_per_graph: 每个图的节点数量
|
||||
node_feature_dim: 节点特征维度
|
||||
edge_feature_dim: 边特征维度
|
||||
|
||||
Returns:
|
||||
一个 Batch 对象,包含多个随机生成的图
|
||||
"""
|
||||
graphs = []
|
||||
|
||||
for _ in range(num_graphs):
|
||||
# 生成随机节点特征
|
||||
x = torch.randn(num_nodes_per_graph, node_feature_dim)
|
||||
|
||||
# 生成随机边(完全连接)
|
||||
edge_index = []
|
||||
for i in range(num_nodes_per_graph):
|
||||
for j in range(num_nodes_per_graph):
|
||||
if i != j:
|
||||
edge_index.append([i, j])
|
||||
edge_index = torch.tensor(edge_index, dtype=torch.long).t()
|
||||
|
||||
# 生成随机标签
|
||||
y = torch.randn(1, 1) # 假设是图级别的标签
|
||||
|
||||
# 创建图数据
|
||||
graph = Data(x=x, edge_index=edge_index, y=y)
|
||||
graphs.append(graph)
|
||||
|
||||
# 构建批次
|
||||
batch = Batch.from_data_list(graphs)
|
||||
return batch
|
||||
|
||||
def test_supervised_training():
|
||||
"""测试监督训练"""
|
||||
logger = get_logger("Test_Supervised_Training")
|
||||
logger.info("=== 测试监督训练 ===")
|
||||
|
||||
# 加载配置
|
||||
config = load_config('configs/default.yaml')
|
||||
|
||||
# 生成随机数据
|
||||
batch = generate_random_graph_data()
|
||||
logger.info(f"生成的批次数据: {batch}")
|
||||
logger.info(f"批次大小: {batch.num_graphs}")
|
||||
logger.info(f"总节点数: {batch.num_nodes}")
|
||||
logger.info(f"总边数: {batch.num_edges}")
|
||||
|
||||
# 初始化模型
|
||||
logger.info("初始化模型...")
|
||||
model = GeoLayoutTransformer(config)
|
||||
logger.info("模型初始化成功")
|
||||
|
||||
# 初始化训练器
|
||||
logger.info("初始化训练器...")
|
||||
trainer = Trainer(model, config)
|
||||
logger.info("训练器初始化成功")
|
||||
|
||||
# 测试前向传播
|
||||
logger.info("测试前向传播...")
|
||||
with torch.no_grad():
|
||||
# 先测试 GNN 编码器
|
||||
gnn_output = model.gnn_encoder(batch)
|
||||
logger.info(f"GNN 编码器输出形状: {gnn_output.shape}")
|
||||
|
||||
# 测试形状重塑
|
||||
num_graphs = batch.num_graphs
|
||||
nodes_per_graph = batch.ptr[1:] - batch.ptr[:-1]
|
||||
logger.info(f"每个图的节点数: {nodes_per_graph}")
|
||||
reshaped_embeddings = gnn_output.view(num_graphs, nodes_per_graph[0], -1)
|
||||
logger.info(f"重塑后的嵌入形状: {reshaped_embeddings.shape}")
|
||||
|
||||
# 测试 Transformer 核心
|
||||
transformer_output = model.transformer_core(reshaped_embeddings)
|
||||
logger.info(f"Transformer 输出形状: {transformer_output.shape}")
|
||||
|
||||
# 测试完整模型
|
||||
output = model(batch)
|
||||
logger.info(f"前向传播成功,输出形状: {output.shape}")
|
||||
|
||||
# 测试反向传播
|
||||
logger.info("测试反向传播...")
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
|
||||
optimizer.zero_grad()
|
||||
output = model(batch)
|
||||
|
||||
# 对输出进行全局池化,得到图级别的表示
|
||||
# 从 [batch_size, seq_len, hidden_dim] 变为 [batch_size, hidden_dim]
|
||||
graph_output = output.mean(dim=1)
|
||||
|
||||
# 使用 MSE 损失,只比较前 1 个维度(与 batch.y 形状匹配)
|
||||
loss = torch.nn.MSELoss()(graph_output[:, :1], batch.y)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
logger.info(f"反向传播成功,损失值: {loss.item()}")
|
||||
|
||||
logger.info("监督训练测试完成,模型可以正常工作!")
|
||||
|
||||
def test_self_supervised_training():
|
||||
"""测试自监督训练"""
|
||||
logger = get_logger("Test_Self_Supervised_Training")
|
||||
logger.info("\n=== 测试自监督训练 ===")
|
||||
|
||||
# 加载配置
|
||||
config = load_config('configs/default.yaml')
|
||||
|
||||
# 生成随机数据
|
||||
batch = generate_random_graph_data()
|
||||
logger.info(f"生成的批次数据: {batch}")
|
||||
logger.info(f"批次大小: {batch.num_graphs}")
|
||||
logger.info(f"总节点数: {batch.num_nodes}")
|
||||
logger.info(f"总边数: {batch.num_edges}")
|
||||
|
||||
# 初始化模型
|
||||
logger.info("初始化模型...")
|
||||
model = GeoLayoutTransformer(config)
|
||||
logger.info("模型初始化成功")
|
||||
|
||||
# 初始化自监督训练器
|
||||
logger.info("初始化自监督训练器...")
|
||||
trainer = SelfSupervisedTrainer(model, config)
|
||||
logger.info("自监督训练器初始化成功")
|
||||
|
||||
# 测试前向传播
|
||||
logger.info("测试前向传播...")
|
||||
with torch.no_grad():
|
||||
# 测试 GNN 编码器
|
||||
gnn_output = model.gnn_encoder(batch)
|
||||
logger.info(f"GNN 编码器输出形状: {gnn_output.shape}")
|
||||
|
||||
# 测试 Transformer 核心
|
||||
num_graphs = batch.num_graphs
|
||||
nodes_per_graph = batch.ptr[1:] - batch.ptr[:-1]
|
||||
if not torch.all(nodes_per_graph == nodes_per_graph[0]):
|
||||
logger.warning("批次中图形的节点数不一致,使用第一个图形的节点数")
|
||||
nodes_per_graph = nodes_per_graph[0]
|
||||
|
||||
gnn_output_reshaped = gnn_output.view(num_graphs, nodes_per_graph, -1)
|
||||
transformer_output = model.transformer_core(gnn_output_reshaped)
|
||||
logger.info(f"Transformer 核心输出形状: {transformer_output.shape}")
|
||||
|
||||
# 测试完整模型前向传播
|
||||
logger.info("测试完整模型前向传播...")
|
||||
with torch.no_grad():
|
||||
output = model(batch)
|
||||
logger.info(f"完整模型前向传播成功,输出形状: {output.shape}")
|
||||
|
||||
logger.info("自监督训练测试完成,模型可以正常工作!")
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
logger = get_logger("Test_Model_Run")
|
||||
logger.info("开始测试模型是否可以正常跑通...")
|
||||
|
||||
try:
|
||||
# 测试监督训练
|
||||
test_supervised_training()
|
||||
|
||||
# 测试自监督训练
|
||||
test_self_supervised_training()
|
||||
|
||||
logger.info("\n✅ 所有测试通过,模型可以正常跑通!")
|
||||
logger.info("模型已准备就绪,可以使用真实数据进行训练。")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 测试失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user