initial commit
This commit is contained in:
90
scripts/preprocess_gds.py
Normal file
90
scripts/preprocess_gds.py
Normal file
@@ -0,0 +1,90 @@
|
||||
import argparse
|
||||
import os
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
from torch_geometric.data import InMemoryDataset, Data
|
||||
|
||||
from src.utils.config_loader import load_config
|
||||
from src.data.gds_parser import GDSParser
|
||||
from src.data.graph_constructor import GraphConstructor
|
||||
from src.utils.logging import get_logger
|
||||
|
||||
# 这是一个辅助的数据集类,仅用于在预处理脚本中保存数据
|
||||
class TempDataset(InMemoryDataset):
|
||||
def __init__(self, root, data_list=None):
|
||||
self.data_list = data_list
|
||||
super(TempDataset, self).__init__(root)
|
||||
self.data, self.slices = self.collate(data_list)
|
||||
|
||||
@property
|
||||
def raw_file_names(self):
|
||||
return []
|
||||
|
||||
@property
|
||||
def processed_file_names(self):
|
||||
return ['data.pt']
|
||||
|
||||
def download(self):
|
||||
pass
|
||||
|
||||
def process(self):
|
||||
# 数据已在外部处理好,直接保存
|
||||
torch.save((self.data, self.slices), self.processed_paths[0])
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="将 GDSII 文件预处理为图数据。")
|
||||
parser.add_argument("--config-file", required=True, help="配置文件的路径。")
|
||||
parser.add_argument("--gds-file", required=True, help="要处理的 GDSII 文件的路径。")
|
||||
parser.add_argument("--output-dir", required=True, help="保存处理后图数据的目录。")
|
||||
# 可以添加一个参数来指定标签文件,例如 DRC 报告
|
||||
# parser.add_argument("--label-file", help="标签文件的路径。")
|
||||
args = parser.parse_args()
|
||||
|
||||
logger = get_logger("GDS_Preprocessor")
|
||||
|
||||
logger.info(f"从 {args.config_file} 加载配置")
|
||||
config = load_config(args.config_file)
|
||||
|
||||
logger.info(f"为 {args.gds_file} 初始化 GDSParser")
|
||||
gds_parser = GDSParser(args.gds_file, config['data']['layer_mapping'])
|
||||
|
||||
logger.info("初始化 GraphConstructor")
|
||||
graph_constructor = GraphConstructor(
|
||||
edge_strategy=config['data']['graph_construction']['edge_strategy'],
|
||||
knn_k=config['data']['graph_construction']['knn_k'],
|
||||
radius_d=config['data']['graph_construction']['radius_d']
|
||||
)
|
||||
|
||||
logger.info("正在生成区块...")
|
||||
patches = gds_parser.get_patches(config['data']['patch_size'], config['data']['patch_stride'])
|
||||
logger.info(f"生成了 {len(patches)} 个区块。")
|
||||
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
graph_list = []
|
||||
# 使用 tqdm 显示进度条
|
||||
for patch_bbox in tqdm(patches, desc="处理区块中"):
|
||||
geometries = gds_parser.extract_geometries_from_patch(patch_bbox)
|
||||
if geometries:
|
||||
# 在真实场景中,您需要从 DRC 报告等来源获取标签
|
||||
# 在这个占位符中,我们假设一个虚拟标签 0
|
||||
# TODO: 实现从标签文件加载标签的逻辑
|
||||
graph = graph_constructor.construct_graph(geometries, label=0)
|
||||
if graph:
|
||||
# PyG 要求 Data 对象具有 y 属性
|
||||
if not hasattr(graph, 'y'):
|
||||
graph.y = torch.tensor([0], dtype=torch.float)
|
||||
graph_list.append(graph)
|
||||
|
||||
logger.info(f"成功构建了 {len(graph_list)} 个图。")
|
||||
|
||||
if graph_list:
|
||||
# 使用 PyG 的 InMemoryDataset 格式保存数据,以便高效加载
|
||||
logger.info("正在将数据保存为 PyG InMemoryDataset 格式...")
|
||||
dataset = TempDataset(root=args.output_dir, data_list=graph_list)
|
||||
logger.info(f"已将处理好的数据保存到 {dataset.processed_paths[0]}")
|
||||
else:
|
||||
logger.warning("没有生成任何图数据,不进行保存。")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
63
scripts/visualize_attention.py
Normal file
63
scripts/visualize_attention.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import argparse
|
||||
import torch
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
|
||||
from src.utils.config_loader import load_config
|
||||
from src.models.geo_layout_transformer import GeoLayoutTransformer
|
||||
from src.utils.logging import get_logger
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="可视化来自已训练模型的注意力图。")
|
||||
parser.add_argument("--config-file", required=True, help="模型配置文件的路径。")
|
||||
parser.add_argument("--model-path", required=True, help="已训练模型检查点的路径。")
|
||||
parser.add_argument("--patch-data", required=True, help="区块数据样本(.pt 文件)的路径。")
|
||||
args = parser.parse_args()
|
||||
|
||||
logger = get_logger("Attention_Visualizer")
|
||||
|
||||
logger.info("这是一个用于注意力可视化的占位符脚本。")
|
||||
logger.info("完整的实现需要加载一个训练好的模型、一个数据样本,然后提取注意力权重。")
|
||||
|
||||
# 1. 加载配置和模型
|
||||
# logger.info("正在加载模型...")
|
||||
# config = load_config(args.config_file)
|
||||
# model = GeoLayoutTransformer(config)
|
||||
# model.load_state_dict(torch.load(args.model_path))
|
||||
# model.eval()
|
||||
|
||||
# 2. 加载一个数据样本
|
||||
# logger.info(f"正在加载数据样本从 {args.patch_data}")
|
||||
# sample_data = torch.load(args.patch_data)
|
||||
|
||||
# 3. 注册钩子(Hook)到模型中以提取注意力权重
|
||||
# 这是一个复杂的过程,需要访问 nn.MultiheadAttention 模块的前向传播过程。
|
||||
# attention_weights = []
|
||||
# def hook(module, input, output):
|
||||
# # output[1] 是注意力权重
|
||||
# attention_weights.append(output[1])
|
||||
# model.transformer_core.transformer_encoder.layers[0].self_attn.register_forward_hook(hook)
|
||||
|
||||
# 4. 运行一次前向传播以获取权重
|
||||
# logger.info("正在运行前向传播...")
|
||||
# with torch.no_grad():
|
||||
# # 模型需要修改以支持返回注意力权重,或者通过钩子获取
|
||||
# _ = model(sample_data)
|
||||
|
||||
# 5. 绘制注意力图
|
||||
# if attention_weights:
|
||||
# logger.info("正在绘制注意力图...")
|
||||
# # attention_weights[0] 的形状是 [batch_size, num_heads, seq_len, seq_len]
|
||||
# # 我们取第一项,并在所有头上取平均值
|
||||
# avg_attention = attention_weights[0][0].mean(dim=0).cpu().numpy()
|
||||
# plt.figure(figsize=(10, 10))
|
||||
# sns.heatmap(avg_attention, cmap='viridis')
|
||||
# plt.title("区块之间的平均注意力图")
|
||||
# plt.xlabel("区块索引")
|
||||
# plt.ylabel("区块索引")
|
||||
# plt.show()
|
||||
# else:
|
||||
# logger.warning("未能提取注意力权重。")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user