添加数据增强方案以及扩散生成模型的想法

This commit is contained in:
Jiao77
2025-10-20 21:14:03 +08:00
parent d6d00cf088
commit 08f488f0d8
22 changed files with 1903 additions and 190 deletions

View File

@@ -0,0 +1,46 @@
#!/usr/bin/env python3
"""
Prepare raster patch dataset and optional condition maps for diffusion training.
Planned inputs:
- --src_dirs: one or more directories containing PNG layout images
- --out_dir: output root for images/ and conditions/
- --size: patch size (e.g., 256)
- --stride: sliding stride for patch extraction
- --min_fg_ratio: minimum foreground ratio to keep a patch (0-1)
- --make_conditions: flags to generate edge/skeleton/distance maps
Current status: CLI skeleton and TODOs only.
"""
from __future__ import annotations
import argparse
from pathlib import Path
def main() -> None:
parser = argparse.ArgumentParser(description="Prepare patch dataset for diffusion training (skeleton)")
parser.add_argument("--src_dirs", type=str, nargs="+", help="Source PNG dirs for layouts")
parser.add_argument("--out_dir", type=str, required=True, help="Output root directory")
parser.add_argument("--size", type=int, default=256, help="Patch size")
parser.add_argument("--stride", type=int, default=256, help="Patch stride")
parser.add_argument("--min_fg_ratio", type=float, default=0.02, help="Min foreground ratio to keep a patch")
parser.add_argument("--make_edge", action="store_true", help="Generate edge map conditions (e.g., Sobel/Canny)")
parser.add_argument("--make_skeleton", action="store_true", help="Generate morphological skeleton condition")
parser.add_argument("--make_dist", action="store_true", help="Generate distance transform condition")
args = parser.parse_args()
out_root = Path(args.out_dir)
out_root.mkdir(parents=True, exist_ok=True)
(out_root / "images").mkdir(exist_ok=True)
(out_root / "conditions").mkdir(exist_ok=True)
# TODO: implement extraction loop over src_dirs, crop patches, filter by min_fg_ratio,
# and save into images/; generate optional condition maps into conditions/ mirroring filenames.
# Keep file naming consistent: images/xxx.png, conditions/xxx_edge.png, etc.
print("[TODO] Implement patch extraction and condition map generation.")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,38 @@
#!/usr/bin/env python3
"""
Sample layout patches using a trained diffusion model (skeleton).
Outputs raster PNGs into a target directory compatible with current training pipeline (no H pairing).
Current status: CLI skeleton and TODOs only.
"""
from __future__ import annotations
import argparse
from pathlib import Path
def main() -> None:
parser = argparse.ArgumentParser(description="Sample layout patches from diffusion model (skeleton)")
parser.add_argument("--ckpt", type=str, required=True, help="Path to trained diffusion checkpoint or HF repo id")
parser.add_argument("--out_dir", type=str, required=True, help="Directory to write sampled PNGs")
parser.add_argument("--num", type=int, default=200)
parser.add_argument("--image_size", type=int, default=256)
parser.add_argument("--guidance", type=float, default=5.0)
parser.add_argument("--steps", type=int, default=50)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--cond_dir", type=str, default=None, help="Optional condition maps directory")
parser.add_argument("--cond_types", type=str, nargs="*", default=None, help="e.g., edge skeleton dist")
args = parser.parse_args()
out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
# TODO: load pipeline from ckpt, set scheduler, handle conditions if provided,
# sample args.num images, save as PNG files into out_dir.
print("[TODO] Implement diffusion sampling and PNG saving.")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,37 @@
#!/usr/bin/env python3
"""
Train a diffusion model for layout patch generation (skeleton).
Planned: fine-tune Stable Diffusion (or Latent Diffusion) with optional ControlNet edge/skeleton conditions.
Dependencies to consider: diffusers, transformers, accelerate, torch, torchvision, opencv-python.
Current status: CLI skeleton and TODOs only.
"""
from __future__ import annotations
import argparse
def main() -> None:
parser = argparse.ArgumentParser(description="Train diffusion model for layout patches (skeleton)")
parser.add_argument("--data_dir", type=str, required=True, help="Prepared dataset root (images/ + conditions/)")
parser.add_argument("--output_dir", type=str, required=True, help="Checkpoint output directory")
parser.add_argument("--image_size", type=int, default=256)
parser.add_argument("--batch_size", type=int, default=8)
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--max_steps", type=int, default=100000)
parser.add_argument("--use_controlnet", action="store_true", help="Train with ControlNet conditioning")
parser.add_argument("--condition_types", type=str, nargs="*", default=["edge"], help="e.g., edge skeleton dist")
args = parser.parse_args()
# TODO: implement dataset/dataloader (images and optional conditions)
# TODO: load base pipeline (Stable Diffusion or Latent Diffusion) and optionally ControlNet
# TODO: set up optimizer, LR schedule, EMA, gradient accumulation, and run training loop
# TODO: save periodic checkpoints to output_dir
print("[TODO] Implement diffusion training loop and checkpoints.")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,90 @@
#!/usr/bin/env python3
"""
Programmatic synthetic IC layout generator using gdstk.
Generates GDS files with simple standard-cell-like patterns, wires, and vias.
"""
from __future__ import annotations
import argparse
from pathlib import Path
import random
import gdstk
def build_standard_cell(cell_name: str, rng: random.Random, layer: int = 1, datatype: int = 0) -> gdstk.Cell:
cell = gdstk.Cell(cell_name)
# Basic cell body
w = rng.uniform(0.8, 2.0)
h = rng.uniform(1.6, 4.0)
rect = gdstk.rectangle((0, 0), (w, h), layer=layer, datatype=datatype)
cell.add(rect)
# Poly fingers
nf = rng.randint(1, 4)
pitch = w / (nf + 1)
for i in range(1, nf + 1):
x = i * pitch
poly = gdstk.rectangle((x - 0.05, 0), (x + 0.05, h), layer=layer + 1, datatype=datatype)
cell.add(poly)
# Contact/vias
for i in range(rng.randint(2, 6)):
vx = rng.uniform(0.1, w - 0.1)
vy = rng.uniform(0.1, h - 0.1)
via = gdstk.rectangle((vx - 0.05, vy - 0.05), (vx + 0.05, vy + 0.05), layer=layer + 2, datatype=datatype)
cell.add(via)
return cell
def generate_layout(out_path: Path, width: float, height: float, seed: int, rows: int, cols: int, density: float):
rng = random.Random(seed)
lib = gdstk.Library()
top = gdstk.Cell("TOP")
# Create a few standard cell variants
variants = [build_standard_cell(f"SC_{i}", rng, layer=1) for i in range(4)]
# Place instances in a grid with random skips based on density
x_pitch = width / cols
y_pitch = height / rows
for r in range(rows):
for c in range(cols):
if rng.random() > density:
continue
cell = rng.choice(variants)
dx = c * x_pitch + rng.uniform(0.0, 0.1 * x_pitch)
dy = r * y_pitch + rng.uniform(0.0, 0.1 * y_pitch)
ref = gdstk.Reference(cell, (dx, dy))
top.add(ref)
lib.add(*variants)
lib.add(top)
lib.write_gds(str(out_path))
def main():
parser = argparse.ArgumentParser(description="Generate synthetic IC layouts (GDS)")
parser.add_argument("--out-dir", type=str, default="data/synthetic/gds")
parser.add_argument("--out_dir", dest="out_dir", type=str, help="Alias of --out-dir")
parser.add_argument("--num-samples", type=int, default=10)
parser.add_argument("--num", dest="num_samples", type=int, help="Alias of --num-samples")
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--width", type=float, default=200.0)
parser.add_argument("--height", type=float, default=200.0)
parser.add_argument("--rows", type=int, default=10)
parser.add_argument("--cols", type=int, default=10)
parser.add_argument("--density", type=float, default=0.5)
args = parser.parse_args()
out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
rng = random.Random(args.seed)
for i in range(args.num_samples):
sample_seed = rng.randint(0, 2**31 - 1)
out_path = out_dir / f"chip_{i:06d}.gds"
generate_layout(out_path, args.width, args.height, sample_seed, args.rows, args.cols, args.density)
print(f"[OK] Generated {out_path}")
if __name__ == "__main__":
main()

160
tools/layout2png.py Normal file
View File

@@ -0,0 +1,160 @@
#!/usr/bin/env python3
"""
Batch convert GDS to PNG.
Priority:
1) Use KLayout in headless batch mode (most accurate view fidelity for IC layouts).
2) Fallback to gdstk(read) -> write SVG -> cairosvg to PNG (no KLayout dependency at runtime).
"""
from __future__ import annotations
import argparse
from pathlib import Path
import subprocess
import sys
import tempfile
import cairosvg
def klayout_convert(gds_path: Path, png_path: Path, dpi: int, layermap: str | None = None, line_width: int | None = None, bgcolor: str | None = None) -> bool:
"""Render using KLayout by invoking a temporary Python macro with paths embedded."""
# Prepare optional display config code
layer_cfg_code = ""
if layermap:
# layermap format: "LAYER/DATATYPE:#RRGGBB,..."
layer_cfg_code += "lprops = pya.LayerPropertiesNode()\n"
for spec in layermap.split(","):
spec = spec.strip()
if not spec:
continue
try:
ld, color = spec.split(":")
layer_s, datatype_s = ld.split("/")
color = color.strip()
layer_cfg_code += (
"lp = pya.LayerPropertiesNode()\n"
f"lp.layer = int({int(layer_s)})\n"
f"lp.datatype = int({int(datatype_s)})\n"
f"lp.fill_color = pya.Color.from_string('{color}')\n"
f"lp.frame_color = pya.Color.from_string('{color}')\n"
"lprops.insert(lp)\n"
)
except Exception:
# Ignore malformed entries
continue
layer_cfg_code += "cv.set_layer_properties(lprops)\n"
line_width_code = ""
if line_width is not None:
line_width_code = f"cv.set_config('default-draw-line-width', '{int(line_width)}')\n"
bg_code = ""
if bgcolor:
bg_code = f"cv.set_config('background-color', '{bgcolor}')\n"
script = f"""
import pya
ly = pya.Layout()
ly.read(r"{gds_path}")
cv = pya.LayoutView()
cv.load_layout(ly, 0)
cv.max_hier_levels = 20
{bg_code}
{line_width_code}
{layer_cfg_code}
cv.zoom_fit()
cv.save_image(r"{png_path}", {dpi}, 0)
"""
try:
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as tf:
tf.write(script)
tf.flush()
macro_path = Path(tf.name)
# Run klayout in batch mode
res = subprocess.run(["klayout", "-zz", "-b", "-r", str(macro_path)], check=False, capture_output=True, text=True)
ok = res.returncode == 0 and png_path.exists()
if not ok:
# Print stderr for visibility when running manually
if res.stderr:
sys.stderr.write(res.stderr)
try:
macro_path.unlink(missing_ok=True) # type: ignore[arg-type]
except Exception:
pass
return ok
except FileNotFoundError:
# klayout command not found
return False
except Exception:
return False
def gdstk_fallback(gds_path: Path, png_path: Path, dpi: int) -> bool:
"""Fallback path: use gdstk to read GDS and write SVG, then cairosvg to PNG.
Note: This may differ visually from KLayout depending on layers/styles.
"""
try:
import gdstk # local import to avoid import cost when not needed
svg_path = png_path.with_suffix(".svg")
lib = gdstk.read_gds(str(gds_path))
tops = lib.top_level()
if not tops:
return False
# Combine tops into a single temporary cell for rendering
cell = tops[0]
# gdstk Cell has write_svg in recent versions
try:
cell.write_svg(str(svg_path)) # type: ignore[attr-defined]
except Exception:
# Older gdstk: write_svg available on Library
try:
lib.write_svg(str(svg_path)) # type: ignore[attr-defined]
except Exception:
return False
# Convert SVG to PNG
cairosvg.svg2png(url=str(svg_path), write_to=str(png_path), dpi=dpi)
try:
svg_path.unlink()
except Exception:
pass
return True
except Exception:
return False
def main():
parser = argparse.ArgumentParser(description="Convert GDS files to PNG")
parser.add_argument("--in", dest="in_dir", type=str, required=True, help="Input directory containing .gds files")
parser.add_argument("--out", dest="out_dir", type=str, required=True, help="Output directory to place .png files")
parser.add_argument("--dpi", type=int, default=600, help="Output resolution in DPI for rasterization")
parser.add_argument("--layermap", type=str, default=None, help="Layer color map, e.g. '1/0:#00FF00,2/0:#FF0000'")
parser.add_argument("--line_width", type=int, default=None, help="Default draw line width in pixels for KLayout display")
parser.add_argument("--bgcolor", type=str, default=None, help="Background color, e.g. '#000000' or 'black'")
args = parser.parse_args()
in_dir = Path(args.in_dir)
out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
gds_files = sorted(in_dir.glob("*.gds"))
if not gds_files:
print(f"[WARN] No GDS files found in {in_dir}")
return
ok_cnt = 0
for gds in gds_files:
png_path = out_dir / (gds.stem + ".png")
ok = klayout_convert(gds, png_path, args.dpi, layermap=args.layermap, line_width=args.line_width, bgcolor=args.bgcolor)
if not ok:
ok = gdstk_fallback(gds, png_path, args.dpi)
if ok:
ok_cnt += 1
print(f"[OK] {gds.name} -> {png_path}")
else:
print(f"[FAIL] {gds.name}")
print(f"Done. {ok_cnt}/{len(gds_files)} converted.")
if __name__ == "__main__":
main()

68
tools/preview_dataset.py Normal file
View File

@@ -0,0 +1,68 @@
#!/usr/bin/env python3
"""
Quickly preview training pairs (original, transformed, H) from ICLayoutTrainingDataset.
Saves a grid image for visual inspection.
"""
from __future__ import annotations
import argparse
from pathlib import Path
import numpy as np
import torch
from PIL import Image
from torchvision.utils import make_grid, save_image
from data.ic_dataset import ICLayoutTrainingDataset
from utils.data_utils import get_transform
def to_pil(t: torch.Tensor) -> Image.Image:
# input normalized to [-1,1] for 3-channels; invert normalization
x = t.clone()
if x.dim() == 3 and x.size(0) == 3:
x = (x * 0.5) + 0.5 # unnormalize
x = (x * 255.0).clamp(0, 255).byte()
if x.dim() == 3 and x.size(0) == 3:
x = x
elif x.dim() == 3 and x.size(0) == 1:
x = x.repeat(3, 1, 1)
else:
raise ValueError("Unexpected tensor shape")
np_img = x.permute(1, 2, 0).cpu().numpy()
return Image.fromarray(np_img)
def main():
parser = argparse.ArgumentParser(description="Preview dataset samples")
parser.add_argument("--dir", dest="image_dir", type=str, required=True, help="PNG images directory")
parser.add_argument("--out", dest="out_path", type=str, default="preview.png")
parser.add_argument("--n", dest="num", type=int, default=8)
parser.add_argument("--patch", dest="patch_size", type=int, default=256)
parser.add_argument("--elastic", dest="use_elastic", action="store_true")
args = parser.parse_args()
transform = get_transform()
ds = ICLayoutTrainingDataset(
args.image_dir,
patch_size=args.patch_size,
transform=transform,
scale_range=(1.0, 1.0),
use_albu=args.use_elastic,
albu_params={"prob": 0.5},
)
images = []
for i in range(min(args.num, len(ds))):
orig, rot, H = ds[i]
# Stack orig and rot side-by-side for each sample
images.append(orig)
images.append(rot)
grid = make_grid(torch.stack(images, dim=0), nrow=2, padding=2)
save_image(grid, args.out_path)
print(f"Saved preview to {args.out_path}")
if __name__ == "__main__":
main()

76
tools/smoke_test.py Normal file
View File

@@ -0,0 +1,76 @@
#!/usr/bin/env python3
"""
Minimal smoke test:
1) Generate a tiny synthetic set (num=8) and rasterize to PNG
2) Validate H consistency (n=4, with/without elastic)
3) Run a short training loop (epochs=1-2) to verify end-to-end pipeline
Prints PASS/FAIL with basic stats.
"""
from __future__ import annotations
import argparse
import subprocess
import os
import sys
from pathlib import Path
def run(cmd: list[str]) -> int:
print("[RUN]", " ".join(cmd))
env = os.environ.copy()
# Ensure project root on PYTHONPATH for child processes
root = Path(__file__).resolve().parents[1]
env["PYTHONPATH"] = f"{root}:{env.get('PYTHONPATH','')}" if env.get("PYTHONPATH") else str(root)
return subprocess.call(cmd, env=env)
def main() -> None:
parser = argparse.ArgumentParser(description="Minimal smoke test for E2E pipeline")
parser.add_argument("--root", type=str, default="data/smoke", help="Root dir for smoke test outputs")
parser.add_argument("--config", type=str, default="configs/base_config.yaml")
args = parser.parse_args()
root = Path(args.root)
gds_dir = root / "gds"
png_dir = root / "png"
gds_dir.mkdir(parents=True, exist_ok=True)
png_dir.mkdir(parents=True, exist_ok=True)
rc = 0
# 1) Generate a tiny set
rc |= run([sys.executable, "tools/generate_synthetic_layouts.py", "--out_dir", gds_dir.as_posix(), "--num", "8", "--seed", "123"])
if rc != 0:
print("[FAIL] generate synthetic")
sys.exit(2)
# 2) Rasterize
rc |= run([sys.executable, "tools/layout2png.py", "--in", gds_dir.as_posix(), "--out", png_dir.as_posix(), "--dpi", "600"])
if rc != 0:
print("[FAIL] layout2png")
sys.exit(3)
# 3) Validate H (n=4, both no-elastic and elastic)
rc |= run([sys.executable, "tools/validate_h_consistency.py", "--dir", png_dir.as_posix(), "--out", (root/"validate_no_elastic").as_posix(), "--n", "4"])
rc |= run([sys.executable, "tools/validate_h_consistency.py", "--dir", png_dir.as_posix(), "--out", (root/"validate_elastic").as_posix(), "--n", "4", "--elastic"])
if rc != 0:
print("[FAIL] validate H")
sys.exit(4)
# 4) Write back config via synth_pipeline and run short training (1 epoch)
rc |= run([sys.executable, "tools/synth_pipeline.py", "--out_root", root.as_posix(), "--num", "0", "--dpi", "600", "--config", args.config, "--ratio", "0.3", "--enable_elastic", "--no_preview"])
if rc != 0:
print("[FAIL] synth_pipeline config update")
sys.exit(5)
# Train 1 epoch to smoke the loop
rc |= run([sys.executable, "train.py", "--config", args.config, "--epochs", "1" ])
if rc != 0:
print("[FAIL] train 1 epoch")
sys.exit(6)
print("[PASS] Smoke test completed successfully.")
if __name__ == "__main__":
main()

169
tools/synth_pipeline.py Normal file
View File

@@ -0,0 +1,169 @@
#!/usr/bin/env python3
"""
One-click synthetic data pipeline:
1) Generate synthetic GDS using tools/generate_synthetic_layouts.py
2) Rasterize GDS to PNG using tools/layout2png.py (KLayout preferred, fallback gdstk+SVG)
3) Preview random training pairs using tools/preview_dataset.py (optional)
4) Validate homography consistency using tools/validate_h_consistency.py (optional)
5) Optionally update a YAML config to enable synthetic mixing and elastic augmentation
"""
from __future__ import annotations
import argparse
import subprocess
import sys
from pathlib import Path
from omegaconf import OmegaConf
def run_cmd(cmd: list[str]) -> None:
print("[RUN]", " ".join(str(c) for c in cmd))
res = subprocess.run(cmd)
if res.returncode != 0:
raise SystemExit(f"Command failed with code {res.returncode}: {' '.join(map(str, cmd))}")
essential_scripts = {
"gen": Path("tools/generate_synthetic_layouts.py"),
"gds2png": Path("tools/layout2png.py"),
"preview": Path("tools/preview_dataset.py"),
"validate": Path("tools/validate_h_consistency.py"),
}
def ensure_scripts_exist() -> None:
missing = [str(p) for p in essential_scripts.values() if not p.exists()]
if missing:
raise SystemExit(f"Missing required scripts: {missing}")
def update_config(config_path: Path, png_dir: Path, ratio: float, enable_elastic: bool) -> None:
cfg = OmegaConf.load(config_path)
# Ensure nodes exist
if "synthetic" not in cfg:
cfg.synthetic = {}
cfg.synthetic.enabled = True
cfg.synthetic.png_dir = png_dir.as_posix()
cfg.synthetic.ratio = float(ratio)
if enable_elastic:
if "augment" not in cfg:
cfg.augment = {}
if "elastic" not in cfg.augment:
cfg.augment.elastic = {}
cfg.augment.elastic.enabled = True
# Don't override numeric params if already present
if "alpha" not in cfg.augment.elastic:
cfg.augment.elastic.alpha = 40
if "sigma" not in cfg.augment.elastic:
cfg.augment.elastic.sigma = 6
if "alpha_affine" not in cfg.augment.elastic:
cfg.augment.elastic.alpha_affine = 6
if "prob" not in cfg.augment.elastic:
cfg.augment.elastic.prob = 0.3
# Photometric defaults
if "photometric" not in cfg.augment:
cfg.augment.photometric = {"brightness_contrast": True, "gauss_noise": True}
OmegaConf.save(config=cfg, f=config_path)
print(f"[OK] Config updated: {config_path}")
def main() -> None:
parser = argparse.ArgumentParser(description="One-click synthetic data pipeline")
parser.add_argument("--out_root", type=str, default="data/synthetic", help="Root output dir for gds/png/preview")
parser.add_argument("--num", type=int, default=200, help="Number of GDS samples to generate")
parser.add_argument("--dpi", type=int, default=600, help="Rasterization DPI for PNG rendering")
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--ratio", type=float, default=0.3, help="Mixing ratio for synthetic data in training")
parser.add_argument("--config", type=str, default="configs/base_config.yaml", help="YAML config to update")
parser.add_argument("--enable_elastic", action="store_true", help="Also enable elastic augmentation in config")
parser.add_argument("--no_preview", action="store_true", help="Skip preview generation")
parser.add_argument("--validate_h", action="store_true", help="Run homography consistency validation on rendered PNGs")
parser.add_argument("--validate_n", type=int, default=6, help="Number of samples for H validation")
parser.add_argument("--diffusion_dir", type=str, default=None, help="Directory of diffusion-generated PNGs to include")
# Rendering style passthrough
parser.add_argument("--layermap", type=str, default=None, help="Layer color map for KLayout, e.g. '1/0:#00FF00,2/0:#FF0000'")
parser.add_argument("--line_width", type=int, default=None, help="Default draw line width for KLayout display")
parser.add_argument("--bgcolor", type=str, default=None, help="Background color for KLayout display")
args = parser.parse_args()
ensure_scripts_exist()
out_root = Path(args.out_root)
gds_dir = out_root / "gds"
png_dir = out_root / "png"
gds_dir.mkdir(parents=True, exist_ok=True)
png_dir.mkdir(parents=True, exist_ok=True)
# 1) Generate GDS
run_cmd([sys.executable, str(essential_scripts["gen"]), "--out_dir", gds_dir.as_posix(), "--num", str(args.num), "--seed", str(args.seed)])
# 2) GDS -> PNG
gds2png_cmd = [
sys.executable, str(essential_scripts["gds2png"]),
"--in", gds_dir.as_posix(),
"--out", png_dir.as_posix(),
"--dpi", str(args.dpi),
]
if args.layermap:
gds2png_cmd += ["--layermap", args.layermap]
if args.line_width is not None:
gds2png_cmd += ["--line_width", str(args.line_width)]
if args.bgcolor:
gds2png_cmd += ["--bgcolor", args.bgcolor]
run_cmd(gds2png_cmd)
# 3) Preview (optional)
if not args.no_preview:
preview_path = out_root / "preview.png"
preview_cmd = [sys.executable, str(essential_scripts["preview"]), "--dir", png_dir.as_posix(), "--out", preview_path.as_posix(), "--n", "8"]
if args.enable_elastic:
preview_cmd.append("--elastic")
run_cmd(preview_cmd)
# 4) Validate homography consistency (optional)
if args.validate_h:
validate_dir = out_root / "validate_h"
validate_cmd = [
sys.executable, str(essential_scripts["validate"]),
"--dir", png_dir.as_posix(),
"--out", validate_dir.as_posix(),
"--n", str(args.validate_n),
]
if args.enable_elastic:
validate_cmd.append("--elastic")
run_cmd(validate_cmd)
# 5) Update YAML config
update_config(Path(args.config), png_dir, args.ratio, args.enable_elastic)
# Include diffusion dir if provided (no automatic sampling here; integration only)
if args.diffusion_dir:
cfg = OmegaConf.load(args.config)
if "synthetic" not in cfg:
cfg.synthetic = {}
if "diffusion" not in cfg.synthetic:
cfg.synthetic.diffusion = {}
cfg.synthetic.diffusion.enabled = True
cfg.synthetic.diffusion.png_dir = Path(args.diffusion_dir).as_posix()
# Keep ratio default at 0 unless user updates later; or reuse a small default like 0.1? Keep 0.0 for safety.
if "ratio" not in cfg.synthetic.diffusion:
cfg.synthetic.diffusion.ratio = 0.0
OmegaConf.save(config=cfg, f=args.config)
print(f"[OK] Config updated with diffusion_dir: {args.diffusion_dir}")
print("\n[Done] Synthetic pipeline completed.")
print(f"- GDS: {gds_dir}")
print(f"- PNG: {png_dir}")
if args.diffusion_dir:
print(f"- Diffusion PNGs: {Path(args.diffusion_dir)}")
if not args.no_preview:
print(f"- Preview: {out_root / 'preview.png'}")
if args.validate_h:
print(f"- H validation: {out_root / 'validate_h'}")
print(f"- Updated config: {args.config}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,117 @@
#!/usr/bin/env python3
"""
Validate homography consistency produced by ICLayoutTrainingDataset.
For random samples, we check that cv2.warpPerspective(original, H) ≈ transformed.
Saves visual composites and prints basic metrics (MSE / PSNR).
"""
from __future__ import annotations
import argparse
from pathlib import Path
import sys
import cv2
import numpy as np
import torch
from PIL import Image
# Ensure project root is on sys.path when running as a script
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from data.ic_dataset import ICLayoutTrainingDataset
def tensor_to_u8_img(t: torch.Tensor) -> np.ndarray:
"""Convert 1xHxW or 3xHxW float tensor in [0,1] to uint8 HxW or HxWx3."""
if t.dim() != 3:
raise ValueError(f"Expect 3D tensor, got {t.shape}")
if t.size(0) == 1:
arr = (t.squeeze(0).cpu().numpy() * 255.0).clip(0, 255).astype(np.uint8)
elif t.size(0) == 3:
arr = (t.permute(1, 2, 0).cpu().numpy() * 255.0).clip(0, 255).astype(np.uint8)
else:
raise ValueError(f"Unexpected channels: {t.size(0)}")
return arr
def mse(a: np.ndarray, b: np.ndarray) -> float:
diff = a.astype(np.float32) - b.astype(np.float32)
return float(np.mean(diff * diff))
def psnr(a: np.ndarray, b: np.ndarray) -> float:
m = mse(a, b)
if m <= 1e-8:
return float('inf')
return 10.0 * np.log10((255.0 * 255.0) / m)
def main() -> None:
parser = argparse.ArgumentParser(description="Validate homography consistency")
parser.add_argument("--dir", dest="image_dir", type=str, required=True, help="PNG images directory")
parser.add_argument("--out", dest="out_dir", type=str, default="validate_h_out", help="Output directory for composites")
parser.add_argument("--n", dest="num", type=int, default=8, help="Number of samples to validate")
parser.add_argument("--patch", dest="patch_size", type=int, default=256)
parser.add_argument("--elastic", dest="use_elastic", action="store_true")
args = parser.parse_args()
out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
# Use no photometric/Sobel transform here to compare raw grayscale content
ds = ICLayoutTrainingDataset(
args.image_dir,
patch_size=args.patch_size,
transform=None,
scale_range=(1.0, 1.0),
use_albu=args.use_elastic,
albu_params={"prob": 0.5},
)
n = min(args.num, len(ds))
if n == 0:
print("[WARN] Empty dataset.")
return
mses = []
psnrs = []
for i in range(n):
patch_t, trans_t, H2x3_t = ds[i]
# Convert to uint8 arrays
patch_u8 = tensor_to_u8_img(patch_t)
trans_u8 = tensor_to_u8_img(trans_t)
if patch_u8.ndim == 3:
patch_u8 = cv2.cvtColor(patch_u8, cv2.COLOR_BGR2GRAY)
if trans_u8.ndim == 3:
trans_u8 = cv2.cvtColor(trans_u8, cv2.COLOR_BGR2GRAY)
# Reconstruct 3x3 H
H2x3 = H2x3_t.numpy()
H = np.vstack([H2x3, [0.0, 0.0, 1.0]]).astype(np.float32)
# Warp original with H
warped = cv2.warpPerspective(patch_u8, H, (patch_u8.shape[1], patch_u8.shape[0]))
# Metrics
m = mse(warped, trans_u8)
p = psnr(warped, trans_u8)
mses.append(m)
psnrs.append(p)
# Composite image: [orig | warped | transformed | absdiff]
diff = cv2.absdiff(warped, trans_u8)
comp = np.concatenate([
patch_u8, warped, trans_u8, diff
], axis=1)
out_path = out_dir / f"sample_{i:03d}.png"
cv2.imwrite(out_path.as_posix(), comp)
print(f"[OK] sample {i}: MSE={m:.2f}, PSNR={p:.2f} dB -> {out_path}")
print(f"\nSummary: MSE avg={np.mean(mses):.2f} ± {np.std(mses):.2f}, PSNR avg={np.mean(psnrs):.2f} dB")
if __name__ == "__main__":
main()