initial commit

This commit is contained in:
Jiao77
2025-08-25 17:54:08 +08:00
commit f187abe72a
28 changed files with 1703 additions and 0 deletions

37
src/data/dataset.py Normal file
View File

@@ -0,0 +1,37 @@
import torch
from torch_geometric.data import Dataset, InMemoryDataset
import os
class LayoutDataset(InMemoryDataset):
"""用于加载预处理后的版图图数据的 PyTorch Geometric 数据集。"""
def __init__(self, root, transform=None, pre_transform=None):
"""
Args:
root: 数据集应保存的根目录。
transform: 一个函数/变换,作用于 `Data` 对象并返回一个转换后的版本。
pre_transform: 一个函数/变换,作用于 `Data` 对象并返回一个转换后的版本。
"""
super(LayoutDataset, self).__init__(root, transform, pre_transform)
# 加载已处理的数据
self.data, self.slices = torch.load(self.processed_paths[0])
@property
def raw_file_names(self):
"""如果 `download()` 返回一个路径列表,这里会返回它们的文件名。"""
return [] # 我们不从网络下载原始文件
@property
def processed_file_names(self):
"""在 `processed_dir` 目录中必须存在的文件列表,用以跳过处理步骤。"""
return ['data.pt']
def download(self):
"""从网上下载原始数据到 `raw_dir` 目录。"""
pass # 假设数据是预先处理好的
def process(self):
"""处理原始数据并将其保存到 `processed_dir` 目录。"""
# 如果希望在加载时动态处理数据,可以在这里实现 `scripts/preprocess_gds.py` 中的逻辑。
# 在我们的框架中,我们假设预处理是通过脚本独立完成的。
pass

68
src/data/gds_parser.py Normal file
View File

@@ -0,0 +1,68 @@
from typing import List, Dict, Tuple
import gdstk
import numpy as np
class GDSParser:
"""解析 GDSII/OASIS 文件,提取指定区块内的版图几何图形。"""
def __init__(self, gds_file: str, layer_mapping: Dict[str, int]):
"""初始化 GDSParser。
Args:
gds_file: GDSII/OASIS 文件的路径。
layer_mapping: 一个字典,将 GDS 的层/数据类型字符串(例如 "1/0")映射到整数索引。
"""
self.gds_file = gds_file
self.layer_mapping = layer_mapping
# 使用 gdstk 读取 GDS 文件
self.library = gdstk.read_gds(gds_file)
# 获取顶层单元
self.top_cell = self.library.top_level()[0]
def get_patches(self, patch_size: float, patch_stride: float) -> List[Tuple[float, float, float, float]]:
"""生成覆盖整个版图的区块坐标。
Args:
patch_size: 正方形区块的尺寸(单位:微米)。
patch_stride: 滑动窗口的步长(单位:微米)。
Returns:
一个包含所有区块边界框 (x_min, y_min, x_max, y_max) 的列表。
"""
# 获取顶层单元的边界框
x_min, y_min, x_max, y_max = self.top_cell.bb()
patches = []
# 使用步长在 x 和 y 方向上生成区块
for x in np.arange(x_min, x_max, patch_stride):
for y in np.arange(y_min, y_max, patch_stride):
patches.append((x, y, x + patch_size, y + patch_size))
return patches
def extract_geometries_from_patch(self, patch_bbox: Tuple[float, float, float, float]) -> List[Dict]:
"""从给定的区块中提取所有几何对象。
Args:
patch_bbox: 区块的边界框 (x_min, y_min, x_max, y_max)。
Returns:
一个字典列表,每个字典代表一个几何对象及其属性(多边形、层、边界框)。
"""
x_min, y_min, x_max, y_max = patch_bbox
# 获取单元内的所有多边形
polygons = self.top_cell.get_polygons(by_spec=True)
geometries = []
# 遍历所有多边形
for (layer, datatype), poly_list in polygons.items():
layer_str = f"{layer}/{datatype}"
# 只处理在 layer_mapping 中定义的层
if layer_str in self.layer_mapping:
for poly in poly_list:
# 简单的边界框相交检查
p_xmin, p_ymin, p_xmax, p_ymax = poly.bb()
if not (p_xmax < x_min or p_xmin > x_max or p_ymax < y_min or p_ymin > y_max):
geometries.append({
"polygon": poly,
"layer": self.layer_mapping[layer_str],
"bbox": (p_xmin, p_ymin, p_xmax, p_ymax)
})
return geometries

View File

@@ -0,0 +1,83 @@
from typing import List, Dict
import torch
from torch_geometric.data import Data
from scipy.spatial import cKDTree
import numpy as np
class GraphConstructor:
"""从几何图形列表中构建 PyTorch Geometric 的 Data 对象(即图)。"""
def __init__(self, edge_strategy: str = "knn", knn_k: int = 8, radius_d: float = 1.0):
"""
Args:
edge_strategy: 创建边的策略('knn''radius')。
knn_k: KNN 策略中的 K最近邻的数量
radius_d: 半径图策略中的半径大小。
"""
self.edge_strategy = edge_strategy
self.knn_k = knn_k
self.radius_d = radius_d
def construct_graph(self, geometries: List[Dict], label: int = 0) -> Data:
"""为单个区块构建一个图。
Args:
geometries: 来自 GDSParser 的几何图形字典列表。
label: 图的标签例如0 表示非热点1 表示热点)。
Returns:
一个 PyTorch Geometric 的 Data 对象。
"""
# 如果没有几何图形,则返回 None
if not geometries:
return None
node_features = []
node_positions = []
# 提取每个几何图形的特征
for geo in geometries:
x_min, y_min, x_max, y_max = geo["bbox"]
width = x_max - x_min
height = y_max - y_min
area = width * height
centroid_x = x_min + width / 2
centroid_y = y_min + height / 2
# 特征包括:中心点坐标、宽度、高度、面积
features = [centroid_x, centroid_y, width, height, area]
node_features.append(features)
node_positions.append([centroid_x, centroid_y])
# 将特征和位置转换为 PyTorch 张量
x = torch.tensor(node_features, dtype=torch.float)
pos = torch.tensor(node_positions, dtype=torch.float)
# 根据选定的策略创建边
edge_index = self._create_edges(pos)
# 创建图数据对象
data = Data(x=x, edge_index=edge_index, pos=pos, y=torch.tensor([label], dtype=torch.float))
return data
def _create_edges(self, node_positions: torch.Tensor) -> torch.Tensor:
"""根据选定的策略创建边。"""
nodes_np = node_positions.numpy()
if self.edge_strategy == "knn":
# 使用 cKDTree 进行高效的 K 最近邻搜索
tree = cKDTree(nodes_np)
# 查询每个点的 k+1 个最近邻(包括自身)
dist, ind = tree.query(nodes_np, k=self.knn_k + 1)
# 创建边列表,排除自环
row = np.repeat(np.arange(len(nodes_np)), self.knn_k)
col = ind[:, 1:].flatten()
edge_index = torch.tensor([row, col], dtype=torch.long)
elif self.edge_strategy == "radius":
# 使用 cKDTree 查找在指定半径内的所有点对
tree = cKDTree(nodes_np)
pairs = tree.query_pairs(r=self.radius_d)
edge_index = torch.tensor(list(pairs), dtype=torch.long).t().contiguous()
else:
raise ValueError(f"未知的边构建策略: {self.edge_strategy}")
return edge_index

0
src/data/init.py Normal file
View File

46
src/engine/evaluator.py Normal file
View File

@@ -0,0 +1,46 @@
import torch
from torch_geometric.data import DataLoader
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
from ..utils.logging import get_logger
class Evaluator:
"""处理模型评估。"""
def __init__(self, model):
self.model = model
self.logger = get_logger(self.__class__.__name__)
def evaluate(self, dataloader: DataLoader):
"""在给定的数据集上评估模型。"""
self.model.eval() # 将模型设置为评估模式
all_preds = []
all_labels = []
# 在没有梯度计算的上下文中进行评估
with torch.no_grad():
for batch in dataloader:
output = self.model(batch)
# 使用 sigmoid 将 logits 转换为概率,然后以 0.5 为阈值进行分类
preds = torch.sigmoid(output) > 0.5
all_preds.append(preds.cpu())
all_labels.append(batch.y.cpu())
# 将所有批次的预测和标签连接起来
all_preds = torch.cat(all_preds).numpy()
all_labels = torch.cat(all_labels).numpy()
# 计算各种评估指标
accuracy = accuracy_score(all_labels, all_preds)
precision = precision_score(all_labels, all_preds)
recall = recall_score(all_labels, all_preds)
f1 = f1_score(all_labels, all_preds)
auc = roc_auc_score(all_labels, all_preds)
self.logger.info(f"评估结果:")
self.logger.info(f" 准确率 (Accuracy): {accuracy:.4f}")
self.logger.info(f" 精确率 (Precision): {precision:.4f}")
self.logger.info(f" 召回率 (Recall): {recall:.4f}")
self.logger.info(f" F1 分数 (F1-Score): {f1:.4f}")
self.logger.info(f" AUC-ROC: {auc:.4f}")
return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1, "auc": auc}

0
src/engine/init.py Normal file
View File

View File

@@ -0,0 +1,77 @@
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch_geometric.data import DataLoader
from ..utils.logging import get_logger
class SelfSupervisedTrainer:
"""处理自监督预训练循环(掩码版图建模)。"""
def __init__(self, model, config):
self.model = model
self.config = config
self.logger = get_logger(self.__class__.__name__)
self.optimizer = AdamW(self.model.parameters(), lr=config['pretraining']['learning_rate'])
# 使用均方误差损失来重建嵌入向量
self.criterion = nn.MSELoss()
def train_epoch(self, dataloader: DataLoader):
"""运行单个预训练周期。"""
self.model.train()
total_loss = 0
mask_ratio = self.config['pretraining']['mask_ratio']
for batch in dataloader:
self.optimizer.zero_grad()
# 1. 获取原始的区块嵌入(作为重建的目标)
with torch.no_grad():
original_embeddings = self.model.gnn_encoder(batch)
# 2. 创建掩码并损坏输入
num_patches = original_embeddings.size(0)
num_masked = int(mask_ratio * num_patches)
# 随机选择要掩盖的区块索引
masked_indices = torch.randperm(num_patches)[:num_masked]
# 创建一个损坏的嵌入副本
# 这是一个简化的方法。更稳健的方法是直接在批次数据中掩盖特征。
# 在这个占位符中,我们直接掩盖嵌入向量。
corrupted_embeddings = original_embeddings.clone()
# 创建一个可学习的 [MASK] 嵌入
mask_embedding = nn.Parameter(torch.randn(original_embeddings.size(1), device=original_embeddings.device))
corrupted_embeddings[masked_indices] = mask_embedding
# 3. 为 Transformer 重塑形状
num_graphs = batch.num_graphs
nodes_per_graph = batch.ptr[1:] - batch.ptr[:-1]
corrupted_embeddings = corrupted_embeddings.view(num_graphs, nodes_per_graph[0], -1)
# 4. 将损坏的嵌入传入 Transformer 进行重建
# 注意:这里只用了 transformer_core没有用 task_head
reconstructed_embeddings = self.model.transformer_core(corrupted_embeddings)
# 5. 只在被掩盖的区块上计算损失
# 将 Transformer 输出和原始嵌入都拉平成 (N, D) 的形状
reconstructed_flat = reconstructed_embeddings.view(-1, original_embeddings.size(1))
# 只选择被掩盖的那些进行比较
loss = self.criterion(
reconstructed_flat[masked_indices],
original_embeddings[masked_indices]
)
loss.backward()
self.optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(dataloader)
self.logger.info(f"预训练损失: {avg_loss:.4f}")
return avg_loss
def run(self, train_loader: DataLoader):
"""运行完整的预训练流程。"""
self.logger.info("开始自监督预训练...")
for epoch in range(self.config['pretraining']['epochs']):
self.logger.info(f"周期 {epoch+1}/{self.config['pretraining']['epochs']}")
self.train_epoch(train_loader)
self.logger.info("预训练完成。")

65
src/engine/trainer.py Normal file
View File

@@ -0,0 +1,65 @@
import torch
import torch.nn as nn
from torch.optim import Adam, AdamW
from torch_geometric.data import DataLoader
from ..utils.logging import get_logger
class Trainer:
"""处理(监督学习)训练循环。"""
def __init__(self, model, config):
self.model = model
self.config = config
self.logger = get_logger(self.__class__.__name__)
# 根据配置选择优化器
if config['training']['optimizer'] == 'adam':
self.optimizer = Adam(self.model.parameters(), lr=config['training']['learning_rate'], weight_decay=config['training']['weight_decay'])
elif config['training']['optimizer'] == 'adamw':
self.optimizer = AdamW(self.model.parameters(), lr=config['training']['learning_rate'], weight_decay=config['training']['weight_decay'])
else:
raise ValueError(f"不支持的优化器: {config['training']['optimizer']}")
# 根据配置选择损失函数
if config['training']['loss_function'] == 'bce':
# BCEWithLogitsLoss 结合了 Sigmoid 和 BCELoss更数值稳定
self.criterion = nn.BCEWithLogitsLoss()
# 在此添加其他损失函数,如 focal loss
else:
raise ValueError(f"不支持的损失函数: {config['training']['loss_function']}")
def train_epoch(self, dataloader: DataLoader):
"""运行单个训练周期epoch"""
self.model.train() # 将模型设置为训练模式
total_loss = 0
for batch in dataloader:
self.optimizer.zero_grad() # 清空梯度
# 前向传播
output = self.model(batch)
# 准备目标标签
# 假设标签在图级别,并且需要调整形状以匹配输出
target = batch.y.view_as(output)
# 计算损失
loss = self.criterion(output, target)
# 反向传播
loss.backward()
# 更新权重
self.optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(dataloader)
self.logger.info(f"训练损失: {avg_loss:.4f}")
return avg_loss
def run(self, train_loader: DataLoader, val_loader: DataLoader):
"""运行完整的训练流程。"""
self.logger.info("开始训练...")
for epoch in range(self.config['training']['epochs']):
self.logger.info(f"周期 {epoch+1}/{self.config['training']['epochs']}")
self.train_epoch(train_loader)
# 在此处添加验证步骤,例如调用 Evaluator
self.logger.info("训练完成。")

0
src/init.py Normal file
View File

View File

@@ -0,0 +1,84 @@
import torch
import torch.nn as nn
from .gnn_encoder import GNNEncoder
from .transformer_core import TransformerCore
from .task_heads import ClassificationHead, MatchingHead
class GeoLayoutTransformer(nn.Module):
"""完整的 Geo-Layout Transformer 模型。"""
def __init__(self, config: dict):
"""初始化模型。
Args:
config: 包含所有模型超参数的配置字典。
"""
super(GeoLayoutTransformer, self).__init__()
self.config = config
# 1. GNN 编码器用于将每个版图区块patch编码为嵌入向量
self.gnn_encoder = GNNEncoder(
node_input_dim=config['model']['gnn']['node_input_dim'],
hidden_dim=config['model']['gnn']['hidden_dim'],
output_dim=config['model']['gnn']['output_dim'],
num_layers=config['model']['gnn']['num_layers'],
gnn_type=config['model']['gnn']['gnn_type']
)
# 2. Transformer 骨干网络:用于捕捉区块之间的全局上下文关系
self.transformer_core = TransformerCore(
hidden_dim=config['model']['transformer']['hidden_dim'],
num_layers=config['model']['transformer']['num_layers'],
num_heads=config['model']['transformer']['num_heads'],
dropout=config['model']['transformer']['dropout']
)
# 3. 特定于任务的头:根据配置动态创建
self.task_head = None
if 'task_head' in config['model']:
head_config = config['model']['task_head']
if head_config['type'] == 'classification':
self.task_head = ClassificationHead(
input_dim=head_config['input_dim'],
hidden_dim=head_config['hidden_dim'],
output_dim=head_config['output_dim']
)
elif head_config['type'] == 'matching':
self.task_head = MatchingHead(
input_dim=head_config['input_dim'],
output_dim=head_config['output_dim']
)
# 可在此处添加其他任务头
def forward(self, data) -> torch.Tensor:
"""
Args:
data: 一个 PyG 的 Batch 对象,包含了一批次的图数据。
Returns:
来自任务头的最终输出张量。
"""
# 1. 从 GNN 编码器获取区块嵌入
# PyG 的 DataLoader 会自动将图数据打包成一个大的 Batch 对象
patch_embeddings = self.gnn_encoder(data)
# 2. 为 Transformer 重塑形状: [batch_size, seq_len, hidden_dim]
# 这需要知道批次中每个图包含多少个区块(节点)。
# 我们可以从 PyG Batch 对象的 `ptr` 属性中获取此信息。
num_graphs = data.num_graphs
# `ptr` 记录了每个图的节点数累积和,通过相减得到每个图的节点数
nodes_per_graph = data.ptr[1:] - data.ptr[:-1]
# 假设批次内所有图的区块数相同(对于我们的滑动窗口方法是成立的)
patch_embeddings = patch_embeddings.view(num_graphs, nodes_per_graph[0], -1)
# 3. 将区块嵌入序列传入 Transformer
contextual_embeddings = self.transformer_core(patch_embeddings)
# 4. 将结果传入任务头
if self.task_head:
output = self.task_head(contextual_embeddings)
else:
# 如果没有定义任务头(例如在自监督预训练中),则返回上下文嵌入
output = contextual_embeddings
return output

61
src/models/gnn_encoder.py Normal file
View File

@@ -0,0 +1,61 @@
import torch
import torch.nn as nn
from torch_geometric.nn import GCNConv, SAGEConv, GATConv, global_mean_pool
class GNNEncoder(nn.Module):
"""基于 GNN 的编码器用于生成区块Patch的嵌入向量。"""
def __init__(self, node_input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, gnn_type: str = 'gcn'):
"""
Args:
node_input_dim: 输入节点特征的维度。
hidden_dim: 隐藏层的维度。
output_dim: 输出区块嵌入向量的维度。
num_layers: GNN 层的数量。
gnn_type: 使用的 GNN 层类型('gcn', 'graphsage', 'gat')。
"""
super(GNNEncoder, self).__init__()
self.layers = nn.ModuleList()
# 输入层
self.layers.append(self.get_gnn_layer(node_input_dim, hidden_dim, gnn_type))
# 隐藏层
for _ in range(num_layers - 2):
self.layers.append(self.get_gnn_layer(hidden_dim, hidden_dim, gnn_type))
# 输出层
self.layers.append(self.get_gnn_layer(hidden_dim, output_dim, gnn_type))
# 读出函数,用于将节点嵌入聚合为图级别的嵌入
self.readout = global_mean_pool
def get_gnn_layer(self, in_channels, out_channels, gnn_type):
"""根据类型获取 GNN 层。"""
if gnn_type == 'gcn':
return GCNConv(in_channels, out_channels)
elif gnn_type == 'graphsage':
return SAGEConv(in_channels, out_channels)
elif gnn_type == 'gat':
# 注意GATConv 可能需要额外的参数,如 heads
return GATConv(in_channels, out_channels)
else:
raise ValueError(f"不支持的 GNN 类型: {gnn_type}")
def forward(self, data) -> torch.Tensor:
"""
Args:
data: 一个 PyTorch Geometric 的 Data 或 Batch 对象。
Returns:
一个代表区块的图级别嵌入的张量。
"""
x, edge_index, batch = data.x, data.edge_index, data.batch
# 通过所有 GNN 层
for layer in self.layers:
x = layer(x, edge_index)
x = torch.relu(x)
# 全局池化以获得图级别的嵌入
graph_embedding = self.readout(x, batch)
return graph_embedding

0
src/models/init.py Normal file
View File

51
src/models/task_heads.py Normal file
View File

@@ -0,0 +1,51 @@
import torch
import torch.nn as nn
class ClassificationHead(nn.Module):
"""一个用于分类任务的简单多层感知机MLP任务头。"""
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
super(ClassificationHead, self).__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_dim, output_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: 来自 Transformer 骨干网络的输入张量。
Returns:
最终的分类 logits。
"""
# 我们可以取第一个 token类似 [CLS])的嵌入,或者进行平均池化
# 为简单起见,我们假设在序列维度上进行平均池化
x_pooled = torch.mean(x, dim=1)
out = self.fc1(x_pooled)
out = self.relu(out)
out = self.fc2(out)
return out
class MatchingHead(nn.Module):
"""用于学习版图匹配的相似性嵌入的任务头。"""
def __init__(self, input_dim: int, output_dim: int):
super(MatchingHead, self).__init__()
self.projection = nn.Linear(input_dim, output_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: 来自 Transformer 骨干网络的输入张量。
Returns:
代表整个输入图(例如一个 IP 模块)的单个嵌入向量。
"""
# 全局平均池化,为整个序列获取一个单一的向量
graph_embedding = torch.mean(x, dim=1)
# 投影到最终的嵌入空间
similarity_embedding = self.projection(graph_embedding)
# 对嵌入进行 L2 归一化,以便使用余弦相似度
similarity_embedding = nn.functional.normalize(similarity_embedding, p=2, dim=1)
return similarity_embedding

View File

@@ -0,0 +1,65 @@
import torch
import torch.nn as nn
import math
class PositionalEncoding(nn.Module):
"""向输入序列中注入位置信息。"""
def __init__(self, d_model: int, max_len: int = 5000):
super(PositionalEncoding, self).__init__()
# 创建一个足够大的位置编码矩阵
pe = torch.zeros(max_len, d_model)
# 创建位置信息 [max_len, 1]
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
# 计算用于正弦和余弦函数的分母项
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
# 计算偶数维度的位置编码(使用正弦)
pe[:, 0::2] = torch.sin(position * div_term)
# 计算奇数维度的位置编码(使用余弦)
pe[:, 1::2] = torch.cos(position * div_term)
# 调整形状以匹配输入 [max_len, 1, d_model]
pe = pe.unsqueeze(0).transpose(0, 1)
# 将 pe 注册为 buffer这样它不会被视为模型参数
self.register_buffer('pe', pe)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: 张量,形状为 [seq_len, batch_size, embedding_dim]
"""
# 将位置编码加到输入张量上
x = x + self.pe[:x.size(0), :]
return x
class TransformerCore(nn.Module):
"""用于全局上下文建模的 Transformer 骨干网络。"""
def __init__(self, hidden_dim: int, num_layers: int, num_heads: int, dropout: float = 0.1):
super(TransformerCore, self).__init__()
self.pos_encoder = PositionalEncoding(hidden_dim)
# 定义 Transformer 编码器层
encoder_layers = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=num_heads, dropout=dropout, batch_first=True)
# 堆叠多个编码器层形成完整的 Transformer 编码器
self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers=num_layers)
def forward(self, patch_embeddings: torch.Tensor) -> torch.Tensor:
"""
Args:
patch_embeddings: 形状为 [batch_size, seq_len, hidden_dim] 的张量,
代表所有区块的嵌入向量。
Returns:
一个形状为 [batch_size, seq_len, hidden_dim] 的、包含全局上下文信息的张量。
"""
# 注意PyTorch 的 TransformerEncoderLayer 期望的输入形状是 (seq_len, batch, features)
# 如果 batch_first=False或者 (batch, seq_len, features) 如果 batch_first=True。
# 我们的输入是 [batch_size, seq_len, hidden_dim],所以我们设置 batch_first=True。
# 我们使用的 PositionalEncoding 是为 (seq_len, batch, features) 设计的,所以需要调整一下形状
src = patch_embeddings.transpose(0, 1) # 转换为 [seq_len, batch_size, hidden_dim]
src = self.pos_encoder(src)
src = src.transpose(0, 1) # 转换回 [batch_size, seq_len, hidden_dim]
# 将带有位置信息的嵌入传入 Transformer
output = self.transformer_encoder(src)
return output

View File

@@ -0,0 +1,36 @@
import yaml
from pathlib import Path
def load_config(config_file: str) -> dict:
"""加载 YAML 配置文件。
Args:
config_file: YAML 配置文件的路径。
Returns:
包含配置信息的字典。
"""
with open(config_file, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
return config
def merge_configs(base_config: dict, task_config: dict) -> dict:
"""将特定于任务的配置合并到基础配置中。
Args:
base_config: 基础(默认)配置。
task_config: 要合并的特定于任务的配置。
Returns:
合并后的配置字典。
"""
merged = base_config.copy() # 复制基础配置
# 遍历任务配置中的键值对
for key, value in task_config.items():
# 如果值是字典且键也存在于合并后的配置中,则递归合并
if isinstance(value, dict) and key in merged and isinstance(merged[key], dict):
merged[key] = merge_configs(merged[key], value)
# 否则,直接用任务配置的值覆盖
else:
merged[key] = value
return merged

0
src/utils/init.py Normal file
View File

31
src/utils/logging.py Normal file
View File

@@ -0,0 +1,31 @@
import logging
import sys
def get_logger(name: str, level=logging.INFO) -> logging.Logger:
"""创建并配置一个日志记录器。
Args:
name: 日志记录器的名称。
level: 日志记录级别。
Returns:
一个配置好的日志记录器实例。
"""
# 获取指定名称的日志记录器
logger = logging.getLogger(name)
# 设置日志记录器的级别
logger.setLevel(level)
# 创建一个处理器,用于将日志记录输出到标准输出
handler = logging.StreamHandler(sys.stdout)
handler.setLevel(level)
# 创建一个格式化器,并将其添加到处理器
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
# 将处理器添加到日志记录器(如果尚未添加)
if not logger.handlers:
logger.addHandler(handler)
return logger