initial commit
This commit is contained in:
75
README.md
Normal file
75
README.md
Normal file
@@ -0,0 +1,75 @@
|
||||
# IC Layout Skeleton Generation (prototype)
|
||||
|
||||
This repo provides a prototype pipeline to handle strictly-Manhattan binary IC layout images: extract skeletons, vectorize skeletons into Manhattan polylines, train a skeleton-generation model, and expand skeletons back to images.
|
||||
|
||||
## Project Structure
|
||||
|
||||
- `scripts/`: Core processing scripts.
|
||||
- `skeleton_extract.py`: Binarize & extract skeleton PNG.
|
||||
- `vectorize_skeleton.py`: Trace skeleton PNG into JSON polylines.
|
||||
- `expand_skeleton.py`: Rasterize polyline JSON back to a PNG.
|
||||
- `models/`: PyTorch models (e.g., `SkeletonVAE`).
|
||||
- `train/`: Training scripts.
|
||||
- `datasets/`: Dataset building helpers.
|
||||
- `out/`: Output directory for generated data and logs (ignored by git).
|
||||
|
||||
## Setup
|
||||
|
||||
This project uses `uv` for dependency management.
|
||||
|
||||
1. **Install uv**:
|
||||
```bash
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
```
|
||||
|
||||
2. **Install dependencies**:
|
||||
```bash
|
||||
uv sync
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### 1. Prepare Data
|
||||
Ensure your dataset is available. For example, link your data folder:
|
||||
```bash
|
||||
ln -s ~/Documents/data/ICCAD2019/img ./ICCAD2019/img
|
||||
```
|
||||
(Note: The workspace already contains a link to `ICCAD2019` if configured).
|
||||
|
||||
### 2. Run the Pipeline
|
||||
|
||||
You can run individual scripts using `uv run`.
|
||||
|
||||
**Extract Skeleton:**
|
||||
```bash
|
||||
uv run scripts/skeleton_extract.py path/to/image.png out/result_dir --denoise 3
|
||||
```
|
||||
|
||||
**Vectorize Skeleton:**
|
||||
```bash
|
||||
uv run scripts/vectorize_skeleton.py out/result_dir/image_sk.png out/result_dir/image_vec.json
|
||||
```
|
||||
|
||||
**Expand (Reconstruct) Image:**
|
||||
```bash
|
||||
uv run scripts/expand_skeleton.py out/result_dir/image_vec.json out/result_dir/image_recon.png
|
||||
```
|
||||
|
||||
### 3. Build Dataset
|
||||
Batch process a folder of images:
|
||||
```bash
|
||||
uv run datasets/build_skeleton_dataset.py ICCAD2019/img out/dataset_processed
|
||||
```
|
||||
|
||||
### 4. Train Model
|
||||
Train the VAE on the processed skeletons:
|
||||
```bash
|
||||
uv run train/train_skeleton_vae.py out/dataset_processed --epochs 20 --batch 16
|
||||
```
|
||||
|
||||
## Debugging
|
||||
A debug script is provided to run the full pipeline on a few random images from the dataset:
|
||||
```bash
|
||||
uv run debug_pipeline.py
|
||||
```
|
||||
Results will be in `out/debug`.
|
||||
39
datasets/build_skeleton_dataset.py
Normal file
39
datasets/build_skeleton_dataset.py
Normal file
@@ -0,0 +1,39 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Build dataset: given input images folder, extract skeletons and vectorize, save pairs.
|
||||
"""
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
import subprocess
|
||||
|
||||
|
||||
def main():
|
||||
p = argparse.ArgumentParser()
|
||||
p.add_argument('img_folder')
|
||||
p.add_argument('out_folder')
|
||||
p.add_argument('--invert', action='store_true')
|
||||
args = p.parse_args()
|
||||
|
||||
out = Path(args.out_folder)
|
||||
out.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
imgs = list(Path(args.img_folder).glob('*.png')) + list(Path(args.img_folder).glob('*.jpg'))
|
||||
for im in imgs:
|
||||
od = out / im.stem
|
||||
od.mkdir(exist_ok=True)
|
||||
# call skeleton_extract
|
||||
cmd = ['python3', str(Path(__file__).parents[1] / 'scripts' / 'skeleton_extract.py'), str(im), str(od)]
|
||||
if args.invert:
|
||||
cmd.append('--invert')
|
||||
subprocess.run(cmd, check=True)
|
||||
# vectorize
|
||||
sk_png = od / (im.stem + '_sk.png')
|
||||
outjson = od / (im.stem + '_sk.json')
|
||||
cmd2 = ['python3', str(Path(__file__).parents[1] / 'scripts' / 'vectorize_skeleton.py'), str(sk_png), str(outjson)]
|
||||
subprocess.run(cmd2, check=True)
|
||||
|
||||
print('Built dataset in', out)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
87
debug_pipeline.py
Normal file
87
debug_pipeline.py
Normal file
@@ -0,0 +1,87 @@
|
||||
#!/usr/bin/env python3
|
||||
import os
|
||||
import random
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
|
||||
def run_command(cmd):
|
||||
print(f"Running: {' '.join(cmd)}")
|
||||
subprocess.run(cmd, check=True)
|
||||
|
||||
def main():
|
||||
# Setup paths
|
||||
base_dir = Path(__file__).parent
|
||||
img_dir = base_dir / "ICCAD2019" / "img"
|
||||
out_dir = base_dir / "out" / "debug1"
|
||||
|
||||
if out_dir.exists():
|
||||
shutil.rmtree(out_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Get images
|
||||
all_images = list(img_dir.glob("*.png"))
|
||||
if not all_images:
|
||||
print("No images found in ICCAD2019/img")
|
||||
return
|
||||
|
||||
selected_images = random.sample(all_images, min(5, len(all_images)))
|
||||
|
||||
print(f"Selected {len(selected_images)} images for debugging.")
|
||||
|
||||
for img_path in selected_images:
|
||||
print(f"\nProcessing {img_path.name}...")
|
||||
case_dir = out_dir / img_path.stem
|
||||
case_dir.mkdir(exist_ok=True)
|
||||
|
||||
# 1. Skeleton Extraction
|
||||
# Note: The user mentioned "black background, white subject" or vice versa.
|
||||
# The script has --invert. I'll try without invert first, assuming standard layout (often drawn dark on light or light on dark).
|
||||
# Let's check one image first? No, I'll just run it.
|
||||
# Actually, usually layouts are drawn objects. If background is black, objects are white.
|
||||
# My script assumes objects are white (value 1) for skeletonization.
|
||||
# If the image is white background, I need to invert.
|
||||
# I'll try both or just assume one. Let's assume we might need --invert if it's white background.
|
||||
# I'll run with --invert if the mean pixel value is high (white background).
|
||||
|
||||
# But wait, I can't easily check mean here without loading.
|
||||
# I'll just run the extraction script.
|
||||
|
||||
extract_cmd = [
|
||||
"python3", "scripts/skeleton_extract.py",
|
||||
str(img_path),
|
||||
str(case_dir),
|
||||
"--denoise", "3"
|
||||
]
|
||||
# Heuristic: if filename contains 'nonhotspot', it might be a clip.
|
||||
# Let's just try running it.
|
||||
run_command(extract_cmd)
|
||||
|
||||
sk_png = case_dir / (img_path.stem + "_sk.png")
|
||||
if not sk_png.exists():
|
||||
print(f"Failed to generate skeleton for {img_path.name}")
|
||||
continue
|
||||
|
||||
# 2. Vectorization
|
||||
vec_json = case_dir / (img_path.stem + "_vec.json")
|
||||
vec_cmd = [
|
||||
"python3", "scripts/vectorize_skeleton.py",
|
||||
str(sk_png),
|
||||
str(vec_json)
|
||||
]
|
||||
run_command(vec_cmd)
|
||||
|
||||
# 3. Expansion (Reconstruction)
|
||||
recon_png = case_dir / (img_path.stem + "_recon.png")
|
||||
expand_cmd = [
|
||||
"python3", "scripts/expand_skeleton.py",
|
||||
str(vec_json),
|
||||
str(recon_png),
|
||||
"--line-width", "3" # Adjust as needed
|
||||
]
|
||||
run_command(expand_cmd)
|
||||
|
||||
print(f"\nDebug run complete. Check results in {out_dir}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
0
models/__init__.py
Normal file
0
models/__init__.py
Normal file
62
models/skeleton_vae.py
Normal file
62
models/skeleton_vae.py
Normal file
@@ -0,0 +1,62 @@
|
||||
"""
|
||||
Simple convolutional VAE for skeleton images (single-channel)
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class ConvEncoder(nn.Module):
|
||||
def __init__(self, z_dim=64):
|
||||
super().__init__()
|
||||
self.enc = nn.Sequential(
|
||||
nn.Conv2d(1,32,4,2,1),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(32,64,4,2,1),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(64,128,4,2,1),
|
||||
nn.ReLU(),
|
||||
nn.Flatten()
|
||||
)
|
||||
self.fc_mu = nn.Linear(128*8*8, z_dim)
|
||||
self.fc_logvar = nn.Linear(128*8*8, z_dim)
|
||||
|
||||
def forward(self, x):
|
||||
h = self.enc(x)
|
||||
return self.fc_mu(h), self.fc_logvar(h)
|
||||
|
||||
|
||||
class ConvDecoder(nn.Module):
|
||||
def __init__(self, z_dim=64):
|
||||
super().__init__()
|
||||
self.fc = nn.Linear(z_dim, 128*8*8)
|
||||
self.dec = nn.Sequential(
|
||||
nn.Unflatten(1, (128,8,8)),
|
||||
nn.ConvTranspose2d(128,64,4,2,1),
|
||||
nn.ReLU(),
|
||||
nn.ConvTranspose2d(64,32,4,2,1),
|
||||
nn.ReLU(),
|
||||
nn.ConvTranspose2d(32,1,4,2,1),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
def forward(self, z):
|
||||
h = self.fc(z)
|
||||
return self.dec(h)
|
||||
|
||||
|
||||
class SkeletonVAE(nn.Module):
|
||||
def __init__(self, z_dim=64):
|
||||
super().__init__()
|
||||
self.encoder = ConvEncoder(z_dim)
|
||||
self.decoder = ConvDecoder(z_dim)
|
||||
|
||||
def reparameterize(self, mu, logvar):
|
||||
std = (0.5 * logvar).exp()
|
||||
eps = torch.randn_like(std)
|
||||
return mu + eps * std
|
||||
|
||||
def forward(self, x):
|
||||
mu, logvar = self.encoder(x)
|
||||
z = self.reparameterize(mu, logvar)
|
||||
out = self.decoder(z)
|
||||
return out, mu, logvar
|
||||
23
pyproject.toml
Normal file
23
pyproject.toml
Normal file
@@ -0,0 +1,23 @@
|
||||
[project]
|
||||
name = "ic-layout-generate"
|
||||
version = "0.1.0"
|
||||
description = "IC Layout Generation using Skeleton Extraction and Expansion"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.9"
|
||||
dependencies = [
|
||||
"numpy",
|
||||
"opencv-python-headless",
|
||||
"scikit-image",
|
||||
"networkx",
|
||||
"shapely",
|
||||
"Pillow",
|
||||
"torch",
|
||||
"torchvision",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["models", "scripts", "datasets", "train"]
|
||||
8
requirements.txt
Normal file
8
requirements.txt
Normal file
@@ -0,0 +1,8 @@
|
||||
numpy
|
||||
opencv-python-headless
|
||||
scikit-image
|
||||
networkx
|
||||
shapely
|
||||
Pillow
|
||||
torch
|
||||
torchvision
|
||||
67
run_experiment.py
Normal file
67
run_experiment.py
Normal file
@@ -0,0 +1,67 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Run the full experiment:
|
||||
1. Build dataset from raw images
|
||||
2. Train VAE model
|
||||
3. Generate new skeletons
|
||||
"""
|
||||
import subprocess
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
def run_command(cmd):
|
||||
print(f"Running: {' '.join(cmd)}")
|
||||
subprocess.run(cmd, check=True)
|
||||
|
||||
def main():
|
||||
p = argparse.ArgumentParser()
|
||||
p.add_argument('--img_dir', type=str, default='ICCAD2019/img', help='Input raw images directory')
|
||||
p.add_argument('--data_dir', type=str, default='out/dataset', help='Output dataset directory')
|
||||
p.add_argument('--model_dir', type=str, default='out/models', help='Directory to save models')
|
||||
p.add_argument('--gen_dir', type=str, default='out/generated', help='Directory for generated samples')
|
||||
p.add_argument('--epochs', type=int, default=20, help='Training epochs')
|
||||
p.add_argument('--batch_size', type=int, default=16, help='Batch size')
|
||||
p.add_argument('--skip_data', action='store_true', help='Skip dataset building')
|
||||
p.add_argument('--skip_train', action='store_true', help='Skip training')
|
||||
args = p.parse_args()
|
||||
|
||||
# 1. Build Dataset
|
||||
if not args.skip_data:
|
||||
print("=== Step 1: Building Dataset ===")
|
||||
# Note: build_skeleton_dataset.py expects img_folder and out_folder
|
||||
cmd = [
|
||||
'python3', 'datasets/build_skeleton_dataset.py',
|
||||
args.img_dir,
|
||||
args.data_dir
|
||||
]
|
||||
run_command(cmd)
|
||||
|
||||
# 2. Train Model
|
||||
model_path = Path(args.model_dir) / 'vae_best.pth'
|
||||
if not args.skip_train:
|
||||
print("=== Step 2: Training VAE ===")
|
||||
cmd = [
|
||||
'python3', 'train/train_skeleton_vae.py',
|
||||
args.data_dir,
|
||||
'--epochs', str(args.epochs),
|
||||
'--batch', str(args.batch_size),
|
||||
'--save_path', str(model_path),
|
||||
'--sample_dir', str(Path(args.model_dir) / 'train_samples')
|
||||
]
|
||||
run_command(cmd)
|
||||
|
||||
# 3. Generate
|
||||
print("=== Step 3: Generating Skeletons ===")
|
||||
cmd = [
|
||||
'python3', 'scripts/generate_skeleton.py',
|
||||
'--model_path', str(model_path),
|
||||
'--out_dir', args.gen_dir,
|
||||
'--num_samples', '20'
|
||||
]
|
||||
run_command(cmd)
|
||||
|
||||
print("=== Experiment Complete ===")
|
||||
print(f"Generated images are in {args.gen_dir}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
43
scripts/expand_skeleton.py
Normal file
43
scripts/expand_skeleton.py
Normal file
@@ -0,0 +1,43 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
expand_skeleton.py
|
||||
Rasterize Manhattan polylines (JSON) back to a binary layout image.
|
||||
"""
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
import json
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
|
||||
def draw_polylines(polylines, shape, line_width=3):
|
||||
h, w = shape
|
||||
img = np.zeros((h, w), dtype='uint8')
|
||||
for pl in polylines:
|
||||
if len(pl) < 2:
|
||||
continue
|
||||
pts = [(int(x), int(y)) for x,y in pl]
|
||||
for i in range(len(pts)-1):
|
||||
p0 = pts[i]
|
||||
p1 = pts[i+1]
|
||||
cv2.line(img, p0, p1, color=255, thickness=line_width)
|
||||
return (img > 127).astype('uint8')
|
||||
|
||||
|
||||
def main():
|
||||
p = argparse.ArgumentParser()
|
||||
p.add_argument('polyjson')
|
||||
p.add_argument('outpng')
|
||||
p.add_argument('--line-width', type=int, default=3)
|
||||
args = p.parse_args()
|
||||
|
||||
data = json.load(open(args.polyjson,'r'))
|
||||
polylines = data.get('polylines', [])
|
||||
shape = tuple(data.get('shape', [512,512]))
|
||||
img = draw_polylines(polylines, shape, line_width=args.line_width)
|
||||
cv2.imwrite(args.outpng, img*255)
|
||||
print('Wrote', args.outpng)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
63
scripts/generate_skeleton.py
Normal file
63
scripts/generate_skeleton.py
Normal file
@@ -0,0 +1,63 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Generate new skeletons using the trained VAE model.
|
||||
"""
|
||||
import argparse
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import torch
|
||||
from torchvision.utils import save_image
|
||||
import numpy as np
|
||||
|
||||
# Add root to path to allow importing models
|
||||
sys.path.append(str(Path(__file__).parents[1]))
|
||||
|
||||
from models.skeleton_vae import SkeletonVAE
|
||||
|
||||
def main():
|
||||
p = argparse.ArgumentParser()
|
||||
p.add_argument('--model_path', type=str, required=True, help='Path to trained .pth model')
|
||||
p.add_argument('--out_dir', type=str, default='out/generated', help='Output directory for generated images')
|
||||
p.add_argument('--num_samples', type=int, default=10, help='Number of samples to generate')
|
||||
p.add_argument('--z_dim', type=int, default=64, help='Latent dimension size')
|
||||
p.add_argument('--threshold', type=float, default=0.5, help='Binarization threshold')
|
||||
args = p.parse_args()
|
||||
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
# Load model
|
||||
model = SkeletonVAE(z_dim=args.z_dim).to(device)
|
||||
if not Path(args.model_path).exists():
|
||||
print(f"Error: Model not found at {args.model_path}")
|
||||
return
|
||||
|
||||
model.load_state_dict(torch.load(args.model_path, map_location=device))
|
||||
model.eval()
|
||||
|
||||
out_dir = Path(args.out_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"Generating {args.num_samples} samples...")
|
||||
|
||||
with torch.no_grad():
|
||||
# Sample from standard normal distribution
|
||||
z = torch.randn(args.num_samples, args.z_dim).to(device)
|
||||
|
||||
# Decode
|
||||
generated = model.decoder(z)
|
||||
|
||||
# Save raw and binarized
|
||||
for i in range(args.num_samples):
|
||||
img = generated[i]
|
||||
|
||||
# Save raw probability map
|
||||
save_image(img, out_dir / f'gen_{i}_raw.png')
|
||||
|
||||
# Binarize
|
||||
bin_img = (img > args.threshold).float()
|
||||
save_image(bin_img, out_dir / f'gen_{i}_bin.png')
|
||||
|
||||
print(f"Saved generated images to {out_dir}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
79
scripts/skeleton_extract.py
Normal file
79
scripts/skeleton_extract.py
Normal file
@@ -0,0 +1,79 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
skeleton_extract.py
|
||||
Extract Manhattan-style skeleton from binary layout images.
|
||||
Supports optional color inversion and basic denoising.
|
||||
Outputs skeleton PNG and a JSON summary with connected-component counts.
|
||||
"""
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import cv2
|
||||
from skimage.morphology import skeletonize
|
||||
from skimage import img_as_ubyte
|
||||
import json
|
||||
|
||||
|
||||
def load_image(path, invert=False):
|
||||
im = Image.open(path).convert('L')
|
||||
a = np.array(im)
|
||||
# auto threshold by Otsu
|
||||
_, th = cv2.threshold(a, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
||||
if invert:
|
||||
th = 255 - th
|
||||
return th // 255
|
||||
|
||||
|
||||
def denoise(bin_img, kernel=3):
|
||||
# simple opening/closing
|
||||
k = cv2.getStructuringElement(cv2.MORPH_RECT, (kernel, kernel))
|
||||
img = cv2.morphologyEx((bin_img * 255).astype('uint8'), cv2.MORPH_OPEN, k)
|
||||
img = cv2.morphologyEx(img, cv2.MORPH_CLOSE, k)
|
||||
return img // 255
|
||||
|
||||
|
||||
def extract_skeleton(bin_img):
|
||||
# skimage expects bool image
|
||||
sk = skeletonize(bin_img > 0)
|
||||
return sk.astype('uint8')
|
||||
|
||||
|
||||
def save_png(arr, path):
|
||||
im = Image.fromarray((arr * 255).astype('uint8'))
|
||||
im.save(path)
|
||||
|
||||
|
||||
def main():
|
||||
p = argparse.ArgumentParser()
|
||||
p.add_argument('input')
|
||||
p.add_argument('outdir')
|
||||
p.add_argument('--invert', action='store_true', help='Invert black/white before processing')
|
||||
p.add_argument('--denoise', type=int, default=3, help='Denoise kernel size')
|
||||
args = p.parse_args()
|
||||
|
||||
inp = Path(args.input)
|
||||
out = Path(args.outdir)
|
||||
out.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
bin_img = load_image(inp, invert=args.invert)
|
||||
if args.denoise and args.denoise > 0:
|
||||
bin_img = denoise(bin_img, kernel=args.denoise)
|
||||
|
||||
sk = extract_skeleton(bin_img)
|
||||
|
||||
save_png(bin_img, out / (inp.stem + '_bin.png'))
|
||||
save_png(sk, out / (inp.stem + '_sk.png'))
|
||||
|
||||
# summary
|
||||
num_pixels = int(sk.sum())
|
||||
components = cv2.connectedComponents((sk * 255).astype('uint8'))[0] - 1
|
||||
info = {'input': str(inp), 'pixels_in_skeleton': int(num_pixels), 'components': int(components)}
|
||||
with open(out / (inp.stem + '_sk.json'), 'w') as f:
|
||||
json.dump(info, f, indent=2)
|
||||
|
||||
print('Saved:', out / (inp.stem + '_sk.png'))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
308
scripts/vectorize_skeleton.py
Normal file
308
scripts/vectorize_skeleton.py
Normal file
@@ -0,0 +1,308 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
vectorize_skeleton.py
|
||||
Trace skeleton PNG to Manhattan polylines (simple 4-neighbor tracing) and export JSON.
|
||||
"""
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import json
|
||||
import cv2
|
||||
from collections import deque
|
||||
|
||||
|
||||
def load_skeleton(path):
|
||||
im = Image.open(path).convert('L')
|
||||
a = np.array(im)
|
||||
return (a > 127).astype('uint8')
|
||||
|
||||
|
||||
def neighbors4(y, x, h, w):
|
||||
for dy, dx in ((0,1),(1,0),(0,-1),(-1,0)):
|
||||
ny, nx = y+dy, x+dx
|
||||
if 0 <= ny < h and 0 <= nx < w:
|
||||
yield ny, nx
|
||||
|
||||
|
||||
def trace_components(sk):
|
||||
h, w = sk.shape
|
||||
visited = np.zeros_like(sk, dtype=bool)
|
||||
comps = []
|
||||
for y in range(h):
|
||||
for x in range(w):
|
||||
if sk[y,x] and not visited[y,x]:
|
||||
# BFS to collect component pixels
|
||||
q = deque()
|
||||
q.append((y,x))
|
||||
visited[y,x] = True
|
||||
pts = []
|
||||
while q:
|
||||
cy, cx = q.popleft()
|
||||
pts.append((int(cx), int(cy)))
|
||||
for ny, nx in neighbors4(cy, cx, h, w):
|
||||
if sk[ny,nx] and not visited[ny,nx]:
|
||||
visited[ny,nx] = True
|
||||
q.append((ny,nx))
|
||||
comps.append(pts)
|
||||
return comps
|
||||
|
||||
|
||||
def prune_segments(segments, min_len=5, merge_tol=3):
|
||||
if not segments: return []
|
||||
|
||||
# 1. Merge consecutive same-direction segments (simple pass)
|
||||
# Repeat until no more merges to handle multi-segment merges
|
||||
while True:
|
||||
merged = []
|
||||
changed = False
|
||||
if segments:
|
||||
curr = segments[0]
|
||||
for i in range(1, len(segments)):
|
||||
next_seg = segments[i]
|
||||
if curr[0] == next_seg[0] and abs(curr[1] - next_seg[1]) <= merge_tol:
|
||||
# Merge: extend end_idx, re-calculate average val
|
||||
len1 = curr[3] - curr[2] + 1
|
||||
len2 = next_seg[3] - next_seg[2] + 1
|
||||
new_val = int(round((curr[1]*len1 + next_seg[1]*len2) / (len1+len2)))
|
||||
curr = (curr[0], new_val, curr[2], next_seg[3])
|
||||
changed = True
|
||||
else:
|
||||
merged.append(curr)
|
||||
curr = next_seg
|
||||
merged.append(curr)
|
||||
segments = merged
|
||||
if not changed:
|
||||
break
|
||||
|
||||
# 2. Remove short tails
|
||||
# Check start
|
||||
if len(segments) > 1:
|
||||
s0 = segments[0]
|
||||
l0 = s0[3] - s0[2] + 1
|
||||
if l0 < min_len:
|
||||
segments.pop(0)
|
||||
|
||||
# Check end
|
||||
if len(segments) > 1:
|
||||
s_last = segments[-1]
|
||||
l_last = s_last[3] - s_last[2] + 1
|
||||
if l_last < min_len:
|
||||
segments.pop(-1)
|
||||
|
||||
# 3. Remove short internal bumps (Z-shape)
|
||||
# H1 -> V(short) -> H2 => Merge H1, H2 if aligned
|
||||
# We iterate and build a new list. If merge happens, we modify the next segment and skip.
|
||||
final_segs = []
|
||||
i = 0
|
||||
while i < len(segments):
|
||||
curr = segments[i]
|
||||
merged_bump = False
|
||||
if i + 2 < len(segments):
|
||||
mid = segments[i+1]
|
||||
next_seg = segments[i+2]
|
||||
|
||||
# Check pattern: A -> B(short) -> A
|
||||
if curr[0] == next_seg[0] and mid[0] != curr[0]:
|
||||
len_mid = mid[3] - mid[2] + 1
|
||||
if len_mid < min_len:
|
||||
# Check alignment of curr and next_seg
|
||||
if abs(curr[1] - next_seg[1]) <= merge_tol:
|
||||
# Merge all three: curr + mid + next_seg -> new_curr
|
||||
len1 = curr[3] - curr[2] + 1
|
||||
len3 = next_seg[3] - next_seg[2] + 1
|
||||
new_val = int(round((curr[1]*len1 + next_seg[1]*len3) / (len1+len3)))
|
||||
# New segment spans from curr.start to next_seg.end
|
||||
new_seg = (curr[0], new_val, curr[2], next_seg[3])
|
||||
|
||||
# We effectively skip mid and next_seg, and replace curr with new_seg
|
||||
# But we need to check if this new_seg can merge further?
|
||||
# For simplicity, let's push new_seg to final_segs and skip 2.
|
||||
# But wait, if we push to final_segs, we can't merge it with subsequent ones in this loop easily.
|
||||
# Let's update segments[i+2] to be the merged one and continue loop from i+2.
|
||||
segments[i+2] = new_seg
|
||||
i += 2
|
||||
merged_bump = True
|
||||
continue
|
||||
|
||||
if not merged_bump:
|
||||
final_segs.append(curr)
|
||||
i += 1
|
||||
|
||||
return final_segs
|
||||
|
||||
|
||||
def simplify_manhattan(pts, tolerance=2):
|
||||
if not pts: return []
|
||||
if len(pts) < 2: return pts
|
||||
|
||||
segments = []
|
||||
n = len(pts)
|
||||
i = 0
|
||||
while i < n - 1:
|
||||
start_pt = pts[i]
|
||||
|
||||
# Check Horizontal
|
||||
k_h = i + 1
|
||||
while k_h < n:
|
||||
if abs(pts[k_h][1] - start_pt[1]) > tolerance:
|
||||
break
|
||||
k_h += 1
|
||||
len_h = k_h - i
|
||||
|
||||
# Check Vertical
|
||||
k_v = i + 1
|
||||
while k_v < n:
|
||||
if abs(pts[k_v][0] - start_pt[0]) > tolerance:
|
||||
break
|
||||
k_v += 1
|
||||
len_v = k_v - i
|
||||
|
||||
if len_h >= len_v:
|
||||
# Horizontal
|
||||
segment_pts = pts[i:k_h]
|
||||
avg_y = int(round(np.mean([p[1] for p in segment_pts])))
|
||||
segments.append(('H', avg_y, i, k_h - 1))
|
||||
i = k_h - 1
|
||||
else:
|
||||
# Vertical
|
||||
segment_pts = pts[i:k_v]
|
||||
avg_x = int(round(np.mean([p[0] for p in segment_pts])))
|
||||
segments.append(('V', avg_x, i, k_v - 1))
|
||||
i = k_v - 1
|
||||
|
||||
# --- Pruning / Refining ---
|
||||
segments = prune_segments(segments, min_len=5, merge_tol=3)
|
||||
|
||||
if not segments:
|
||||
return []
|
||||
|
||||
out_poly = []
|
||||
|
||||
# First point
|
||||
first_seg = segments[0]
|
||||
first_pt_orig = pts[first_seg[2]]
|
||||
if first_seg[0] == 'H':
|
||||
curr = (first_pt_orig[0], first_seg[1])
|
||||
else:
|
||||
curr = (first_seg[1], first_pt_orig[1])
|
||||
out_poly.append(curr)
|
||||
|
||||
for idx in range(len(segments) - 1):
|
||||
s1 = segments[idx]
|
||||
s2 = segments[idx+1]
|
||||
|
||||
# Transition point index is s1[3] (which is same as s2[2])
|
||||
# But after pruning, s1[3] might not be s2[2] - 1. Gaps might exist.
|
||||
# We should just connect s1's end to s2's start via a Manhattan corner.
|
||||
|
||||
# s1 end point (projected)
|
||||
if s1[0] == 'H':
|
||||
p1_end = (pts[s1[3]][0], s1[1])
|
||||
else:
|
||||
p1_end = (s1[1], pts[s1[3]][1])
|
||||
|
||||
# s2 start point (projected)
|
||||
if s2[0] == 'H':
|
||||
p2_start = (pts[s2[2]][0], s2[1])
|
||||
else:
|
||||
p2_start = (s2[1], pts[s2[2]][1])
|
||||
|
||||
# Connect p1_end to p2_start
|
||||
if s1[0] == 'H' and s2[0] == 'V':
|
||||
# H(y1) -> V(x2)
|
||||
# Intersection is (x2, y1)
|
||||
corner = (s2[1], s1[1])
|
||||
if out_poly[-1] != corner: out_poly.append(corner)
|
||||
elif s1[0] == 'V' and s2[0] == 'H':
|
||||
# V(x1) -> H(y2)
|
||||
# Intersection is (x1, y2)
|
||||
corner = (s1[1], s2[1])
|
||||
if out_poly[-1] != corner: out_poly.append(corner)
|
||||
else:
|
||||
# Parallel segments (should have been merged, but if gap was large...)
|
||||
# Or H -> H (gap)
|
||||
# Just connect via midpoint or direct L-shape?
|
||||
# Let's use p1_end -> p2_start directly? No, need Manhattan.
|
||||
# H(y1) ... H(y2). Connect (x_end1, y1) -> (x_end1, y2) -> (x_start2, y2) ?
|
||||
# Or (x_end1, y1) -> (x_start2, y1) -> (x_start2, y2) ?
|
||||
# Let's use the first one (Vertical bridge).
|
||||
if out_poly[-1] != p1_end: out_poly.append(p1_end)
|
||||
|
||||
if s1[0] == 'H':
|
||||
# Bridge is Vertical
|
||||
mid_x = (p1_end[0] + p2_start[0]) // 2
|
||||
c1 = (mid_x, p1_end[1])
|
||||
c2 = (mid_x, p2_start[1])
|
||||
if out_poly[-1] != c1: out_poly.append(c1)
|
||||
if c1 != c2: out_poly.append(c2)
|
||||
if c2 != p2_start: out_poly.append(p2_start)
|
||||
else:
|
||||
# Bridge is Horizontal
|
||||
mid_y = (p1_end[1] + p2_start[1]) // 2
|
||||
c1 = (p1_end[0], mid_y)
|
||||
c2 = (p2_start[0], mid_y)
|
||||
if out_poly[-1] != c1: out_poly.append(c1)
|
||||
if c1 != c2: out_poly.append(c2)
|
||||
if c2 != p2_start: out_poly.append(p2_start)
|
||||
|
||||
# Last point
|
||||
last_seg = segments[-1]
|
||||
last_pt_orig = pts[last_seg[3]]
|
||||
if last_seg[0] == 'H':
|
||||
end = (last_pt_orig[0], last_seg[1])
|
||||
else:
|
||||
end = (last_seg[1], last_pt_orig[1])
|
||||
if out_poly[-1] != end: out_poly.append(end)
|
||||
|
||||
return out_poly
|
||||
|
||||
|
||||
def polyline_from_pixels(pts):
|
||||
# pts: list of (x,y) unordered; produce simple ordering by greedy nearest neighbor
|
||||
if not pts:
|
||||
return []
|
||||
pts_set = set(pts)
|
||||
# start from leftmost-topmost
|
||||
cur = min(pts, key=lambda p: (p[1], p[0]))
|
||||
seq = [cur]
|
||||
pts_set.remove(cur)
|
||||
while pts_set:
|
||||
# look for 4-neighbor next
|
||||
x,y = seq[-1]
|
||||
found = None
|
||||
for nx, ny in ((x+1,y),(x-1,y),(x,y+1),(x,y-1)):
|
||||
if (nx,ny) in pts_set:
|
||||
found = (nx,ny)
|
||||
break
|
||||
if found is None:
|
||||
# fallback: nearest
|
||||
found = min(pts_set, key=lambda p: (abs(p[0]-x)+abs(p[1]-y), p[1], p[0]))
|
||||
seq.append(found)
|
||||
pts_set.remove(found)
|
||||
|
||||
return simplify_manhattan(seq, tolerance=2)
|
||||
|
||||
|
||||
def main():
|
||||
p = argparse.ArgumentParser()
|
||||
p.add_argument('skel_png')
|
||||
p.add_argument('outjson')
|
||||
args = p.parse_args()
|
||||
|
||||
sk = load_skeleton(args.skel_png)
|
||||
comps = trace_components(sk)
|
||||
polylines = []
|
||||
for pts in comps:
|
||||
pl = polyline_from_pixels(pts)
|
||||
if len(pl) > 1:
|
||||
polylines.append(pl)
|
||||
|
||||
out = {'polylines': polylines, 'num': len(polylines), 'shape': sk.shape}
|
||||
with open(args.outjson, 'w') as f:
|
||||
json.dump(out, f)
|
||||
print('Wrote', args.outjson)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
136
train/train_skeleton_vae.py
Normal file
136
train/train_skeleton_vae.py
Normal file
@@ -0,0 +1,136 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Minimal training loop for SkeletonVAE. Expects a directory of skeleton PNGs.
|
||||
"""
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
import torch
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from PIL import Image
|
||||
import torchvision.transforms as T
|
||||
import numpy as np
|
||||
import logging
|
||||
import time
|
||||
from models.skeleton_vae import SkeletonVAE
|
||||
|
||||
|
||||
class SkeletonDataset(Dataset):
|
||||
def __init__(self, folder, size=64):
|
||||
self.files = list(Path(folder).glob('**/*_sk.png'))
|
||||
self.tr = T.Compose([T.Resize((size,size)), T.ToTensor()])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.files)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
p = self.files[idx]
|
||||
im = Image.open(p).convert('L')
|
||||
t = self.tr(im)
|
||||
return t
|
||||
|
||||
|
||||
def loss_fn(recon_x, x, mu, logvar):
|
||||
BCE = torch.nn.functional.binary_cross_entropy(recon_x, x, reduction='sum')
|
||||
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
|
||||
return BCE + KLD, BCE, KLD
|
||||
|
||||
|
||||
def setup_logger(log_dir):
|
||||
log_dir = Path(log_dir)
|
||||
log_dir.mkdir(parents=True, exist_ok=True)
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.FileHandler(log_dir / 'train.log'),
|
||||
logging.StreamHandler()
|
||||
],
|
||||
force=True
|
||||
)
|
||||
return logging.getLogger(__name__)
|
||||
|
||||
|
||||
def main():
|
||||
p = argparse.ArgumentParser()
|
||||
p.add_argument('skel_folder')
|
||||
p.add_argument('--epochs', type=int, default=10)
|
||||
p.add_argument('--batch', type=int, default=16)
|
||||
p.add_argument('--lr', type=float, default=1e-3)
|
||||
p.add_argument('--save_path', type=str, default='out/vae_model.pth')
|
||||
p.add_argument('--sample_dir', type=str, default='out/samples')
|
||||
args = p.parse_args()
|
||||
|
||||
Path(args.sample_dir).mkdir(parents=True, exist_ok=True)
|
||||
Path(args.save_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logger = setup_logger(Path(args.save_path).parent)
|
||||
|
||||
ds = SkeletonDataset(args.skel_folder, size=64)
|
||||
if len(ds) == 0:
|
||||
logger.error(f"No skeleton images found in {args.skel_folder}")
|
||||
return
|
||||
|
||||
dl = DataLoader(ds, batch_size=args.batch, shuffle=True, num_workers=4)
|
||||
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
logger.info(f"Using device: {device}")
|
||||
logger.info(f"Dataset size: {len(ds)} images")
|
||||
|
||||
model = SkeletonVAE(z_dim=64).to(device)
|
||||
opt = torch.optim.Adam(model.parameters(), lr=args.lr)
|
||||
|
||||
logger.info("Starting training...")
|
||||
for epoch in range(args.epochs):
|
||||
start_time = time.time()
|
||||
model.train()
|
||||
total_loss = 0.0
|
||||
total_bce = 0.0
|
||||
total_kld = 0.0
|
||||
|
||||
for i, xb in enumerate(dl):
|
||||
xb = xb.to(device)
|
||||
recon, mu, logvar = model(xb)
|
||||
loss, bce, kld = loss_fn(recon, xb, mu, logvar)
|
||||
|
||||
opt.zero_grad()
|
||||
loss.backward()
|
||||
opt.step()
|
||||
|
||||
total_loss += loss.item()
|
||||
total_bce += bce.item()
|
||||
total_kld += kld.item()
|
||||
|
||||
if (i + 1) % 10 == 0:
|
||||
logger.info(f"Epoch [{epoch+1}/{args.epochs}] Batch [{i+1}/{len(dl)}] "
|
||||
f"Loss: {loss.item()/len(xb):.4f} (BCE: {bce.item()/len(xb):.4f}, KLD: {kld.item()/len(xb):.4f})")
|
||||
|
||||
avg_loss = total_loss / len(ds)
|
||||
avg_bce = total_bce / len(ds)
|
||||
avg_kld = total_kld / len(ds)
|
||||
epoch_time = time.time() - start_time
|
||||
|
||||
logger.info(f'Epoch {epoch+1}/{args.epochs} completed in {epoch_time:.2f}s')
|
||||
logger.info(f' Average Loss: {avg_loss:.4f}')
|
||||
logger.info(f' BCE: {avg_bce:.4f} | KLD: {avg_kld:.4f}')
|
||||
|
||||
# Save model
|
||||
torch.save(model.state_dict(), args.save_path)
|
||||
|
||||
# Save sample reconstructions
|
||||
if (epoch + 1) % 5 == 0:
|
||||
logger.info(f"Saving reconstruction samples to {args.sample_dir}")
|
||||
with torch.no_grad():
|
||||
model.eval()
|
||||
# Reconstruct first batch
|
||||
xb = next(iter(dl)).to(device)
|
||||
recon, _, _ = model(xb)
|
||||
# Concat input and recon
|
||||
vis = torch.cat([xb[:8], recon[:8]], dim=0)
|
||||
# Save grid
|
||||
from torchvision.utils import save_image
|
||||
save_image(vis, Path(args.sample_dir) / f'epoch_{epoch+1}.png', nrow=8)
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Reference in New Issue
Block a user