添加数据增强方案以及扩散生成模型的想法

This commit is contained in:
Jiao77
2025-10-20 21:14:03 +08:00
parent d6d00cf088
commit 08f488f0d8
22 changed files with 1903 additions and 190 deletions

View File

@@ -1,6 +1,6 @@
import os
import json
from typing import Tuple
from typing import Tuple, Optional
import cv2
import numpy as np
@@ -70,6 +70,8 @@ class ICLayoutTrainingDataset(Dataset):
patch_size: int = 256,
transform=None,
scale_range: Tuple[float, float] = (1.0, 1.0),
use_albu: bool = False,
albu_params: Optional[dict] = None,
) -> None:
self.image_dir = image_dir
self.image_paths = [
@@ -80,6 +82,28 @@ class ICLayoutTrainingDataset(Dataset):
self.patch_size = patch_size
self.transform = transform
self.scale_range = scale_range
# 可选的 albumentations 管道
self.albu = None
if use_albu:
try:
import albumentations as A # 延迟导入,避免环境未安装时报错
p = albu_params or {}
elastic_prob = float(p.get("prob", 0.3))
alpha = float(p.get("alpha", 40))
sigma = float(p.get("sigma", 6))
alpha_affine = float(p.get("alpha_affine", 6))
use_bc = bool(p.get("brightness_contrast", True))
use_noise = bool(p.get("gauss_noise", True))
transforms_list = [
A.ElasticTransform(alpha=alpha, sigma=sigma, alpha_affine=alpha_affine, p=elastic_prob),
]
if use_bc:
transforms_list.append(A.RandomBrightnessContrast(p=0.5))
if use_noise:
transforms_list.append(A.GaussNoise(var_limit=(5.0, 20.0), p=0.3))
self.albu = A.Compose(transforms_list)
except Exception:
self.albu = None
def __len__(self) -> int:
return len(self.image_paths)
@@ -102,22 +126,27 @@ class ICLayoutTrainingDataset(Dataset):
patch = image.crop((x, y, x + crop_size, y + crop_size))
patch = patch.resize((self.patch_size, self.patch_size), Image.Resampling.LANCZOS)
# 亮度/对比度增强
if np.random.random() < 0.5:
brightness_factor = np.random.uniform(0.8, 1.2)
patch = patch.point(lambda px: int(np.clip(px * brightness_factor, 0, 255)))
if np.random.random() < 0.5:
contrast_factor = np.random.uniform(0.8, 1.2)
patch = patch.point(lambda px: int(np.clip(((px - 128) * contrast_factor) + 128, 0, 255)))
if np.random.random() < 0.3:
patch_np = np.array(patch, dtype=np.float32)
noise = np.random.normal(0, 5, patch_np.shape)
patch_np = np.clip(patch_np + noise, 0, 255)
patch = Image.fromarray(patch_np.astype(np.uint8))
# photometric/elastic在几何 H 之前)
patch_np_uint8 = np.array(patch)
if self.albu is not None:
patch_np_uint8 = self.albu(image=patch_np_uint8)["image"]
patch = Image.fromarray(patch_np_uint8)
else:
# 原有轻量光度增强
if np.random.random() < 0.5:
brightness_factor = np.random.uniform(0.8, 1.2)
patch = patch.point(lambda px: int(np.clip(px * brightness_factor, 0, 255)))
if np.random.random() < 0.5:
contrast_factor = np.random.uniform(0.8, 1.2)
patch = patch.point(lambda px: int(np.clip(((px - 128) * contrast_factor) + 128, 0, 255)))
if np.random.random() < 0.3:
patch_np = np.array(patch, dtype=np.float32)
noise = np.random.normal(0, 5, patch_np.shape)
patch_np = np.clip(patch_np + noise, 0, 255)
patch = Image.fromarray(patch_np.astype(np.uint8))
patch_np_uint8 = np.array(patch)
# 随机旋转与镜像8个离散变换
theta_deg = int(np.random.choice([0, 90, 180, 270]))