initial commit

This commit is contained in:
Jiao77
2025-08-25 17:54:08 +08:00
commit f187abe72a
28 changed files with 1703 additions and 0 deletions

37
src/data/dataset.py Normal file
View File

@@ -0,0 +1,37 @@
import torch
from torch_geometric.data import Dataset, InMemoryDataset
import os
class LayoutDataset(InMemoryDataset):
"""用于加载预处理后的版图图数据的 PyTorch Geometric 数据集。"""
def __init__(self, root, transform=None, pre_transform=None):
"""
Args:
root: 数据集应保存的根目录。
transform: 一个函数/变换,作用于 `Data` 对象并返回一个转换后的版本。
pre_transform: 一个函数/变换,作用于 `Data` 对象并返回一个转换后的版本。
"""
super(LayoutDataset, self).__init__(root, transform, pre_transform)
# 加载已处理的数据
self.data, self.slices = torch.load(self.processed_paths[0])
@property
def raw_file_names(self):
"""如果 `download()` 返回一个路径列表,这里会返回它们的文件名。"""
return [] # 我们不从网络下载原始文件
@property
def processed_file_names(self):
"""在 `processed_dir` 目录中必须存在的文件列表,用以跳过处理步骤。"""
return ['data.pt']
def download(self):
"""从网上下载原始数据到 `raw_dir` 目录。"""
pass # 假设数据是预先处理好的
def process(self):
"""处理原始数据并将其保存到 `processed_dir` 目录。"""
# 如果希望在加载时动态处理数据,可以在这里实现 `scripts/preprocess_gds.py` 中的逻辑。
# 在我们的框架中,我们假设预处理是通过脚本独立完成的。
pass