initial commit

This commit is contained in:
Jiao77
2025-11-24 20:34:50 +08:00
commit 633749886e
15 changed files with 2665 additions and 0 deletions

1
ICCAD2019 Symbolic link
View File

@@ -0,0 +1 @@
../../data/ICCAD2019

75
README.md Normal file
View File

@@ -0,0 +1,75 @@
# IC Layout Skeleton Generation (prototype)
This repo provides a prototype pipeline to handle strictly-Manhattan binary IC layout images: extract skeletons, vectorize skeletons into Manhattan polylines, train a skeleton-generation model, and expand skeletons back to images.
## Project Structure
- `scripts/`: Core processing scripts.
- `skeleton_extract.py`: Binarize & extract skeleton PNG.
- `vectorize_skeleton.py`: Trace skeleton PNG into JSON polylines.
- `expand_skeleton.py`: Rasterize polyline JSON back to a PNG.
- `models/`: PyTorch models (e.g., `SkeletonVAE`).
- `train/`: Training scripts.
- `datasets/`: Dataset building helpers.
- `out/`: Output directory for generated data and logs (ignored by git).
## Setup
This project uses `uv` for dependency management.
1. **Install uv**:
```bash
curl -LsSf https://astral.sh/uv/install.sh | sh
```
2. **Install dependencies**:
```bash
uv sync
```
## Usage
### 1. Prepare Data
Ensure your dataset is available. For example, link your data folder:
```bash
ln -s ~/Documents/data/ICCAD2019/img ./ICCAD2019/img
```
(Note: The workspace already contains a link to `ICCAD2019` if configured).
### 2. Run the Pipeline
You can run individual scripts using `uv run`.
**Extract Skeleton:**
```bash
uv run scripts/skeleton_extract.py path/to/image.png out/result_dir --denoise 3
```
**Vectorize Skeleton:**
```bash
uv run scripts/vectorize_skeleton.py out/result_dir/image_sk.png out/result_dir/image_vec.json
```
**Expand (Reconstruct) Image:**
```bash
uv run scripts/expand_skeleton.py out/result_dir/image_vec.json out/result_dir/image_recon.png
```
### 3. Build Dataset
Batch process a folder of images:
```bash
uv run datasets/build_skeleton_dataset.py ICCAD2019/img out/dataset_processed
```
### 4. Train Model
Train the VAE on the processed skeletons:
```bash
uv run train/train_skeleton_vae.py out/dataset_processed --epochs 20 --batch 16
```
## Debugging
A debug script is provided to run the full pipeline on a few random images from the dataset:
```bash
uv run debug_pipeline.py
```
Results will be in `out/debug`.

View File

@@ -0,0 +1,39 @@
#!/usr/bin/env python3
"""
Build dataset: given input images folder, extract skeletons and vectorize, save pairs.
"""
import argparse
from pathlib import Path
import subprocess
def main():
p = argparse.ArgumentParser()
p.add_argument('img_folder')
p.add_argument('out_folder')
p.add_argument('--invert', action='store_true')
args = p.parse_args()
out = Path(args.out_folder)
out.mkdir(parents=True, exist_ok=True)
imgs = list(Path(args.img_folder).glob('*.png')) + list(Path(args.img_folder).glob('*.jpg'))
for im in imgs:
od = out / im.stem
od.mkdir(exist_ok=True)
# call skeleton_extract
cmd = ['python3', str(Path(__file__).parents[1] / 'scripts' / 'skeleton_extract.py'), str(im), str(od)]
if args.invert:
cmd.append('--invert')
subprocess.run(cmd, check=True)
# vectorize
sk_png = od / (im.stem + '_sk.png')
outjson = od / (im.stem + '_sk.json')
cmd2 = ['python3', str(Path(__file__).parents[1] / 'scripts' / 'vectorize_skeleton.py'), str(sk_png), str(outjson)]
subprocess.run(cmd2, check=True)
print('Built dataset in', out)
if __name__ == '__main__':
main()

87
debug_pipeline.py Normal file
View File

@@ -0,0 +1,87 @@
#!/usr/bin/env python3
import os
import random
import subprocess
from pathlib import Path
import shutil
def run_command(cmd):
print(f"Running: {' '.join(cmd)}")
subprocess.run(cmd, check=True)
def main():
# Setup paths
base_dir = Path(__file__).parent
img_dir = base_dir / "ICCAD2019" / "img"
out_dir = base_dir / "out" / "debug1"
if out_dir.exists():
shutil.rmtree(out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
# Get images
all_images = list(img_dir.glob("*.png"))
if not all_images:
print("No images found in ICCAD2019/img")
return
selected_images = random.sample(all_images, min(5, len(all_images)))
print(f"Selected {len(selected_images)} images for debugging.")
for img_path in selected_images:
print(f"\nProcessing {img_path.name}...")
case_dir = out_dir / img_path.stem
case_dir.mkdir(exist_ok=True)
# 1. Skeleton Extraction
# Note: The user mentioned "black background, white subject" or vice versa.
# The script has --invert. I'll try without invert first, assuming standard layout (often drawn dark on light or light on dark).
# Let's check one image first? No, I'll just run it.
# Actually, usually layouts are drawn objects. If background is black, objects are white.
# My script assumes objects are white (value 1) for skeletonization.
# If the image is white background, I need to invert.
# I'll try both or just assume one. Let's assume we might need --invert if it's white background.
# I'll run with --invert if the mean pixel value is high (white background).
# But wait, I can't easily check mean here without loading.
# I'll just run the extraction script.
extract_cmd = [
"python3", "scripts/skeleton_extract.py",
str(img_path),
str(case_dir),
"--denoise", "3"
]
# Heuristic: if filename contains 'nonhotspot', it might be a clip.
# Let's just try running it.
run_command(extract_cmd)
sk_png = case_dir / (img_path.stem + "_sk.png")
if not sk_png.exists():
print(f"Failed to generate skeleton for {img_path.name}")
continue
# 2. Vectorization
vec_json = case_dir / (img_path.stem + "_vec.json")
vec_cmd = [
"python3", "scripts/vectorize_skeleton.py",
str(sk_png),
str(vec_json)
]
run_command(vec_cmd)
# 3. Expansion (Reconstruction)
recon_png = case_dir / (img_path.stem + "_recon.png")
expand_cmd = [
"python3", "scripts/expand_skeleton.py",
str(vec_json),
str(recon_png),
"--line-width", "3" # Adjust as needed
]
run_command(expand_cmd)
print(f"\nDebug run complete. Check results in {out_dir}")
if __name__ == "__main__":
main()

0
models/__init__.py Normal file
View File

62
models/skeleton_vae.py Normal file
View File

@@ -0,0 +1,62 @@
"""
Simple convolutional VAE for skeleton images (single-channel)
"""
import torch
import torch.nn as nn
class ConvEncoder(nn.Module):
def __init__(self, z_dim=64):
super().__init__()
self.enc = nn.Sequential(
nn.Conv2d(1,32,4,2,1),
nn.ReLU(),
nn.Conv2d(32,64,4,2,1),
nn.ReLU(),
nn.Conv2d(64,128,4,2,1),
nn.ReLU(),
nn.Flatten()
)
self.fc_mu = nn.Linear(128*8*8, z_dim)
self.fc_logvar = nn.Linear(128*8*8, z_dim)
def forward(self, x):
h = self.enc(x)
return self.fc_mu(h), self.fc_logvar(h)
class ConvDecoder(nn.Module):
def __init__(self, z_dim=64):
super().__init__()
self.fc = nn.Linear(z_dim, 128*8*8)
self.dec = nn.Sequential(
nn.Unflatten(1, (128,8,8)),
nn.ConvTranspose2d(128,64,4,2,1),
nn.ReLU(),
nn.ConvTranspose2d(64,32,4,2,1),
nn.ReLU(),
nn.ConvTranspose2d(32,1,4,2,1),
nn.Sigmoid()
)
def forward(self, z):
h = self.fc(z)
return self.dec(h)
class SkeletonVAE(nn.Module):
def __init__(self, z_dim=64):
super().__init__()
self.encoder = ConvEncoder(z_dim)
self.decoder = ConvDecoder(z_dim)
def reparameterize(self, mu, logvar):
std = (0.5 * logvar).exp()
eps = torch.randn_like(std)
return mu + eps * std
def forward(self, x):
mu, logvar = self.encoder(x)
z = self.reparameterize(mu, logvar)
out = self.decoder(z)
return out, mu, logvar

23
pyproject.toml Normal file
View File

@@ -0,0 +1,23 @@
[project]
name = "ic-layout-generate"
version = "0.1.0"
description = "IC Layout Generation using Skeleton Extraction and Expansion"
readme = "README.md"
requires-python = ">=3.9"
dependencies = [
"numpy",
"opencv-python-headless",
"scikit-image",
"networkx",
"shapely",
"Pillow",
"torch",
"torchvision",
]
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.hatch.build.targets.wheel]
packages = ["models", "scripts", "datasets", "train"]

8
requirements.txt Normal file
View File

@@ -0,0 +1,8 @@
numpy
opencv-python-headless
scikit-image
networkx
shapely
Pillow
torch
torchvision

67
run_experiment.py Normal file
View File

@@ -0,0 +1,67 @@
#!/usr/bin/env python3
"""
Run the full experiment:
1. Build dataset from raw images
2. Train VAE model
3. Generate new skeletons
"""
import subprocess
import argparse
from pathlib import Path
def run_command(cmd):
print(f"Running: {' '.join(cmd)}")
subprocess.run(cmd, check=True)
def main():
p = argparse.ArgumentParser()
p.add_argument('--img_dir', type=str, default='ICCAD2019/img', help='Input raw images directory')
p.add_argument('--data_dir', type=str, default='out/dataset', help='Output dataset directory')
p.add_argument('--model_dir', type=str, default='out/models', help='Directory to save models')
p.add_argument('--gen_dir', type=str, default='out/generated', help='Directory for generated samples')
p.add_argument('--epochs', type=int, default=20, help='Training epochs')
p.add_argument('--batch_size', type=int, default=16, help='Batch size')
p.add_argument('--skip_data', action='store_true', help='Skip dataset building')
p.add_argument('--skip_train', action='store_true', help='Skip training')
args = p.parse_args()
# 1. Build Dataset
if not args.skip_data:
print("=== Step 1: Building Dataset ===")
# Note: build_skeleton_dataset.py expects img_folder and out_folder
cmd = [
'python3', 'datasets/build_skeleton_dataset.py',
args.img_dir,
args.data_dir
]
run_command(cmd)
# 2. Train Model
model_path = Path(args.model_dir) / 'vae_best.pth'
if not args.skip_train:
print("=== Step 2: Training VAE ===")
cmd = [
'python3', 'train/train_skeleton_vae.py',
args.data_dir,
'--epochs', str(args.epochs),
'--batch', str(args.batch_size),
'--save_path', str(model_path),
'--sample_dir', str(Path(args.model_dir) / 'train_samples')
]
run_command(cmd)
# 3. Generate
print("=== Step 3: Generating Skeletons ===")
cmd = [
'python3', 'scripts/generate_skeleton.py',
'--model_path', str(model_path),
'--out_dir', args.gen_dir,
'--num_samples', '20'
]
run_command(cmd)
print("=== Experiment Complete ===")
print(f"Generated images are in {args.gen_dir}")
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,43 @@
#!/usr/bin/env python3
"""
expand_skeleton.py
Rasterize Manhattan polylines (JSON) back to a binary layout image.
"""
import argparse
from pathlib import Path
import json
import numpy as np
import cv2
def draw_polylines(polylines, shape, line_width=3):
h, w = shape
img = np.zeros((h, w), dtype='uint8')
for pl in polylines:
if len(pl) < 2:
continue
pts = [(int(x), int(y)) for x,y in pl]
for i in range(len(pts)-1):
p0 = pts[i]
p1 = pts[i+1]
cv2.line(img, p0, p1, color=255, thickness=line_width)
return (img > 127).astype('uint8')
def main():
p = argparse.ArgumentParser()
p.add_argument('polyjson')
p.add_argument('outpng')
p.add_argument('--line-width', type=int, default=3)
args = p.parse_args()
data = json.load(open(args.polyjson,'r'))
polylines = data.get('polylines', [])
shape = tuple(data.get('shape', [512,512]))
img = draw_polylines(polylines, shape, line_width=args.line_width)
cv2.imwrite(args.outpng, img*255)
print('Wrote', args.outpng)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,63 @@
#!/usr/bin/env python3
"""
Generate new skeletons using the trained VAE model.
"""
import argparse
import sys
from pathlib import Path
import torch
from torchvision.utils import save_image
import numpy as np
# Add root to path to allow importing models
sys.path.append(str(Path(__file__).parents[1]))
from models.skeleton_vae import SkeletonVAE
def main():
p = argparse.ArgumentParser()
p.add_argument('--model_path', type=str, required=True, help='Path to trained .pth model')
p.add_argument('--out_dir', type=str, default='out/generated', help='Output directory for generated images')
p.add_argument('--num_samples', type=int, default=10, help='Number of samples to generate')
p.add_argument('--z_dim', type=int, default=64, help='Latent dimension size')
p.add_argument('--threshold', type=float, default=0.5, help='Binarization threshold')
args = p.parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load model
model = SkeletonVAE(z_dim=args.z_dim).to(device)
if not Path(args.model_path).exists():
print(f"Error: Model not found at {args.model_path}")
return
model.load_state_dict(torch.load(args.model_path, map_location=device))
model.eval()
out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
print(f"Generating {args.num_samples} samples...")
with torch.no_grad():
# Sample from standard normal distribution
z = torch.randn(args.num_samples, args.z_dim).to(device)
# Decode
generated = model.decoder(z)
# Save raw and binarized
for i in range(args.num_samples):
img = generated[i]
# Save raw probability map
save_image(img, out_dir / f'gen_{i}_raw.png')
# Binarize
bin_img = (img > args.threshold).float()
save_image(bin_img, out_dir / f'gen_{i}_bin.png')
print(f"Saved generated images to {out_dir}")
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,79 @@
#!/usr/bin/env python3
"""
skeleton_extract.py
Extract Manhattan-style skeleton from binary layout images.
Supports optional color inversion and basic denoising.
Outputs skeleton PNG and a JSON summary with connected-component counts.
"""
import argparse
from pathlib import Path
import numpy as np
from PIL import Image
import cv2
from skimage.morphology import skeletonize
from skimage import img_as_ubyte
import json
def load_image(path, invert=False):
im = Image.open(path).convert('L')
a = np.array(im)
# auto threshold by Otsu
_, th = cv2.threshold(a, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
if invert:
th = 255 - th
return th // 255
def denoise(bin_img, kernel=3):
# simple opening/closing
k = cv2.getStructuringElement(cv2.MORPH_RECT, (kernel, kernel))
img = cv2.morphologyEx((bin_img * 255).astype('uint8'), cv2.MORPH_OPEN, k)
img = cv2.morphologyEx(img, cv2.MORPH_CLOSE, k)
return img // 255
def extract_skeleton(bin_img):
# skimage expects bool image
sk = skeletonize(bin_img > 0)
return sk.astype('uint8')
def save_png(arr, path):
im = Image.fromarray((arr * 255).astype('uint8'))
im.save(path)
def main():
p = argparse.ArgumentParser()
p.add_argument('input')
p.add_argument('outdir')
p.add_argument('--invert', action='store_true', help='Invert black/white before processing')
p.add_argument('--denoise', type=int, default=3, help='Denoise kernel size')
args = p.parse_args()
inp = Path(args.input)
out = Path(args.outdir)
out.mkdir(parents=True, exist_ok=True)
bin_img = load_image(inp, invert=args.invert)
if args.denoise and args.denoise > 0:
bin_img = denoise(bin_img, kernel=args.denoise)
sk = extract_skeleton(bin_img)
save_png(bin_img, out / (inp.stem + '_bin.png'))
save_png(sk, out / (inp.stem + '_sk.png'))
# summary
num_pixels = int(sk.sum())
components = cv2.connectedComponents((sk * 255).astype('uint8'))[0] - 1
info = {'input': str(inp), 'pixels_in_skeleton': int(num_pixels), 'components': int(components)}
with open(out / (inp.stem + '_sk.json'), 'w') as f:
json.dump(info, f, indent=2)
print('Saved:', out / (inp.stem + '_sk.png'))
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,308 @@
#!/usr/bin/env python3
"""
vectorize_skeleton.py
Trace skeleton PNG to Manhattan polylines (simple 4-neighbor tracing) and export JSON.
"""
import argparse
from pathlib import Path
import numpy as np
from PIL import Image
import json
import cv2
from collections import deque
def load_skeleton(path):
im = Image.open(path).convert('L')
a = np.array(im)
return (a > 127).astype('uint8')
def neighbors4(y, x, h, w):
for dy, dx in ((0,1),(1,0),(0,-1),(-1,0)):
ny, nx = y+dy, x+dx
if 0 <= ny < h and 0 <= nx < w:
yield ny, nx
def trace_components(sk):
h, w = sk.shape
visited = np.zeros_like(sk, dtype=bool)
comps = []
for y in range(h):
for x in range(w):
if sk[y,x] and not visited[y,x]:
# BFS to collect component pixels
q = deque()
q.append((y,x))
visited[y,x] = True
pts = []
while q:
cy, cx = q.popleft()
pts.append((int(cx), int(cy)))
for ny, nx in neighbors4(cy, cx, h, w):
if sk[ny,nx] and not visited[ny,nx]:
visited[ny,nx] = True
q.append((ny,nx))
comps.append(pts)
return comps
def prune_segments(segments, min_len=5, merge_tol=3):
if not segments: return []
# 1. Merge consecutive same-direction segments (simple pass)
# Repeat until no more merges to handle multi-segment merges
while True:
merged = []
changed = False
if segments:
curr = segments[0]
for i in range(1, len(segments)):
next_seg = segments[i]
if curr[0] == next_seg[0] and abs(curr[1] - next_seg[1]) <= merge_tol:
# Merge: extend end_idx, re-calculate average val
len1 = curr[3] - curr[2] + 1
len2 = next_seg[3] - next_seg[2] + 1
new_val = int(round((curr[1]*len1 + next_seg[1]*len2) / (len1+len2)))
curr = (curr[0], new_val, curr[2], next_seg[3])
changed = True
else:
merged.append(curr)
curr = next_seg
merged.append(curr)
segments = merged
if not changed:
break
# 2. Remove short tails
# Check start
if len(segments) > 1:
s0 = segments[0]
l0 = s0[3] - s0[2] + 1
if l0 < min_len:
segments.pop(0)
# Check end
if len(segments) > 1:
s_last = segments[-1]
l_last = s_last[3] - s_last[2] + 1
if l_last < min_len:
segments.pop(-1)
# 3. Remove short internal bumps (Z-shape)
# H1 -> V(short) -> H2 => Merge H1, H2 if aligned
# We iterate and build a new list. If merge happens, we modify the next segment and skip.
final_segs = []
i = 0
while i < len(segments):
curr = segments[i]
merged_bump = False
if i + 2 < len(segments):
mid = segments[i+1]
next_seg = segments[i+2]
# Check pattern: A -> B(short) -> A
if curr[0] == next_seg[0] and mid[0] != curr[0]:
len_mid = mid[3] - mid[2] + 1
if len_mid < min_len:
# Check alignment of curr and next_seg
if abs(curr[1] - next_seg[1]) <= merge_tol:
# Merge all three: curr + mid + next_seg -> new_curr
len1 = curr[3] - curr[2] + 1
len3 = next_seg[3] - next_seg[2] + 1
new_val = int(round((curr[1]*len1 + next_seg[1]*len3) / (len1+len3)))
# New segment spans from curr.start to next_seg.end
new_seg = (curr[0], new_val, curr[2], next_seg[3])
# We effectively skip mid and next_seg, and replace curr with new_seg
# But we need to check if this new_seg can merge further?
# For simplicity, let's push new_seg to final_segs and skip 2.
# But wait, if we push to final_segs, we can't merge it with subsequent ones in this loop easily.
# Let's update segments[i+2] to be the merged one and continue loop from i+2.
segments[i+2] = new_seg
i += 2
merged_bump = True
continue
if not merged_bump:
final_segs.append(curr)
i += 1
return final_segs
def simplify_manhattan(pts, tolerance=2):
if not pts: return []
if len(pts) < 2: return pts
segments = []
n = len(pts)
i = 0
while i < n - 1:
start_pt = pts[i]
# Check Horizontal
k_h = i + 1
while k_h < n:
if abs(pts[k_h][1] - start_pt[1]) > tolerance:
break
k_h += 1
len_h = k_h - i
# Check Vertical
k_v = i + 1
while k_v < n:
if abs(pts[k_v][0] - start_pt[0]) > tolerance:
break
k_v += 1
len_v = k_v - i
if len_h >= len_v:
# Horizontal
segment_pts = pts[i:k_h]
avg_y = int(round(np.mean([p[1] for p in segment_pts])))
segments.append(('H', avg_y, i, k_h - 1))
i = k_h - 1
else:
# Vertical
segment_pts = pts[i:k_v]
avg_x = int(round(np.mean([p[0] for p in segment_pts])))
segments.append(('V', avg_x, i, k_v - 1))
i = k_v - 1
# --- Pruning / Refining ---
segments = prune_segments(segments, min_len=5, merge_tol=3)
if not segments:
return []
out_poly = []
# First point
first_seg = segments[0]
first_pt_orig = pts[first_seg[2]]
if first_seg[0] == 'H':
curr = (first_pt_orig[0], first_seg[1])
else:
curr = (first_seg[1], first_pt_orig[1])
out_poly.append(curr)
for idx in range(len(segments) - 1):
s1 = segments[idx]
s2 = segments[idx+1]
# Transition point index is s1[3] (which is same as s2[2])
# But after pruning, s1[3] might not be s2[2] - 1. Gaps might exist.
# We should just connect s1's end to s2's start via a Manhattan corner.
# s1 end point (projected)
if s1[0] == 'H':
p1_end = (pts[s1[3]][0], s1[1])
else:
p1_end = (s1[1], pts[s1[3]][1])
# s2 start point (projected)
if s2[0] == 'H':
p2_start = (pts[s2[2]][0], s2[1])
else:
p2_start = (s2[1], pts[s2[2]][1])
# Connect p1_end to p2_start
if s1[0] == 'H' and s2[0] == 'V':
# H(y1) -> V(x2)
# Intersection is (x2, y1)
corner = (s2[1], s1[1])
if out_poly[-1] != corner: out_poly.append(corner)
elif s1[0] == 'V' and s2[0] == 'H':
# V(x1) -> H(y2)
# Intersection is (x1, y2)
corner = (s1[1], s2[1])
if out_poly[-1] != corner: out_poly.append(corner)
else:
# Parallel segments (should have been merged, but if gap was large...)
# Or H -> H (gap)
# Just connect via midpoint or direct L-shape?
# Let's use p1_end -> p2_start directly? No, need Manhattan.
# H(y1) ... H(y2). Connect (x_end1, y1) -> (x_end1, y2) -> (x_start2, y2) ?
# Or (x_end1, y1) -> (x_start2, y1) -> (x_start2, y2) ?
# Let's use the first one (Vertical bridge).
if out_poly[-1] != p1_end: out_poly.append(p1_end)
if s1[0] == 'H':
# Bridge is Vertical
mid_x = (p1_end[0] + p2_start[0]) // 2
c1 = (mid_x, p1_end[1])
c2 = (mid_x, p2_start[1])
if out_poly[-1] != c1: out_poly.append(c1)
if c1 != c2: out_poly.append(c2)
if c2 != p2_start: out_poly.append(p2_start)
else:
# Bridge is Horizontal
mid_y = (p1_end[1] + p2_start[1]) // 2
c1 = (p1_end[0], mid_y)
c2 = (p2_start[0], mid_y)
if out_poly[-1] != c1: out_poly.append(c1)
if c1 != c2: out_poly.append(c2)
if c2 != p2_start: out_poly.append(p2_start)
# Last point
last_seg = segments[-1]
last_pt_orig = pts[last_seg[3]]
if last_seg[0] == 'H':
end = (last_pt_orig[0], last_seg[1])
else:
end = (last_seg[1], last_pt_orig[1])
if out_poly[-1] != end: out_poly.append(end)
return out_poly
def polyline_from_pixels(pts):
# pts: list of (x,y) unordered; produce simple ordering by greedy nearest neighbor
if not pts:
return []
pts_set = set(pts)
# start from leftmost-topmost
cur = min(pts, key=lambda p: (p[1], p[0]))
seq = [cur]
pts_set.remove(cur)
while pts_set:
# look for 4-neighbor next
x,y = seq[-1]
found = None
for nx, ny in ((x+1,y),(x-1,y),(x,y+1),(x,y-1)):
if (nx,ny) in pts_set:
found = (nx,ny)
break
if found is None:
# fallback: nearest
found = min(pts_set, key=lambda p: (abs(p[0]-x)+abs(p[1]-y), p[1], p[0]))
seq.append(found)
pts_set.remove(found)
return simplify_manhattan(seq, tolerance=2)
def main():
p = argparse.ArgumentParser()
p.add_argument('skel_png')
p.add_argument('outjson')
args = p.parse_args()
sk = load_skeleton(args.skel_png)
comps = trace_components(sk)
polylines = []
for pts in comps:
pl = polyline_from_pixels(pts)
if len(pl) > 1:
polylines.append(pl)
out = {'polylines': polylines, 'num': len(polylines), 'shape': sk.shape}
with open(args.outjson, 'w') as f:
json.dump(out, f)
print('Wrote', args.outjson)
if __name__ == '__main__':
main()

136
train/train_skeleton_vae.py Normal file
View File

@@ -0,0 +1,136 @@
#!/usr/bin/env python3
"""
Minimal training loop for SkeletonVAE. Expects a directory of skeleton PNGs.
"""
import argparse
from pathlib import Path
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torchvision.transforms as T
import numpy as np
import logging
import time
from models.skeleton_vae import SkeletonVAE
class SkeletonDataset(Dataset):
def __init__(self, folder, size=64):
self.files = list(Path(folder).glob('**/*_sk.png'))
self.tr = T.Compose([T.Resize((size,size)), T.ToTensor()])
def __len__(self):
return len(self.files)
def __getitem__(self, idx):
p = self.files[idx]
im = Image.open(p).convert('L')
t = self.tr(im)
return t
def loss_fn(recon_x, x, mu, logvar):
BCE = torch.nn.functional.binary_cross_entropy(recon_x, x, reduction='sum')
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return BCE + KLD, BCE, KLD
def setup_logger(log_dir):
log_dir = Path(log_dir)
log_dir.mkdir(parents=True, exist_ok=True)
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(log_dir / 'train.log'),
logging.StreamHandler()
],
force=True
)
return logging.getLogger(__name__)
def main():
p = argparse.ArgumentParser()
p.add_argument('skel_folder')
p.add_argument('--epochs', type=int, default=10)
p.add_argument('--batch', type=int, default=16)
p.add_argument('--lr', type=float, default=1e-3)
p.add_argument('--save_path', type=str, default='out/vae_model.pth')
p.add_argument('--sample_dir', type=str, default='out/samples')
args = p.parse_args()
Path(args.sample_dir).mkdir(parents=True, exist_ok=True)
Path(args.save_path).parent.mkdir(parents=True, exist_ok=True)
logger = setup_logger(Path(args.save_path).parent)
ds = SkeletonDataset(args.skel_folder, size=64)
if len(ds) == 0:
logger.error(f"No skeleton images found in {args.skel_folder}")
return
dl = DataLoader(ds, batch_size=args.batch, shuffle=True, num_workers=4)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"Using device: {device}")
logger.info(f"Dataset size: {len(ds)} images")
model = SkeletonVAE(z_dim=64).to(device)
opt = torch.optim.Adam(model.parameters(), lr=args.lr)
logger.info("Starting training...")
for epoch in range(args.epochs):
start_time = time.time()
model.train()
total_loss = 0.0
total_bce = 0.0
total_kld = 0.0
for i, xb in enumerate(dl):
xb = xb.to(device)
recon, mu, logvar = model(xb)
loss, bce, kld = loss_fn(recon, xb, mu, logvar)
opt.zero_grad()
loss.backward()
opt.step()
total_loss += loss.item()
total_bce += bce.item()
total_kld += kld.item()
if (i + 1) % 10 == 0:
logger.info(f"Epoch [{epoch+1}/{args.epochs}] Batch [{i+1}/{len(dl)}] "
f"Loss: {loss.item()/len(xb):.4f} (BCE: {bce.item()/len(xb):.4f}, KLD: {kld.item()/len(xb):.4f})")
avg_loss = total_loss / len(ds)
avg_bce = total_bce / len(ds)
avg_kld = total_kld / len(ds)
epoch_time = time.time() - start_time
logger.info(f'Epoch {epoch+1}/{args.epochs} completed in {epoch_time:.2f}s')
logger.info(f' Average Loss: {avg_loss:.4f}')
logger.info(f' BCE: {avg_bce:.4f} | KLD: {avg_kld:.4f}')
# Save model
torch.save(model.state_dict(), args.save_path)
# Save sample reconstructions
if (epoch + 1) % 5 == 0:
logger.info(f"Saving reconstruction samples to {args.sample_dir}")
with torch.no_grad():
model.eval()
# Reconstruct first batch
xb = next(iter(dl)).to(device)
recon, _, _ = model(xb)
# Concat input and recon
vis = torch.cat([xb[:8], recon[:8]], dim=0)
# Save grid
from torchvision.utils import save_image
save_image(vis, Path(args.sample_dir) / f'epoch_{epoch+1}.png', nrow=8)
if __name__ == '__main__':
main()

1674
uv.lock generated Normal file

File diff suppressed because it is too large Load Diff