56 lines
		
	
	
		
			1.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			56 lines
		
	
	
		
			1.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import os
 | |
| from PIL import Image
 | |
| from torch.utils.data import Dataset
 | |
| import json
 | |
| 
 | |
| class ICLayoutDataset(Dataset):
 | |
|     def __init__(self, image_dir, annotation_dir=None, transform=None):
 | |
|         """
 | |
|         Initialize the IC layout dataset.
 | |
| 
 | |
|         Args:
 | |
|             image_dir (str): Directory path containing PNG format IC layout images.
 | |
|             annotation_dir (str, optional): Directory path containing JSON format annotation files.
 | |
|             transform (callable, optional): Optional transform to apply to images (e.g., Sobel edge detection).
 | |
|         """
 | |
|         self.image_dir = image_dir
 | |
|         self.annotation_dir = annotation_dir
 | |
|         self.transform = transform
 | |
|         self.images = [f for f in os.listdir(image_dir) if f.endswith('.png')]
 | |
|         if annotation_dir:
 | |
|             self.annotations = [f.replace('.png', '.json') for f in self.images]
 | |
|         else:
 | |
|             self.annotations = [None] * len(self.images)
 | |
| 
 | |
|     def __len__(self):
 | |
|         """
 | |
|         Return the number of images in the dataset.
 | |
| 
 | |
|         Returns:
 | |
|             int: Dataset size.
 | |
|         """
 | |
|         return len(self.images)
 | |
| 
 | |
|     def __getitem__(self, idx):
 | |
|         """
 | |
|         Get image and annotation at specified index.
 | |
| 
 | |
|         Args:
 | |
|             idx (int): Image index.
 | |
| 
 | |
|         Returns:
 | |
|             tuple: (image, annotation), where image is the processed image and annotation is the annotation dict or empty dict.
 | |
|         """
 | |
|         img_path = os.path.join(self.image_dir, self.images[idx])
 | |
|         image = Image.open(img_path).convert('L')  # Convert to grayscale
 | |
|         if self.transform:
 | |
|             image = self.transform(image)
 | |
|         
 | |
|         annotation = {}
 | |
|         if self.annotation_dir and self.annotations[idx]:
 | |
|             ann_path = os.path.join(self.annotation_dir, self.annotations[idx])
 | |
|             if os.path.exists(ann_path):
 | |
|                 with open(ann_path, 'r') as f:
 | |
|                     annotation = json.load(f)
 | |
|         
 | |
|         return image, annotation | 
