initial commit
This commit is contained in:
37
src/data/dataset.py
Normal file
37
src/data/dataset.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import torch
|
||||
from torch_geometric.data import Dataset, InMemoryDataset
|
||||
import os
|
||||
|
||||
class LayoutDataset(InMemoryDataset):
|
||||
"""用于加载预处理后的版图图数据的 PyTorch Geometric 数据集。"""
|
||||
|
||||
def __init__(self, root, transform=None, pre_transform=None):
|
||||
"""
|
||||
Args:
|
||||
root: 数据集应保存的根目录。
|
||||
transform: 一个函数/变换,作用于 `Data` 对象并返回一个转换后的版本。
|
||||
pre_transform: 一个函数/变换,作用于 `Data` 对象并返回一个转换后的版本。
|
||||
"""
|
||||
super(LayoutDataset, self).__init__(root, transform, pre_transform)
|
||||
# 加载已处理的数据
|
||||
self.data, self.slices = torch.load(self.processed_paths[0])
|
||||
|
||||
@property
|
||||
def raw_file_names(self):
|
||||
"""如果 `download()` 返回一个路径列表,这里会返回它们的文件名。"""
|
||||
return [] # 我们不从网络下载原始文件
|
||||
|
||||
@property
|
||||
def processed_file_names(self):
|
||||
"""在 `processed_dir` 目录中必须存在的文件列表,用以跳过处理步骤。"""
|
||||
return ['data.pt']
|
||||
|
||||
def download(self):
|
||||
"""从网上下载原始数据到 `raw_dir` 目录。"""
|
||||
pass # 假设数据是预先处理好的
|
||||
|
||||
def process(self):
|
||||
"""处理原始数据并将其保存到 `processed_dir` 目录。"""
|
||||
# 如果希望在加载时动态处理数据,可以在这里实现 `scripts/preprocess_gds.py` 中的逻辑。
|
||||
# 在我们的框架中,我们假设预处理是通过脚本独立完成的。
|
||||
pass
|
||||
68
src/data/gds_parser.py
Normal file
68
src/data/gds_parser.py
Normal file
@@ -0,0 +1,68 @@
|
||||
from typing import List, Dict, Tuple
|
||||
import gdstk
|
||||
import numpy as np
|
||||
|
||||
class GDSParser:
|
||||
"""解析 GDSII/OASIS 文件,提取指定区块内的版图几何图形。"""
|
||||
|
||||
def __init__(self, gds_file: str, layer_mapping: Dict[str, int]):
|
||||
"""初始化 GDSParser。
|
||||
|
||||
Args:
|
||||
gds_file: GDSII/OASIS 文件的路径。
|
||||
layer_mapping: 一个字典,将 GDS 的层/数据类型字符串(例如 "1/0")映射到整数索引。
|
||||
"""
|
||||
self.gds_file = gds_file
|
||||
self.layer_mapping = layer_mapping
|
||||
# 使用 gdstk 读取 GDS 文件
|
||||
self.library = gdstk.read_gds(gds_file)
|
||||
# 获取顶层单元
|
||||
self.top_cell = self.library.top_level()[0]
|
||||
|
||||
def get_patches(self, patch_size: float, patch_stride: float) -> List[Tuple[float, float, float, float]]:
|
||||
"""生成覆盖整个版图的区块坐标。
|
||||
|
||||
Args:
|
||||
patch_size: 正方形区块的尺寸(单位:微米)。
|
||||
patch_stride: 滑动窗口的步长(单位:微米)。
|
||||
|
||||
Returns:
|
||||
一个包含所有区块边界框 (x_min, y_min, x_max, y_max) 的列表。
|
||||
"""
|
||||
# 获取顶层单元的边界框
|
||||
x_min, y_min, x_max, y_max = self.top_cell.bb()
|
||||
patches = []
|
||||
# 使用步长在 x 和 y 方向上生成区块
|
||||
for x in np.arange(x_min, x_max, patch_stride):
|
||||
for y in np.arange(y_min, y_max, patch_stride):
|
||||
patches.append((x, y, x + patch_size, y + patch_size))
|
||||
return patches
|
||||
|
||||
def extract_geometries_from_patch(self, patch_bbox: Tuple[float, float, float, float]) -> List[Dict]:
|
||||
"""从给定的区块中提取所有几何对象。
|
||||
|
||||
Args:
|
||||
patch_bbox: 区块的边界框 (x_min, y_min, x_max, y_max)。
|
||||
|
||||
Returns:
|
||||
一个字典列表,每个字典代表一个几何对象及其属性(多边形、层、边界框)。
|
||||
"""
|
||||
x_min, y_min, x_max, y_max = patch_bbox
|
||||
# 获取单元内的所有多边形
|
||||
polygons = self.top_cell.get_polygons(by_spec=True)
|
||||
geometries = []
|
||||
# 遍历所有多边形
|
||||
for (layer, datatype), poly_list in polygons.items():
|
||||
layer_str = f"{layer}/{datatype}"
|
||||
# 只处理在 layer_mapping 中定义的层
|
||||
if layer_str in self.layer_mapping:
|
||||
for poly in poly_list:
|
||||
# 简单的边界框相交检查
|
||||
p_xmin, p_ymin, p_xmax, p_ymax = poly.bb()
|
||||
if not (p_xmax < x_min or p_xmin > x_max or p_ymax < y_min or p_ymin > y_max):
|
||||
geometries.append({
|
||||
"polygon": poly,
|
||||
"layer": self.layer_mapping[layer_str],
|
||||
"bbox": (p_xmin, p_ymin, p_xmax, p_ymax)
|
||||
})
|
||||
return geometries
|
||||
83
src/data/graph_constructor.py
Normal file
83
src/data/graph_constructor.py
Normal file
@@ -0,0 +1,83 @@
|
||||
from typing import List, Dict
|
||||
import torch
|
||||
from torch_geometric.data import Data
|
||||
from scipy.spatial import cKDTree
|
||||
import numpy as np
|
||||
|
||||
class GraphConstructor:
|
||||
"""从几何图形列表中构建 PyTorch Geometric 的 Data 对象(即图)。"""
|
||||
|
||||
def __init__(self, edge_strategy: str = "knn", knn_k: int = 8, radius_d: float = 1.0):
|
||||
"""
|
||||
Args:
|
||||
edge_strategy: 创建边的策略('knn' 或 'radius')。
|
||||
knn_k: KNN 策略中的 K(最近邻的数量)。
|
||||
radius_d: 半径图策略中的半径大小。
|
||||
"""
|
||||
self.edge_strategy = edge_strategy
|
||||
self.knn_k = knn_k
|
||||
self.radius_d = radius_d
|
||||
|
||||
def construct_graph(self, geometries: List[Dict], label: int = 0) -> Data:
|
||||
"""为单个区块构建一个图。
|
||||
|
||||
Args:
|
||||
geometries: 来自 GDSParser 的几何图形字典列表。
|
||||
label: 图的标签(例如,0 表示非热点,1 表示热点)。
|
||||
|
||||
Returns:
|
||||
一个 PyTorch Geometric 的 Data 对象。
|
||||
"""
|
||||
# 如果没有几何图形,则返回 None
|
||||
if not geometries:
|
||||
return None
|
||||
|
||||
node_features = []
|
||||
node_positions = []
|
||||
# 提取每个几何图形的特征
|
||||
for geo in geometries:
|
||||
x_min, y_min, x_max, y_max = geo["bbox"]
|
||||
width = x_max - x_min
|
||||
height = y_max - y_min
|
||||
area = width * height
|
||||
centroid_x = x_min + width / 2
|
||||
centroid_y = y_min + height / 2
|
||||
|
||||
# 特征包括:中心点坐标、宽度、高度、面积
|
||||
features = [centroid_x, centroid_y, width, height, area]
|
||||
node_features.append(features)
|
||||
node_positions.append([centroid_x, centroid_y])
|
||||
|
||||
# 将特征和位置转换为 PyTorch 张量
|
||||
x = torch.tensor(node_features, dtype=torch.float)
|
||||
pos = torch.tensor(node_positions, dtype=torch.float)
|
||||
|
||||
# 根据选定的策略创建边
|
||||
edge_index = self._create_edges(pos)
|
||||
|
||||
# 创建图数据对象
|
||||
data = Data(x=x, edge_index=edge_index, pos=pos, y=torch.tensor([label], dtype=torch.float))
|
||||
return data
|
||||
|
||||
def _create_edges(self, node_positions: torch.Tensor) -> torch.Tensor:
|
||||
"""根据选定的策略创建边。"""
|
||||
nodes_np = node_positions.numpy()
|
||||
if self.edge_strategy == "knn":
|
||||
# 使用 cKDTree 进行高效的 K 最近邻搜索
|
||||
tree = cKDTree(nodes_np)
|
||||
# 查询每个点的 k+1 个最近邻(包括自身)
|
||||
dist, ind = tree.query(nodes_np, k=self.knn_k + 1)
|
||||
# 创建边列表,排除自环
|
||||
row = np.repeat(np.arange(len(nodes_np)), self.knn_k)
|
||||
col = ind[:, 1:].flatten()
|
||||
edge_index = torch.tensor([row, col], dtype=torch.long)
|
||||
|
||||
elif self.edge_strategy == "radius":
|
||||
# 使用 cKDTree 查找在指定半径内的所有点对
|
||||
tree = cKDTree(nodes_np)
|
||||
pairs = tree.query_pairs(r=self.radius_d)
|
||||
edge_index = torch.tensor(list(pairs), dtype=torch.long).t().contiguous()
|
||||
else:
|
||||
raise ValueError(f"未知的边构建策略: {self.edge_strategy}")
|
||||
|
||||
return edge_index
|
||||
0
src/data/init.py
Normal file
0
src/data/init.py
Normal file
46
src/engine/evaluator.py
Normal file
46
src/engine/evaluator.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import torch
|
||||
from torch_geometric.data import DataLoader
|
||||
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
|
||||
from ..utils.logging import get_logger
|
||||
|
||||
class Evaluator:
|
||||
"""处理模型评估。"""
|
||||
|
||||
def __init__(self, model):
|
||||
self.model = model
|
||||
self.logger = get_logger(self.__class__.__name__)
|
||||
|
||||
def evaluate(self, dataloader: DataLoader):
|
||||
"""在给定的数据集上评估模型。"""
|
||||
self.model.eval() # 将模型设置为评估模式
|
||||
all_preds = []
|
||||
all_labels = []
|
||||
|
||||
# 在没有梯度计算的上下文中进行评估
|
||||
with torch.no_grad():
|
||||
for batch in dataloader:
|
||||
output = self.model(batch)
|
||||
# 使用 sigmoid 将 logits 转换为概率,然后以 0.5 为阈值进行分类
|
||||
preds = torch.sigmoid(output) > 0.5
|
||||
all_preds.append(preds.cpu())
|
||||
all_labels.append(batch.y.cpu())
|
||||
|
||||
# 将所有批次的预测和标签连接起来
|
||||
all_preds = torch.cat(all_preds).numpy()
|
||||
all_labels = torch.cat(all_labels).numpy()
|
||||
|
||||
# 计算各种评估指标
|
||||
accuracy = accuracy_score(all_labels, all_preds)
|
||||
precision = precision_score(all_labels, all_preds)
|
||||
recall = recall_score(all_labels, all_preds)
|
||||
f1 = f1_score(all_labels, all_preds)
|
||||
auc = roc_auc_score(all_labels, all_preds)
|
||||
|
||||
self.logger.info(f"评估结果:")
|
||||
self.logger.info(f" 准确率 (Accuracy): {accuracy:.4f}")
|
||||
self.logger.info(f" 精确率 (Precision): {precision:.4f}")
|
||||
self.logger.info(f" 召回率 (Recall): {recall:.4f}")
|
||||
self.logger.info(f" F1 分数 (F1-Score): {f1:.4f}")
|
||||
self.logger.info(f" AUC-ROC: {auc:.4f}")
|
||||
|
||||
return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1, "auc": auc}
|
||||
0
src/engine/init.py
Normal file
0
src/engine/init.py
Normal file
77
src/engine/self_supervised.py
Normal file
77
src/engine/self_supervised.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.optim import AdamW
|
||||
from torch_geometric.data import DataLoader
|
||||
from ..utils.logging import get_logger
|
||||
|
||||
class SelfSupervisedTrainer:
|
||||
"""处理自监督预训练循环(掩码版图建模)。"""
|
||||
|
||||
def __init__(self, model, config):
|
||||
self.model = model
|
||||
self.config = config
|
||||
self.logger = get_logger(self.__class__.__name__)
|
||||
self.optimizer = AdamW(self.model.parameters(), lr=config['pretraining']['learning_rate'])
|
||||
# 使用均方误差损失来重建嵌入向量
|
||||
self.criterion = nn.MSELoss()
|
||||
|
||||
def train_epoch(self, dataloader: DataLoader):
|
||||
"""运行单个预训练周期。"""
|
||||
self.model.train()
|
||||
total_loss = 0
|
||||
mask_ratio = self.config['pretraining']['mask_ratio']
|
||||
|
||||
for batch in dataloader:
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
# 1. 获取原始的区块嵌入(作为重建的目标)
|
||||
with torch.no_grad():
|
||||
original_embeddings = self.model.gnn_encoder(batch)
|
||||
|
||||
# 2. 创建掩码并损坏输入
|
||||
num_patches = original_embeddings.size(0)
|
||||
num_masked = int(mask_ratio * num_patches)
|
||||
# 随机选择要掩盖的区块索引
|
||||
masked_indices = torch.randperm(num_patches)[:num_masked]
|
||||
|
||||
# 创建一个损坏的嵌入副本
|
||||
# 这是一个简化的方法。更稳健的方法是直接在批次数据中掩盖特征。
|
||||
# 在这个占位符中,我们直接掩盖嵌入向量。
|
||||
corrupted_embeddings = original_embeddings.clone()
|
||||
# 创建一个可学习的 [MASK] 嵌入
|
||||
mask_embedding = nn.Parameter(torch.randn(original_embeddings.size(1), device=original_embeddings.device))
|
||||
corrupted_embeddings[masked_indices] = mask_embedding
|
||||
|
||||
# 3. 为 Transformer 重塑形状
|
||||
num_graphs = batch.num_graphs
|
||||
nodes_per_graph = batch.ptr[1:] - batch.ptr[:-1]
|
||||
corrupted_embeddings = corrupted_embeddings.view(num_graphs, nodes_per_graph[0], -1)
|
||||
|
||||
# 4. 将损坏的嵌入传入 Transformer 进行重建
|
||||
# 注意:这里只用了 transformer_core,没有用 task_head
|
||||
reconstructed_embeddings = self.model.transformer_core(corrupted_embeddings)
|
||||
|
||||
# 5. 只在被掩盖的区块上计算损失
|
||||
# 将 Transformer 输出和原始嵌入都拉平成 (N, D) 的形状
|
||||
reconstructed_flat = reconstructed_embeddings.view(-1, original_embeddings.size(1))
|
||||
# 只选择被掩盖的那些进行比较
|
||||
loss = self.criterion(
|
||||
reconstructed_flat[masked_indices],
|
||||
original_embeddings[masked_indices]
|
||||
)
|
||||
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
total_loss += loss.item()
|
||||
|
||||
avg_loss = total_loss / len(dataloader)
|
||||
self.logger.info(f"预训练损失: {avg_loss:.4f}")
|
||||
return avg_loss
|
||||
|
||||
def run(self, train_loader: DataLoader):
|
||||
"""运行完整的预训练流程。"""
|
||||
self.logger.info("开始自监督预训练...")
|
||||
for epoch in range(self.config['pretraining']['epochs']):
|
||||
self.logger.info(f"周期 {epoch+1}/{self.config['pretraining']['epochs']}")
|
||||
self.train_epoch(train_loader)
|
||||
self.logger.info("预训练完成。")
|
||||
65
src/engine/trainer.py
Normal file
65
src/engine/trainer.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.optim import Adam, AdamW
|
||||
from torch_geometric.data import DataLoader
|
||||
from ..utils.logging import get_logger
|
||||
|
||||
class Trainer:
|
||||
"""处理(监督学习)训练循环。"""
|
||||
|
||||
def __init__(self, model, config):
|
||||
self.model = model
|
||||
self.config = config
|
||||
self.logger = get_logger(self.__class__.__name__)
|
||||
|
||||
# 根据配置选择优化器
|
||||
if config['training']['optimizer'] == 'adam':
|
||||
self.optimizer = Adam(self.model.parameters(), lr=config['training']['learning_rate'], weight_decay=config['training']['weight_decay'])
|
||||
elif config['training']['optimizer'] == 'adamw':
|
||||
self.optimizer = AdamW(self.model.parameters(), lr=config['training']['learning_rate'], weight_decay=config['training']['weight_decay'])
|
||||
else:
|
||||
raise ValueError(f"不支持的优化器: {config['training']['optimizer']}")
|
||||
|
||||
# 根据配置选择损失函数
|
||||
if config['training']['loss_function'] == 'bce':
|
||||
# BCEWithLogitsLoss 结合了 Sigmoid 和 BCELoss,更数值稳定
|
||||
self.criterion = nn.BCEWithLogitsLoss()
|
||||
# 在此添加其他损失函数,如 focal loss
|
||||
else:
|
||||
raise ValueError(f"不支持的损失函数: {config['training']['loss_function']}")
|
||||
|
||||
def train_epoch(self, dataloader: DataLoader):
|
||||
"""运行单个训练周期(epoch)。"""
|
||||
self.model.train() # 将模型设置为训练模式
|
||||
total_loss = 0
|
||||
for batch in dataloader:
|
||||
self.optimizer.zero_grad() # 清空梯度
|
||||
|
||||
# 前向传播
|
||||
output = self.model(batch)
|
||||
|
||||
# 准备目标标签
|
||||
# 假设标签在图级别,并且需要调整形状以匹配输出
|
||||
target = batch.y.view_as(output)
|
||||
|
||||
# 计算损失
|
||||
loss = self.criterion(output, target)
|
||||
# 反向传播
|
||||
loss.backward()
|
||||
# 更新权重
|
||||
self.optimizer.step()
|
||||
|
||||
total_loss += loss.item()
|
||||
|
||||
avg_loss = total_loss / len(dataloader)
|
||||
self.logger.info(f"训练损失: {avg_loss:.4f}")
|
||||
return avg_loss
|
||||
|
||||
def run(self, train_loader: DataLoader, val_loader: DataLoader):
|
||||
"""运行完整的训练流程。"""
|
||||
self.logger.info("开始训练...")
|
||||
for epoch in range(self.config['training']['epochs']):
|
||||
self.logger.info(f"周期 {epoch+1}/{self.config['training']['epochs']}")
|
||||
self.train_epoch(train_loader)
|
||||
# 在此处添加验证步骤,例如调用 Evaluator
|
||||
self.logger.info("训练完成。")
|
||||
0
src/init.py
Normal file
0
src/init.py
Normal file
84
src/models/geo_layout_transformer.py
Normal file
84
src/models/geo_layout_transformer.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .gnn_encoder import GNNEncoder
|
||||
from .transformer_core import TransformerCore
|
||||
from .task_heads import ClassificationHead, MatchingHead
|
||||
|
||||
class GeoLayoutTransformer(nn.Module):
|
||||
"""完整的 Geo-Layout Transformer 模型。"""
|
||||
|
||||
def __init__(self, config: dict):
|
||||
"""初始化模型。
|
||||
|
||||
Args:
|
||||
config: 包含所有模型超参数的配置字典。
|
||||
"""
|
||||
super(GeoLayoutTransformer, self).__init__()
|
||||
self.config = config
|
||||
|
||||
# 1. GNN 编码器:用于将每个版图区块(patch)编码为嵌入向量
|
||||
self.gnn_encoder = GNNEncoder(
|
||||
node_input_dim=config['model']['gnn']['node_input_dim'],
|
||||
hidden_dim=config['model']['gnn']['hidden_dim'],
|
||||
output_dim=config['model']['gnn']['output_dim'],
|
||||
num_layers=config['model']['gnn']['num_layers'],
|
||||
gnn_type=config['model']['gnn']['gnn_type']
|
||||
)
|
||||
|
||||
# 2. Transformer 骨干网络:用于捕捉区块之间的全局上下文关系
|
||||
self.transformer_core = TransformerCore(
|
||||
hidden_dim=config['model']['transformer']['hidden_dim'],
|
||||
num_layers=config['model']['transformer']['num_layers'],
|
||||
num_heads=config['model']['transformer']['num_heads'],
|
||||
dropout=config['model']['transformer']['dropout']
|
||||
)
|
||||
|
||||
# 3. 特定于任务的头:根据配置动态创建
|
||||
self.task_head = None
|
||||
if 'task_head' in config['model']:
|
||||
head_config = config['model']['task_head']
|
||||
if head_config['type'] == 'classification':
|
||||
self.task_head = ClassificationHead(
|
||||
input_dim=head_config['input_dim'],
|
||||
hidden_dim=head_config['hidden_dim'],
|
||||
output_dim=head_config['output_dim']
|
||||
)
|
||||
elif head_config['type'] == 'matching':
|
||||
self.task_head = MatchingHead(
|
||||
input_dim=head_config['input_dim'],
|
||||
output_dim=head_config['output_dim']
|
||||
)
|
||||
# 可在此处添加其他任务头
|
||||
|
||||
def forward(self, data) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
data: 一个 PyG 的 Batch 对象,包含了一批次的图数据。
|
||||
|
||||
Returns:
|
||||
来自任务头的最终输出张量。
|
||||
"""
|
||||
# 1. 从 GNN 编码器获取区块嵌入
|
||||
# PyG 的 DataLoader 会自动将图数据打包成一个大的 Batch 对象
|
||||
patch_embeddings = self.gnn_encoder(data)
|
||||
|
||||
# 2. 为 Transformer 重塑形状: [batch_size, seq_len, hidden_dim]
|
||||
# 这需要知道批次中每个图包含多少个区块(节点)。
|
||||
# 我们可以从 PyG Batch 对象的 `ptr` 属性中获取此信息。
|
||||
num_graphs = data.num_graphs
|
||||
# `ptr` 记录了每个图的节点数累积和,通过相减得到每个图的节点数
|
||||
nodes_per_graph = data.ptr[1:] - data.ptr[:-1]
|
||||
# 假设批次内所有图的区块数相同(对于我们的滑动窗口方法是成立的)
|
||||
patch_embeddings = patch_embeddings.view(num_graphs, nodes_per_graph[0], -1)
|
||||
|
||||
# 3. 将区块嵌入序列传入 Transformer
|
||||
contextual_embeddings = self.transformer_core(patch_embeddings)
|
||||
|
||||
# 4. 将结果传入任务头
|
||||
if self.task_head:
|
||||
output = self.task_head(contextual_embeddings)
|
||||
else:
|
||||
# 如果没有定义任务头(例如在自监督预训练中),则返回上下文嵌入
|
||||
output = contextual_embeddings
|
||||
|
||||
return output
|
||||
61
src/models/gnn_encoder.py
Normal file
61
src/models/gnn_encoder.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch_geometric.nn import GCNConv, SAGEConv, GATConv, global_mean_pool
|
||||
|
||||
class GNNEncoder(nn.Module):
|
||||
"""基于 GNN 的编码器,用于生成区块(Patch)的嵌入向量。"""
|
||||
|
||||
def __init__(self, node_input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, gnn_type: str = 'gcn'):
|
||||
"""
|
||||
Args:
|
||||
node_input_dim: 输入节点特征的维度。
|
||||
hidden_dim: 隐藏层的维度。
|
||||
output_dim: 输出区块嵌入向量的维度。
|
||||
num_layers: GNN 层的数量。
|
||||
gnn_type: 使用的 GNN 层类型('gcn', 'graphsage', 'gat')。
|
||||
"""
|
||||
super(GNNEncoder, self).__init__()
|
||||
self.layers = nn.ModuleList()
|
||||
# 输入层
|
||||
self.layers.append(self.get_gnn_layer(node_input_dim, hidden_dim, gnn_type))
|
||||
|
||||
# 隐藏层
|
||||
for _ in range(num_layers - 2):
|
||||
self.layers.append(self.get_gnn_layer(hidden_dim, hidden_dim, gnn_type))
|
||||
|
||||
# 输出层
|
||||
self.layers.append(self.get_gnn_layer(hidden_dim, output_dim, gnn_type))
|
||||
|
||||
# 读出函数,用于将节点嵌入聚合为图级别的嵌入
|
||||
self.readout = global_mean_pool
|
||||
|
||||
def get_gnn_layer(self, in_channels, out_channels, gnn_type):
|
||||
"""根据类型获取 GNN 层。"""
|
||||
if gnn_type == 'gcn':
|
||||
return GCNConv(in_channels, out_channels)
|
||||
elif gnn_type == 'graphsage':
|
||||
return SAGEConv(in_channels, out_channels)
|
||||
elif gnn_type == 'gat':
|
||||
# 注意:GATConv 可能需要额外的参数,如 heads
|
||||
return GATConv(in_channels, out_channels)
|
||||
else:
|
||||
raise ValueError(f"不支持的 GNN 类型: {gnn_type}")
|
||||
|
||||
def forward(self, data) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
data: 一个 PyTorch Geometric 的 Data 或 Batch 对象。
|
||||
|
||||
Returns:
|
||||
一个代表区块的图级别嵌入的张量。
|
||||
"""
|
||||
x, edge_index, batch = data.x, data.edge_index, data.batch
|
||||
|
||||
# 通过所有 GNN 层
|
||||
for layer in self.layers:
|
||||
x = layer(x, edge_index)
|
||||
x = torch.relu(x)
|
||||
|
||||
# 全局池化以获得图级别的嵌入
|
||||
graph_embedding = self.readout(x, batch)
|
||||
return graph_embedding
|
||||
0
src/models/init.py
Normal file
0
src/models/init.py
Normal file
51
src/models/task_heads.py
Normal file
51
src/models/task_heads.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class ClassificationHead(nn.Module):
|
||||
"""一个用于分类任务的简单多层感知机(MLP)任务头。"""
|
||||
|
||||
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
|
||||
super(ClassificationHead, self).__init__()
|
||||
self.fc1 = nn.Linear(input_dim, hidden_dim)
|
||||
self.relu = nn.ReLU()
|
||||
self.fc2 = nn.Linear(hidden_dim, output_dim)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: 来自 Transformer 骨干网络的输入张量。
|
||||
|
||||
Returns:
|
||||
最终的分类 logits。
|
||||
"""
|
||||
# 我们可以取第一个 token(类似 [CLS])的嵌入,或者进行平均池化
|
||||
# 为简单起见,我们假设在序列维度上进行平均池化
|
||||
x_pooled = torch.mean(x, dim=1)
|
||||
|
||||
out = self.fc1(x_pooled)
|
||||
out = self.relu(out)
|
||||
out = self.fc2(out)
|
||||
return out
|
||||
|
||||
class MatchingHead(nn.Module):
|
||||
"""用于学习版图匹配的相似性嵌入的任务头。"""
|
||||
|
||||
def __init__(self, input_dim: int, output_dim: int):
|
||||
super(MatchingHead, self).__init__()
|
||||
self.projection = nn.Linear(input_dim, output_dim)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: 来自 Transformer 骨干网络的输入张量。
|
||||
|
||||
Returns:
|
||||
代表整个输入图(例如一个 IP 模块)的单个嵌入向量。
|
||||
"""
|
||||
# 全局平均池化,为整个序列获取一个单一的向量
|
||||
graph_embedding = torch.mean(x, dim=1)
|
||||
# 投影到最终的嵌入空间
|
||||
similarity_embedding = self.projection(graph_embedding)
|
||||
# 对嵌入进行 L2 归一化,以便使用余弦相似度
|
||||
similarity_embedding = nn.functional.normalize(similarity_embedding, p=2, dim=1)
|
||||
return similarity_embedding
|
||||
65
src/models/transformer_core.py
Normal file
65
src/models/transformer_core.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import math
|
||||
|
||||
class PositionalEncoding(nn.Module):
|
||||
"""向输入序列中注入位置信息。"""
|
||||
|
||||
def __init__(self, d_model: int, max_len: int = 5000):
|
||||
super(PositionalEncoding, self).__init__()
|
||||
# 创建一个足够大的位置编码矩阵
|
||||
pe = torch.zeros(max_len, d_model)
|
||||
# 创建位置信息 [max_len, 1]
|
||||
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
||||
# 计算用于正弦和余弦函数的分母项
|
||||
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
||||
# 计算偶数维度的位置编码(使用正弦)
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
# 计算奇数维度的位置编码(使用余弦)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
# 调整形状以匹配输入 [max_len, 1, d_model]
|
||||
pe = pe.unsqueeze(0).transpose(0, 1)
|
||||
# 将 pe 注册为 buffer,这样它不会被视为模型参数
|
||||
self.register_buffer('pe', pe)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: 张量,形状为 [seq_len, batch_size, embedding_dim]
|
||||
"""
|
||||
# 将位置编码加到输入张量上
|
||||
x = x + self.pe[:x.size(0), :]
|
||||
return x
|
||||
|
||||
class TransformerCore(nn.Module):
|
||||
"""用于全局上下文建模的 Transformer 骨干网络。"""
|
||||
|
||||
def __init__(self, hidden_dim: int, num_layers: int, num_heads: int, dropout: float = 0.1):
|
||||
super(TransformerCore, self).__init__()
|
||||
self.pos_encoder = PositionalEncoding(hidden_dim)
|
||||
# 定义 Transformer 编码器层
|
||||
encoder_layers = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=num_heads, dropout=dropout, batch_first=True)
|
||||
# 堆叠多个编码器层形成完整的 Transformer 编码器
|
||||
self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers=num_layers)
|
||||
|
||||
def forward(self, patch_embeddings: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
patch_embeddings: 形状为 [batch_size, seq_len, hidden_dim] 的张量,
|
||||
代表所有区块的嵌入向量。
|
||||
|
||||
Returns:
|
||||
一个形状为 [batch_size, seq_len, hidden_dim] 的、包含全局上下文信息的张量。
|
||||
"""
|
||||
# 注意:PyTorch 的 TransformerEncoderLayer 期望的输入形状是 (seq_len, batch, features)
|
||||
# 如果 batch_first=False,或者 (batch, seq_len, features) 如果 batch_first=True。
|
||||
# 我们的输入是 [batch_size, seq_len, hidden_dim],所以我们设置 batch_first=True。
|
||||
|
||||
# 我们使用的 PositionalEncoding 是为 (seq_len, batch, features) 设计的,所以需要调整一下形状
|
||||
src = patch_embeddings.transpose(0, 1) # 转换为 [seq_len, batch_size, hidden_dim]
|
||||
src = self.pos_encoder(src)
|
||||
src = src.transpose(0, 1) # 转换回 [batch_size, seq_len, hidden_dim]
|
||||
|
||||
# 将带有位置信息的嵌入传入 Transformer
|
||||
output = self.transformer_encoder(src)
|
||||
return output
|
||||
36
src/utils/config_loader.py
Normal file
36
src/utils/config_loader.py
Normal file
@@ -0,0 +1,36 @@
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
|
||||
def load_config(config_file: str) -> dict:
|
||||
"""加载 YAML 配置文件。
|
||||
|
||||
Args:
|
||||
config_file: YAML 配置文件的路径。
|
||||
|
||||
Returns:
|
||||
包含配置信息的字典。
|
||||
"""
|
||||
with open(config_file, 'r', encoding='utf-8') as f:
|
||||
config = yaml.safe_load(f)
|
||||
return config
|
||||
|
||||
def merge_configs(base_config: dict, task_config: dict) -> dict:
|
||||
"""将特定于任务的配置合并到基础配置中。
|
||||
|
||||
Args:
|
||||
base_config: 基础(默认)配置。
|
||||
task_config: 要合并的特定于任务的配置。
|
||||
|
||||
Returns:
|
||||
合并后的配置字典。
|
||||
"""
|
||||
merged = base_config.copy() # 复制基础配置
|
||||
# 遍历任务配置中的键值对
|
||||
for key, value in task_config.items():
|
||||
# 如果值是字典且键也存在于合并后的配置中,则递归合并
|
||||
if isinstance(value, dict) and key in merged and isinstance(merged[key], dict):
|
||||
merged[key] = merge_configs(merged[key], value)
|
||||
# 否则,直接用任务配置的值覆盖
|
||||
else:
|
||||
merged[key] = value
|
||||
return merged
|
||||
0
src/utils/init.py
Normal file
0
src/utils/init.py
Normal file
31
src/utils/logging.py
Normal file
31
src/utils/logging.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import logging
|
||||
import sys
|
||||
|
||||
def get_logger(name: str, level=logging.INFO) -> logging.Logger:
|
||||
"""创建并配置一个日志记录器。
|
||||
|
||||
Args:
|
||||
name: 日志记录器的名称。
|
||||
level: 日志记录级别。
|
||||
|
||||
Returns:
|
||||
一个配置好的日志记录器实例。
|
||||
"""
|
||||
# 获取指定名称的日志记录器
|
||||
logger = logging.getLogger(name)
|
||||
# 设置日志记录器的级别
|
||||
logger.setLevel(level)
|
||||
|
||||
# 创建一个处理器,用于将日志记录输出到标准输出
|
||||
handler = logging.StreamHandler(sys.stdout)
|
||||
handler.setLevel(level)
|
||||
|
||||
# 创建一个格式化器,并将其添加到处理器
|
||||
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
# 将处理器添加到日志记录器(如果尚未添加)
|
||||
if not logger.handlers:
|
||||
logger.addHandler(handler)
|
||||
|
||||
return logger
|
||||
Reference in New Issue
Block a user