39 lines
1.6 KiB
Python
39 lines
1.6 KiB
Python
# src/data/dataset.py
|
|
import torch
|
|
from torch_geometric.data import Dataset, InMemoryDataset
|
|
import os
|
|
|
|
class LayoutDataset(InMemoryDataset):
|
|
"""用于加载预处理后的版图图数据的 PyTorch Geometric 数据集。"""
|
|
|
|
def __init__(self, root, transform=None, pre_transform=None):
|
|
"""
|
|
Args:
|
|
root: 数据集应保存的根目录。
|
|
transform: 一个函数/变换,作用于 `Data` 对象并返回一个转换后的版本。
|
|
pre_transform: 一个函数/变换,作用于 `Data` 对象并返回一个转换后的版本。
|
|
"""
|
|
super(LayoutDataset, self).__init__(root, transform, pre_transform)
|
|
# 加载已处理的数据
|
|
self.data, self.slices = torch.load(self.processed_paths[0])
|
|
|
|
@property
|
|
def raw_file_names(self):
|
|
"""如果 `download()` 返回一个路径列表,这里会返回它们的文件名。"""
|
|
return [] # 我们不从网络下载原始文件
|
|
|
|
@property
|
|
def processed_file_names(self):
|
|
"""在 `processed_dir` 目录中必须存在的文件列表,用以跳过处理步骤。"""
|
|
return ['data.pt']
|
|
|
|
def download(self):
|
|
"""从网上下载原始数据到 `raw_dir` 目录。"""
|
|
pass # 假设数据是预先处理好的
|
|
|
|
def process(self):
|
|
"""处理原始数据并将其保存到 `processed_dir` 目录。"""
|
|
# 如果希望在加载时动态处理数据,可以在这里实现 `scripts/preprocess_gds.py` 中的逻辑。
|
|
# 在我们的框架中,我们假设预处理是通过脚本独立完成的。
|
|
pass
|