56 lines
1.9 KiB
Python
56 lines
1.9 KiB
Python
import os
|
|
from PIL import Image
|
|
from torch.utils.data import Dataset
|
|
import json
|
|
|
|
class ICLayoutDataset(Dataset):
|
|
def __init__(self, image_dir, annotation_dir=None, transform=None):
|
|
"""
|
|
Initialize the IC layout dataset.
|
|
|
|
Args:
|
|
image_dir (str): Directory path containing PNG format IC layout images.
|
|
annotation_dir (str, optional): Directory path containing JSON format annotation files.
|
|
transform (callable, optional): Optional transform to apply to images (e.g., Sobel edge detection).
|
|
"""
|
|
self.image_dir = image_dir
|
|
self.annotation_dir = annotation_dir
|
|
self.transform = transform
|
|
self.images = [f for f in os.listdir(image_dir) if f.endswith('.png')]
|
|
if annotation_dir:
|
|
self.annotations = [f.replace('.png', '.json') for f in self.images]
|
|
else:
|
|
self.annotations = [None] * len(self.images)
|
|
|
|
def __len__(self):
|
|
"""
|
|
Return the number of images in the dataset.
|
|
|
|
Returns:
|
|
int: Dataset size.
|
|
"""
|
|
return len(self.images)
|
|
|
|
def __getitem__(self, idx):
|
|
"""
|
|
Get image and annotation at specified index.
|
|
|
|
Args:
|
|
idx (int): Image index.
|
|
|
|
Returns:
|
|
tuple: (image, annotation), where image is the processed image and annotation is the annotation dict or empty dict.
|
|
"""
|
|
img_path = os.path.join(self.image_dir, self.images[idx])
|
|
image = Image.open(img_path).convert('L') # Convert to grayscale
|
|
if self.transform:
|
|
image = self.transform(image)
|
|
|
|
annotation = {}
|
|
if self.annotation_dir and self.annotations[idx]:
|
|
ann_path = os.path.join(self.annotation_dir, self.annotations[idx])
|
|
if os.path.exists(ann_path):
|
|
with open(ann_path, 'r') as f:
|
|
annotation = json.load(f)
|
|
|
|
return image, annotation |