initial commit
This commit is contained in:
75
README.md
Normal file
75
README.md
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
# IC Layout Skeleton Generation (prototype)
|
||||||
|
|
||||||
|
This repo provides a prototype pipeline to handle strictly-Manhattan binary IC layout images: extract skeletons, vectorize skeletons into Manhattan polylines, train a skeleton-generation model, and expand skeletons back to images.
|
||||||
|
|
||||||
|
## Project Structure
|
||||||
|
|
||||||
|
- `scripts/`: Core processing scripts.
|
||||||
|
- `skeleton_extract.py`: Binarize & extract skeleton PNG.
|
||||||
|
- `vectorize_skeleton.py`: Trace skeleton PNG into JSON polylines.
|
||||||
|
- `expand_skeleton.py`: Rasterize polyline JSON back to a PNG.
|
||||||
|
- `models/`: PyTorch models (e.g., `SkeletonVAE`).
|
||||||
|
- `train/`: Training scripts.
|
||||||
|
- `datasets/`: Dataset building helpers.
|
||||||
|
- `out/`: Output directory for generated data and logs (ignored by git).
|
||||||
|
|
||||||
|
## Setup
|
||||||
|
|
||||||
|
This project uses `uv` for dependency management.
|
||||||
|
|
||||||
|
1. **Install uv**:
|
||||||
|
```bash
|
||||||
|
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Install dependencies**:
|
||||||
|
```bash
|
||||||
|
uv sync
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
### 1. Prepare Data
|
||||||
|
Ensure your dataset is available. For example, link your data folder:
|
||||||
|
```bash
|
||||||
|
ln -s ~/Documents/data/ICCAD2019/img ./ICCAD2019/img
|
||||||
|
```
|
||||||
|
(Note: The workspace already contains a link to `ICCAD2019` if configured).
|
||||||
|
|
||||||
|
### 2. Run the Pipeline
|
||||||
|
|
||||||
|
You can run individual scripts using `uv run`.
|
||||||
|
|
||||||
|
**Extract Skeleton:**
|
||||||
|
```bash
|
||||||
|
uv run scripts/skeleton_extract.py path/to/image.png out/result_dir --denoise 3
|
||||||
|
```
|
||||||
|
|
||||||
|
**Vectorize Skeleton:**
|
||||||
|
```bash
|
||||||
|
uv run scripts/vectorize_skeleton.py out/result_dir/image_sk.png out/result_dir/image_vec.json
|
||||||
|
```
|
||||||
|
|
||||||
|
**Expand (Reconstruct) Image:**
|
||||||
|
```bash
|
||||||
|
uv run scripts/expand_skeleton.py out/result_dir/image_vec.json out/result_dir/image_recon.png
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Build Dataset
|
||||||
|
Batch process a folder of images:
|
||||||
|
```bash
|
||||||
|
uv run datasets/build_skeleton_dataset.py ICCAD2019/img out/dataset_processed
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. Train Model
|
||||||
|
Train the VAE on the processed skeletons:
|
||||||
|
```bash
|
||||||
|
uv run train/train_skeleton_vae.py out/dataset_processed --epochs 20 --batch 16
|
||||||
|
```
|
||||||
|
|
||||||
|
## Debugging
|
||||||
|
A debug script is provided to run the full pipeline on a few random images from the dataset:
|
||||||
|
```bash
|
||||||
|
uv run debug_pipeline.py
|
||||||
|
```
|
||||||
|
Results will be in `out/debug`.
|
||||||
39
datasets/build_skeleton_dataset.py
Normal file
39
datasets/build_skeleton_dataset.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Build dataset: given input images folder, extract skeletons and vectorize, save pairs.
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
p = argparse.ArgumentParser()
|
||||||
|
p.add_argument('img_folder')
|
||||||
|
p.add_argument('out_folder')
|
||||||
|
p.add_argument('--invert', action='store_true')
|
||||||
|
args = p.parse_args()
|
||||||
|
|
||||||
|
out = Path(args.out_folder)
|
||||||
|
out.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
imgs = list(Path(args.img_folder).glob('*.png')) + list(Path(args.img_folder).glob('*.jpg'))
|
||||||
|
for im in imgs:
|
||||||
|
od = out / im.stem
|
||||||
|
od.mkdir(exist_ok=True)
|
||||||
|
# call skeleton_extract
|
||||||
|
cmd = ['python3', str(Path(__file__).parents[1] / 'scripts' / 'skeleton_extract.py'), str(im), str(od)]
|
||||||
|
if args.invert:
|
||||||
|
cmd.append('--invert')
|
||||||
|
subprocess.run(cmd, check=True)
|
||||||
|
# vectorize
|
||||||
|
sk_png = od / (im.stem + '_sk.png')
|
||||||
|
outjson = od / (im.stem + '_sk.json')
|
||||||
|
cmd2 = ['python3', str(Path(__file__).parents[1] / 'scripts' / 'vectorize_skeleton.py'), str(sk_png), str(outjson)]
|
||||||
|
subprocess.run(cmd2, check=True)
|
||||||
|
|
||||||
|
print('Built dataset in', out)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
87
debug_pipeline.py
Normal file
87
debug_pipeline.py
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import subprocess
|
||||||
|
from pathlib import Path
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
def run_command(cmd):
|
||||||
|
print(f"Running: {' '.join(cmd)}")
|
||||||
|
subprocess.run(cmd, check=True)
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Setup paths
|
||||||
|
base_dir = Path(__file__).parent
|
||||||
|
img_dir = base_dir / "ICCAD2019" / "img"
|
||||||
|
out_dir = base_dir / "out" / "debug1"
|
||||||
|
|
||||||
|
if out_dir.exists():
|
||||||
|
shutil.rmtree(out_dir)
|
||||||
|
out_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Get images
|
||||||
|
all_images = list(img_dir.glob("*.png"))
|
||||||
|
if not all_images:
|
||||||
|
print("No images found in ICCAD2019/img")
|
||||||
|
return
|
||||||
|
|
||||||
|
selected_images = random.sample(all_images, min(5, len(all_images)))
|
||||||
|
|
||||||
|
print(f"Selected {len(selected_images)} images for debugging.")
|
||||||
|
|
||||||
|
for img_path in selected_images:
|
||||||
|
print(f"\nProcessing {img_path.name}...")
|
||||||
|
case_dir = out_dir / img_path.stem
|
||||||
|
case_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
# 1. Skeleton Extraction
|
||||||
|
# Note: The user mentioned "black background, white subject" or vice versa.
|
||||||
|
# The script has --invert. I'll try without invert first, assuming standard layout (often drawn dark on light or light on dark).
|
||||||
|
# Let's check one image first? No, I'll just run it.
|
||||||
|
# Actually, usually layouts are drawn objects. If background is black, objects are white.
|
||||||
|
# My script assumes objects are white (value 1) for skeletonization.
|
||||||
|
# If the image is white background, I need to invert.
|
||||||
|
# I'll try both or just assume one. Let's assume we might need --invert if it's white background.
|
||||||
|
# I'll run with --invert if the mean pixel value is high (white background).
|
||||||
|
|
||||||
|
# But wait, I can't easily check mean here without loading.
|
||||||
|
# I'll just run the extraction script.
|
||||||
|
|
||||||
|
extract_cmd = [
|
||||||
|
"python3", "scripts/skeleton_extract.py",
|
||||||
|
str(img_path),
|
||||||
|
str(case_dir),
|
||||||
|
"--denoise", "3"
|
||||||
|
]
|
||||||
|
# Heuristic: if filename contains 'nonhotspot', it might be a clip.
|
||||||
|
# Let's just try running it.
|
||||||
|
run_command(extract_cmd)
|
||||||
|
|
||||||
|
sk_png = case_dir / (img_path.stem + "_sk.png")
|
||||||
|
if not sk_png.exists():
|
||||||
|
print(f"Failed to generate skeleton for {img_path.name}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 2. Vectorization
|
||||||
|
vec_json = case_dir / (img_path.stem + "_vec.json")
|
||||||
|
vec_cmd = [
|
||||||
|
"python3", "scripts/vectorize_skeleton.py",
|
||||||
|
str(sk_png),
|
||||||
|
str(vec_json)
|
||||||
|
]
|
||||||
|
run_command(vec_cmd)
|
||||||
|
|
||||||
|
# 3. Expansion (Reconstruction)
|
||||||
|
recon_png = case_dir / (img_path.stem + "_recon.png")
|
||||||
|
expand_cmd = [
|
||||||
|
"python3", "scripts/expand_skeleton.py",
|
||||||
|
str(vec_json),
|
||||||
|
str(recon_png),
|
||||||
|
"--line-width", "3" # Adjust as needed
|
||||||
|
]
|
||||||
|
run_command(expand_cmd)
|
||||||
|
|
||||||
|
print(f"\nDebug run complete. Check results in {out_dir}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
0
models/__init__.py
Normal file
0
models/__init__.py
Normal file
62
models/skeleton_vae.py
Normal file
62
models/skeleton_vae.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
"""
|
||||||
|
Simple convolutional VAE for skeleton images (single-channel)
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class ConvEncoder(nn.Module):
|
||||||
|
def __init__(self, z_dim=64):
|
||||||
|
super().__init__()
|
||||||
|
self.enc = nn.Sequential(
|
||||||
|
nn.Conv2d(1,32,4,2,1),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2d(32,64,4,2,1),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2d(64,128,4,2,1),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Flatten()
|
||||||
|
)
|
||||||
|
self.fc_mu = nn.Linear(128*8*8, z_dim)
|
||||||
|
self.fc_logvar = nn.Linear(128*8*8, z_dim)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
h = self.enc(x)
|
||||||
|
return self.fc_mu(h), self.fc_logvar(h)
|
||||||
|
|
||||||
|
|
||||||
|
class ConvDecoder(nn.Module):
|
||||||
|
def __init__(self, z_dim=64):
|
||||||
|
super().__init__()
|
||||||
|
self.fc = nn.Linear(z_dim, 128*8*8)
|
||||||
|
self.dec = nn.Sequential(
|
||||||
|
nn.Unflatten(1, (128,8,8)),
|
||||||
|
nn.ConvTranspose2d(128,64,4,2,1),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.ConvTranspose2d(64,32,4,2,1),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.ConvTranspose2d(32,1,4,2,1),
|
||||||
|
nn.Sigmoid()
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, z):
|
||||||
|
h = self.fc(z)
|
||||||
|
return self.dec(h)
|
||||||
|
|
||||||
|
|
||||||
|
class SkeletonVAE(nn.Module):
|
||||||
|
def __init__(self, z_dim=64):
|
||||||
|
super().__init__()
|
||||||
|
self.encoder = ConvEncoder(z_dim)
|
||||||
|
self.decoder = ConvDecoder(z_dim)
|
||||||
|
|
||||||
|
def reparameterize(self, mu, logvar):
|
||||||
|
std = (0.5 * logvar).exp()
|
||||||
|
eps = torch.randn_like(std)
|
||||||
|
return mu + eps * std
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
mu, logvar = self.encoder(x)
|
||||||
|
z = self.reparameterize(mu, logvar)
|
||||||
|
out = self.decoder(z)
|
||||||
|
return out, mu, logvar
|
||||||
23
pyproject.toml
Normal file
23
pyproject.toml
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
[project]
|
||||||
|
name = "ic-layout-generate"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "IC Layout Generation using Skeleton Extraction and Expansion"
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.9"
|
||||||
|
dependencies = [
|
||||||
|
"numpy",
|
||||||
|
"opencv-python-headless",
|
||||||
|
"scikit-image",
|
||||||
|
"networkx",
|
||||||
|
"shapely",
|
||||||
|
"Pillow",
|
||||||
|
"torch",
|
||||||
|
"torchvision",
|
||||||
|
]
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["hatchling"]
|
||||||
|
build-backend = "hatchling.build"
|
||||||
|
|
||||||
|
[tool.hatch.build.targets.wheel]
|
||||||
|
packages = ["models", "scripts", "datasets", "train"]
|
||||||
8
requirements.txt
Normal file
8
requirements.txt
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
numpy
|
||||||
|
opencv-python-headless
|
||||||
|
scikit-image
|
||||||
|
networkx
|
||||||
|
shapely
|
||||||
|
Pillow
|
||||||
|
torch
|
||||||
|
torchvision
|
||||||
67
run_experiment.py
Normal file
67
run_experiment.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Run the full experiment:
|
||||||
|
1. Build dataset from raw images
|
||||||
|
2. Train VAE model
|
||||||
|
3. Generate new skeletons
|
||||||
|
"""
|
||||||
|
import subprocess
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
def run_command(cmd):
|
||||||
|
print(f"Running: {' '.join(cmd)}")
|
||||||
|
subprocess.run(cmd, check=True)
|
||||||
|
|
||||||
|
def main():
|
||||||
|
p = argparse.ArgumentParser()
|
||||||
|
p.add_argument('--img_dir', type=str, default='ICCAD2019/img', help='Input raw images directory')
|
||||||
|
p.add_argument('--data_dir', type=str, default='out/dataset', help='Output dataset directory')
|
||||||
|
p.add_argument('--model_dir', type=str, default='out/models', help='Directory to save models')
|
||||||
|
p.add_argument('--gen_dir', type=str, default='out/generated', help='Directory for generated samples')
|
||||||
|
p.add_argument('--epochs', type=int, default=20, help='Training epochs')
|
||||||
|
p.add_argument('--batch_size', type=int, default=16, help='Batch size')
|
||||||
|
p.add_argument('--skip_data', action='store_true', help='Skip dataset building')
|
||||||
|
p.add_argument('--skip_train', action='store_true', help='Skip training')
|
||||||
|
args = p.parse_args()
|
||||||
|
|
||||||
|
# 1. Build Dataset
|
||||||
|
if not args.skip_data:
|
||||||
|
print("=== Step 1: Building Dataset ===")
|
||||||
|
# Note: build_skeleton_dataset.py expects img_folder and out_folder
|
||||||
|
cmd = [
|
||||||
|
'python3', 'datasets/build_skeleton_dataset.py',
|
||||||
|
args.img_dir,
|
||||||
|
args.data_dir
|
||||||
|
]
|
||||||
|
run_command(cmd)
|
||||||
|
|
||||||
|
# 2. Train Model
|
||||||
|
model_path = Path(args.model_dir) / 'vae_best.pth'
|
||||||
|
if not args.skip_train:
|
||||||
|
print("=== Step 2: Training VAE ===")
|
||||||
|
cmd = [
|
||||||
|
'python3', 'train/train_skeleton_vae.py',
|
||||||
|
args.data_dir,
|
||||||
|
'--epochs', str(args.epochs),
|
||||||
|
'--batch', str(args.batch_size),
|
||||||
|
'--save_path', str(model_path),
|
||||||
|
'--sample_dir', str(Path(args.model_dir) / 'train_samples')
|
||||||
|
]
|
||||||
|
run_command(cmd)
|
||||||
|
|
||||||
|
# 3. Generate
|
||||||
|
print("=== Step 3: Generating Skeletons ===")
|
||||||
|
cmd = [
|
||||||
|
'python3', 'scripts/generate_skeleton.py',
|
||||||
|
'--model_path', str(model_path),
|
||||||
|
'--out_dir', args.gen_dir,
|
||||||
|
'--num_samples', '20'
|
||||||
|
]
|
||||||
|
run_command(cmd)
|
||||||
|
|
||||||
|
print("=== Experiment Complete ===")
|
||||||
|
print(f"Generated images are in {args.gen_dir}")
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
43
scripts/expand_skeleton.py
Normal file
43
scripts/expand_skeleton.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
expand_skeleton.py
|
||||||
|
Rasterize Manhattan polylines (JSON) back to a binary layout image.
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
import json
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
|
||||||
|
def draw_polylines(polylines, shape, line_width=3):
|
||||||
|
h, w = shape
|
||||||
|
img = np.zeros((h, w), dtype='uint8')
|
||||||
|
for pl in polylines:
|
||||||
|
if len(pl) < 2:
|
||||||
|
continue
|
||||||
|
pts = [(int(x), int(y)) for x,y in pl]
|
||||||
|
for i in range(len(pts)-1):
|
||||||
|
p0 = pts[i]
|
||||||
|
p1 = pts[i+1]
|
||||||
|
cv2.line(img, p0, p1, color=255, thickness=line_width)
|
||||||
|
return (img > 127).astype('uint8')
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
p = argparse.ArgumentParser()
|
||||||
|
p.add_argument('polyjson')
|
||||||
|
p.add_argument('outpng')
|
||||||
|
p.add_argument('--line-width', type=int, default=3)
|
||||||
|
args = p.parse_args()
|
||||||
|
|
||||||
|
data = json.load(open(args.polyjson,'r'))
|
||||||
|
polylines = data.get('polylines', [])
|
||||||
|
shape = tuple(data.get('shape', [512,512]))
|
||||||
|
img = draw_polylines(polylines, shape, line_width=args.line_width)
|
||||||
|
cv2.imwrite(args.outpng, img*255)
|
||||||
|
print('Wrote', args.outpng)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
63
scripts/generate_skeleton.py
Normal file
63
scripts/generate_skeleton.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Generate new skeletons using the trained VAE model.
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
import torch
|
||||||
|
from torchvision.utils import save_image
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# Add root to path to allow importing models
|
||||||
|
sys.path.append(str(Path(__file__).parents[1]))
|
||||||
|
|
||||||
|
from models.skeleton_vae import SkeletonVAE
|
||||||
|
|
||||||
|
def main():
|
||||||
|
p = argparse.ArgumentParser()
|
||||||
|
p.add_argument('--model_path', type=str, required=True, help='Path to trained .pth model')
|
||||||
|
p.add_argument('--out_dir', type=str, default='out/generated', help='Output directory for generated images')
|
||||||
|
p.add_argument('--num_samples', type=int, default=10, help='Number of samples to generate')
|
||||||
|
p.add_argument('--z_dim', type=int, default=64, help='Latent dimension size')
|
||||||
|
p.add_argument('--threshold', type=float, default=0.5, help='Binarization threshold')
|
||||||
|
args = p.parse_args()
|
||||||
|
|
||||||
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
|
||||||
|
# Load model
|
||||||
|
model = SkeletonVAE(z_dim=args.z_dim).to(device)
|
||||||
|
if not Path(args.model_path).exists():
|
||||||
|
print(f"Error: Model not found at {args.model_path}")
|
||||||
|
return
|
||||||
|
|
||||||
|
model.load_state_dict(torch.load(args.model_path, map_location=device))
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
out_dir = Path(args.out_dir)
|
||||||
|
out_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
print(f"Generating {args.num_samples} samples...")
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
# Sample from standard normal distribution
|
||||||
|
z = torch.randn(args.num_samples, args.z_dim).to(device)
|
||||||
|
|
||||||
|
# Decode
|
||||||
|
generated = model.decoder(z)
|
||||||
|
|
||||||
|
# Save raw and binarized
|
||||||
|
for i in range(args.num_samples):
|
||||||
|
img = generated[i]
|
||||||
|
|
||||||
|
# Save raw probability map
|
||||||
|
save_image(img, out_dir / f'gen_{i}_raw.png')
|
||||||
|
|
||||||
|
# Binarize
|
||||||
|
bin_img = (img > args.threshold).float()
|
||||||
|
save_image(bin_img, out_dir / f'gen_{i}_bin.png')
|
||||||
|
|
||||||
|
print(f"Saved generated images to {out_dir}")
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
79
scripts/skeleton_extract.py
Normal file
79
scripts/skeleton_extract.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
skeleton_extract.py
|
||||||
|
Extract Manhattan-style skeleton from binary layout images.
|
||||||
|
Supports optional color inversion and basic denoising.
|
||||||
|
Outputs skeleton PNG and a JSON summary with connected-component counts.
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
import cv2
|
||||||
|
from skimage.morphology import skeletonize
|
||||||
|
from skimage import img_as_ubyte
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
def load_image(path, invert=False):
|
||||||
|
im = Image.open(path).convert('L')
|
||||||
|
a = np.array(im)
|
||||||
|
# auto threshold by Otsu
|
||||||
|
_, th = cv2.threshold(a, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
||||||
|
if invert:
|
||||||
|
th = 255 - th
|
||||||
|
return th // 255
|
||||||
|
|
||||||
|
|
||||||
|
def denoise(bin_img, kernel=3):
|
||||||
|
# simple opening/closing
|
||||||
|
k = cv2.getStructuringElement(cv2.MORPH_RECT, (kernel, kernel))
|
||||||
|
img = cv2.morphologyEx((bin_img * 255).astype('uint8'), cv2.MORPH_OPEN, k)
|
||||||
|
img = cv2.morphologyEx(img, cv2.MORPH_CLOSE, k)
|
||||||
|
return img // 255
|
||||||
|
|
||||||
|
|
||||||
|
def extract_skeleton(bin_img):
|
||||||
|
# skimage expects bool image
|
||||||
|
sk = skeletonize(bin_img > 0)
|
||||||
|
return sk.astype('uint8')
|
||||||
|
|
||||||
|
|
||||||
|
def save_png(arr, path):
|
||||||
|
im = Image.fromarray((arr * 255).astype('uint8'))
|
||||||
|
im.save(path)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
p = argparse.ArgumentParser()
|
||||||
|
p.add_argument('input')
|
||||||
|
p.add_argument('outdir')
|
||||||
|
p.add_argument('--invert', action='store_true', help='Invert black/white before processing')
|
||||||
|
p.add_argument('--denoise', type=int, default=3, help='Denoise kernel size')
|
||||||
|
args = p.parse_args()
|
||||||
|
|
||||||
|
inp = Path(args.input)
|
||||||
|
out = Path(args.outdir)
|
||||||
|
out.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
bin_img = load_image(inp, invert=args.invert)
|
||||||
|
if args.denoise and args.denoise > 0:
|
||||||
|
bin_img = denoise(bin_img, kernel=args.denoise)
|
||||||
|
|
||||||
|
sk = extract_skeleton(bin_img)
|
||||||
|
|
||||||
|
save_png(bin_img, out / (inp.stem + '_bin.png'))
|
||||||
|
save_png(sk, out / (inp.stem + '_sk.png'))
|
||||||
|
|
||||||
|
# summary
|
||||||
|
num_pixels = int(sk.sum())
|
||||||
|
components = cv2.connectedComponents((sk * 255).astype('uint8'))[0] - 1
|
||||||
|
info = {'input': str(inp), 'pixels_in_skeleton': int(num_pixels), 'components': int(components)}
|
||||||
|
with open(out / (inp.stem + '_sk.json'), 'w') as f:
|
||||||
|
json.dump(info, f, indent=2)
|
||||||
|
|
||||||
|
print('Saved:', out / (inp.stem + '_sk.png'))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
308
scripts/vectorize_skeleton.py
Normal file
308
scripts/vectorize_skeleton.py
Normal file
@@ -0,0 +1,308 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
vectorize_skeleton.py
|
||||||
|
Trace skeleton PNG to Manhattan polylines (simple 4-neighbor tracing) and export JSON.
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
import json
|
||||||
|
import cv2
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
|
|
||||||
|
def load_skeleton(path):
|
||||||
|
im = Image.open(path).convert('L')
|
||||||
|
a = np.array(im)
|
||||||
|
return (a > 127).astype('uint8')
|
||||||
|
|
||||||
|
|
||||||
|
def neighbors4(y, x, h, w):
|
||||||
|
for dy, dx in ((0,1),(1,0),(0,-1),(-1,0)):
|
||||||
|
ny, nx = y+dy, x+dx
|
||||||
|
if 0 <= ny < h and 0 <= nx < w:
|
||||||
|
yield ny, nx
|
||||||
|
|
||||||
|
|
||||||
|
def trace_components(sk):
|
||||||
|
h, w = sk.shape
|
||||||
|
visited = np.zeros_like(sk, dtype=bool)
|
||||||
|
comps = []
|
||||||
|
for y in range(h):
|
||||||
|
for x in range(w):
|
||||||
|
if sk[y,x] and not visited[y,x]:
|
||||||
|
# BFS to collect component pixels
|
||||||
|
q = deque()
|
||||||
|
q.append((y,x))
|
||||||
|
visited[y,x] = True
|
||||||
|
pts = []
|
||||||
|
while q:
|
||||||
|
cy, cx = q.popleft()
|
||||||
|
pts.append((int(cx), int(cy)))
|
||||||
|
for ny, nx in neighbors4(cy, cx, h, w):
|
||||||
|
if sk[ny,nx] and not visited[ny,nx]:
|
||||||
|
visited[ny,nx] = True
|
||||||
|
q.append((ny,nx))
|
||||||
|
comps.append(pts)
|
||||||
|
return comps
|
||||||
|
|
||||||
|
|
||||||
|
def prune_segments(segments, min_len=5, merge_tol=3):
|
||||||
|
if not segments: return []
|
||||||
|
|
||||||
|
# 1. Merge consecutive same-direction segments (simple pass)
|
||||||
|
# Repeat until no more merges to handle multi-segment merges
|
||||||
|
while True:
|
||||||
|
merged = []
|
||||||
|
changed = False
|
||||||
|
if segments:
|
||||||
|
curr = segments[0]
|
||||||
|
for i in range(1, len(segments)):
|
||||||
|
next_seg = segments[i]
|
||||||
|
if curr[0] == next_seg[0] and abs(curr[1] - next_seg[1]) <= merge_tol:
|
||||||
|
# Merge: extend end_idx, re-calculate average val
|
||||||
|
len1 = curr[3] - curr[2] + 1
|
||||||
|
len2 = next_seg[3] - next_seg[2] + 1
|
||||||
|
new_val = int(round((curr[1]*len1 + next_seg[1]*len2) / (len1+len2)))
|
||||||
|
curr = (curr[0], new_val, curr[2], next_seg[3])
|
||||||
|
changed = True
|
||||||
|
else:
|
||||||
|
merged.append(curr)
|
||||||
|
curr = next_seg
|
||||||
|
merged.append(curr)
|
||||||
|
segments = merged
|
||||||
|
if not changed:
|
||||||
|
break
|
||||||
|
|
||||||
|
# 2. Remove short tails
|
||||||
|
# Check start
|
||||||
|
if len(segments) > 1:
|
||||||
|
s0 = segments[0]
|
||||||
|
l0 = s0[3] - s0[2] + 1
|
||||||
|
if l0 < min_len:
|
||||||
|
segments.pop(0)
|
||||||
|
|
||||||
|
# Check end
|
||||||
|
if len(segments) > 1:
|
||||||
|
s_last = segments[-1]
|
||||||
|
l_last = s_last[3] - s_last[2] + 1
|
||||||
|
if l_last < min_len:
|
||||||
|
segments.pop(-1)
|
||||||
|
|
||||||
|
# 3. Remove short internal bumps (Z-shape)
|
||||||
|
# H1 -> V(short) -> H2 => Merge H1, H2 if aligned
|
||||||
|
# We iterate and build a new list. If merge happens, we modify the next segment and skip.
|
||||||
|
final_segs = []
|
||||||
|
i = 0
|
||||||
|
while i < len(segments):
|
||||||
|
curr = segments[i]
|
||||||
|
merged_bump = False
|
||||||
|
if i + 2 < len(segments):
|
||||||
|
mid = segments[i+1]
|
||||||
|
next_seg = segments[i+2]
|
||||||
|
|
||||||
|
# Check pattern: A -> B(short) -> A
|
||||||
|
if curr[0] == next_seg[0] and mid[0] != curr[0]:
|
||||||
|
len_mid = mid[3] - mid[2] + 1
|
||||||
|
if len_mid < min_len:
|
||||||
|
# Check alignment of curr and next_seg
|
||||||
|
if abs(curr[1] - next_seg[1]) <= merge_tol:
|
||||||
|
# Merge all three: curr + mid + next_seg -> new_curr
|
||||||
|
len1 = curr[3] - curr[2] + 1
|
||||||
|
len3 = next_seg[3] - next_seg[2] + 1
|
||||||
|
new_val = int(round((curr[1]*len1 + next_seg[1]*len3) / (len1+len3)))
|
||||||
|
# New segment spans from curr.start to next_seg.end
|
||||||
|
new_seg = (curr[0], new_val, curr[2], next_seg[3])
|
||||||
|
|
||||||
|
# We effectively skip mid and next_seg, and replace curr with new_seg
|
||||||
|
# But we need to check if this new_seg can merge further?
|
||||||
|
# For simplicity, let's push new_seg to final_segs and skip 2.
|
||||||
|
# But wait, if we push to final_segs, we can't merge it with subsequent ones in this loop easily.
|
||||||
|
# Let's update segments[i+2] to be the merged one and continue loop from i+2.
|
||||||
|
segments[i+2] = new_seg
|
||||||
|
i += 2
|
||||||
|
merged_bump = True
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not merged_bump:
|
||||||
|
final_segs.append(curr)
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
return final_segs
|
||||||
|
|
||||||
|
|
||||||
|
def simplify_manhattan(pts, tolerance=2):
|
||||||
|
if not pts: return []
|
||||||
|
if len(pts) < 2: return pts
|
||||||
|
|
||||||
|
segments = []
|
||||||
|
n = len(pts)
|
||||||
|
i = 0
|
||||||
|
while i < n - 1:
|
||||||
|
start_pt = pts[i]
|
||||||
|
|
||||||
|
# Check Horizontal
|
||||||
|
k_h = i + 1
|
||||||
|
while k_h < n:
|
||||||
|
if abs(pts[k_h][1] - start_pt[1]) > tolerance:
|
||||||
|
break
|
||||||
|
k_h += 1
|
||||||
|
len_h = k_h - i
|
||||||
|
|
||||||
|
# Check Vertical
|
||||||
|
k_v = i + 1
|
||||||
|
while k_v < n:
|
||||||
|
if abs(pts[k_v][0] - start_pt[0]) > tolerance:
|
||||||
|
break
|
||||||
|
k_v += 1
|
||||||
|
len_v = k_v - i
|
||||||
|
|
||||||
|
if len_h >= len_v:
|
||||||
|
# Horizontal
|
||||||
|
segment_pts = pts[i:k_h]
|
||||||
|
avg_y = int(round(np.mean([p[1] for p in segment_pts])))
|
||||||
|
segments.append(('H', avg_y, i, k_h - 1))
|
||||||
|
i = k_h - 1
|
||||||
|
else:
|
||||||
|
# Vertical
|
||||||
|
segment_pts = pts[i:k_v]
|
||||||
|
avg_x = int(round(np.mean([p[0] for p in segment_pts])))
|
||||||
|
segments.append(('V', avg_x, i, k_v - 1))
|
||||||
|
i = k_v - 1
|
||||||
|
|
||||||
|
# --- Pruning / Refining ---
|
||||||
|
segments = prune_segments(segments, min_len=5, merge_tol=3)
|
||||||
|
|
||||||
|
if not segments:
|
||||||
|
return []
|
||||||
|
|
||||||
|
out_poly = []
|
||||||
|
|
||||||
|
# First point
|
||||||
|
first_seg = segments[0]
|
||||||
|
first_pt_orig = pts[first_seg[2]]
|
||||||
|
if first_seg[0] == 'H':
|
||||||
|
curr = (first_pt_orig[0], first_seg[1])
|
||||||
|
else:
|
||||||
|
curr = (first_seg[1], first_pt_orig[1])
|
||||||
|
out_poly.append(curr)
|
||||||
|
|
||||||
|
for idx in range(len(segments) - 1):
|
||||||
|
s1 = segments[idx]
|
||||||
|
s2 = segments[idx+1]
|
||||||
|
|
||||||
|
# Transition point index is s1[3] (which is same as s2[2])
|
||||||
|
# But after pruning, s1[3] might not be s2[2] - 1. Gaps might exist.
|
||||||
|
# We should just connect s1's end to s2's start via a Manhattan corner.
|
||||||
|
|
||||||
|
# s1 end point (projected)
|
||||||
|
if s1[0] == 'H':
|
||||||
|
p1_end = (pts[s1[3]][0], s1[1])
|
||||||
|
else:
|
||||||
|
p1_end = (s1[1], pts[s1[3]][1])
|
||||||
|
|
||||||
|
# s2 start point (projected)
|
||||||
|
if s2[0] == 'H':
|
||||||
|
p2_start = (pts[s2[2]][0], s2[1])
|
||||||
|
else:
|
||||||
|
p2_start = (s2[1], pts[s2[2]][1])
|
||||||
|
|
||||||
|
# Connect p1_end to p2_start
|
||||||
|
if s1[0] == 'H' and s2[0] == 'V':
|
||||||
|
# H(y1) -> V(x2)
|
||||||
|
# Intersection is (x2, y1)
|
||||||
|
corner = (s2[1], s1[1])
|
||||||
|
if out_poly[-1] != corner: out_poly.append(corner)
|
||||||
|
elif s1[0] == 'V' and s2[0] == 'H':
|
||||||
|
# V(x1) -> H(y2)
|
||||||
|
# Intersection is (x1, y2)
|
||||||
|
corner = (s1[1], s2[1])
|
||||||
|
if out_poly[-1] != corner: out_poly.append(corner)
|
||||||
|
else:
|
||||||
|
# Parallel segments (should have been merged, but if gap was large...)
|
||||||
|
# Or H -> H (gap)
|
||||||
|
# Just connect via midpoint or direct L-shape?
|
||||||
|
# Let's use p1_end -> p2_start directly? No, need Manhattan.
|
||||||
|
# H(y1) ... H(y2). Connect (x_end1, y1) -> (x_end1, y2) -> (x_start2, y2) ?
|
||||||
|
# Or (x_end1, y1) -> (x_start2, y1) -> (x_start2, y2) ?
|
||||||
|
# Let's use the first one (Vertical bridge).
|
||||||
|
if out_poly[-1] != p1_end: out_poly.append(p1_end)
|
||||||
|
|
||||||
|
if s1[0] == 'H':
|
||||||
|
# Bridge is Vertical
|
||||||
|
mid_x = (p1_end[0] + p2_start[0]) // 2
|
||||||
|
c1 = (mid_x, p1_end[1])
|
||||||
|
c2 = (mid_x, p2_start[1])
|
||||||
|
if out_poly[-1] != c1: out_poly.append(c1)
|
||||||
|
if c1 != c2: out_poly.append(c2)
|
||||||
|
if c2 != p2_start: out_poly.append(p2_start)
|
||||||
|
else:
|
||||||
|
# Bridge is Horizontal
|
||||||
|
mid_y = (p1_end[1] + p2_start[1]) // 2
|
||||||
|
c1 = (p1_end[0], mid_y)
|
||||||
|
c2 = (p2_start[0], mid_y)
|
||||||
|
if out_poly[-1] != c1: out_poly.append(c1)
|
||||||
|
if c1 != c2: out_poly.append(c2)
|
||||||
|
if c2 != p2_start: out_poly.append(p2_start)
|
||||||
|
|
||||||
|
# Last point
|
||||||
|
last_seg = segments[-1]
|
||||||
|
last_pt_orig = pts[last_seg[3]]
|
||||||
|
if last_seg[0] == 'H':
|
||||||
|
end = (last_pt_orig[0], last_seg[1])
|
||||||
|
else:
|
||||||
|
end = (last_seg[1], last_pt_orig[1])
|
||||||
|
if out_poly[-1] != end: out_poly.append(end)
|
||||||
|
|
||||||
|
return out_poly
|
||||||
|
|
||||||
|
|
||||||
|
def polyline_from_pixels(pts):
|
||||||
|
# pts: list of (x,y) unordered; produce simple ordering by greedy nearest neighbor
|
||||||
|
if not pts:
|
||||||
|
return []
|
||||||
|
pts_set = set(pts)
|
||||||
|
# start from leftmost-topmost
|
||||||
|
cur = min(pts, key=lambda p: (p[1], p[0]))
|
||||||
|
seq = [cur]
|
||||||
|
pts_set.remove(cur)
|
||||||
|
while pts_set:
|
||||||
|
# look for 4-neighbor next
|
||||||
|
x,y = seq[-1]
|
||||||
|
found = None
|
||||||
|
for nx, ny in ((x+1,y),(x-1,y),(x,y+1),(x,y-1)):
|
||||||
|
if (nx,ny) in pts_set:
|
||||||
|
found = (nx,ny)
|
||||||
|
break
|
||||||
|
if found is None:
|
||||||
|
# fallback: nearest
|
||||||
|
found = min(pts_set, key=lambda p: (abs(p[0]-x)+abs(p[1]-y), p[1], p[0]))
|
||||||
|
seq.append(found)
|
||||||
|
pts_set.remove(found)
|
||||||
|
|
||||||
|
return simplify_manhattan(seq, tolerance=2)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
p = argparse.ArgumentParser()
|
||||||
|
p.add_argument('skel_png')
|
||||||
|
p.add_argument('outjson')
|
||||||
|
args = p.parse_args()
|
||||||
|
|
||||||
|
sk = load_skeleton(args.skel_png)
|
||||||
|
comps = trace_components(sk)
|
||||||
|
polylines = []
|
||||||
|
for pts in comps:
|
||||||
|
pl = polyline_from_pixels(pts)
|
||||||
|
if len(pl) > 1:
|
||||||
|
polylines.append(pl)
|
||||||
|
|
||||||
|
out = {'polylines': polylines, 'num': len(polylines), 'shape': sk.shape}
|
||||||
|
with open(args.outjson, 'w') as f:
|
||||||
|
json.dump(out, f)
|
||||||
|
print('Wrote', args.outjson)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
136
train/train_skeleton_vae.py
Normal file
136
train/train_skeleton_vae.py
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Minimal training loop for SkeletonVAE. Expects a directory of skeleton PNGs.
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
from PIL import Image
|
||||||
|
import torchvision.transforms as T
|
||||||
|
import numpy as np
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from models.skeleton_vae import SkeletonVAE
|
||||||
|
|
||||||
|
|
||||||
|
class SkeletonDataset(Dataset):
|
||||||
|
def __init__(self, folder, size=64):
|
||||||
|
self.files = list(Path(folder).glob('**/*_sk.png'))
|
||||||
|
self.tr = T.Compose([T.Resize((size,size)), T.ToTensor()])
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.files)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
p = self.files[idx]
|
||||||
|
im = Image.open(p).convert('L')
|
||||||
|
t = self.tr(im)
|
||||||
|
return t
|
||||||
|
|
||||||
|
|
||||||
|
def loss_fn(recon_x, x, mu, logvar):
|
||||||
|
BCE = torch.nn.functional.binary_cross_entropy(recon_x, x, reduction='sum')
|
||||||
|
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
|
||||||
|
return BCE + KLD, BCE, KLD
|
||||||
|
|
||||||
|
|
||||||
|
def setup_logger(log_dir):
|
||||||
|
log_dir = Path(log_dir)
|
||||||
|
log_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||||
|
handlers=[
|
||||||
|
logging.FileHandler(log_dir / 'train.log'),
|
||||||
|
logging.StreamHandler()
|
||||||
|
],
|
||||||
|
force=True
|
||||||
|
)
|
||||||
|
return logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
p = argparse.ArgumentParser()
|
||||||
|
p.add_argument('skel_folder')
|
||||||
|
p.add_argument('--epochs', type=int, default=10)
|
||||||
|
p.add_argument('--batch', type=int, default=16)
|
||||||
|
p.add_argument('--lr', type=float, default=1e-3)
|
||||||
|
p.add_argument('--save_path', type=str, default='out/vae_model.pth')
|
||||||
|
p.add_argument('--sample_dir', type=str, default='out/samples')
|
||||||
|
args = p.parse_args()
|
||||||
|
|
||||||
|
Path(args.sample_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
Path(args.save_path).parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
logger = setup_logger(Path(args.save_path).parent)
|
||||||
|
|
||||||
|
ds = SkeletonDataset(args.skel_folder, size=64)
|
||||||
|
if len(ds) == 0:
|
||||||
|
logger.error(f"No skeleton images found in {args.skel_folder}")
|
||||||
|
return
|
||||||
|
|
||||||
|
dl = DataLoader(ds, batch_size=args.batch, shuffle=True, num_workers=4)
|
||||||
|
|
||||||
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
logger.info(f"Using device: {device}")
|
||||||
|
logger.info(f"Dataset size: {len(ds)} images")
|
||||||
|
|
||||||
|
model = SkeletonVAE(z_dim=64).to(device)
|
||||||
|
opt = torch.optim.Adam(model.parameters(), lr=args.lr)
|
||||||
|
|
||||||
|
logger.info("Starting training...")
|
||||||
|
for epoch in range(args.epochs):
|
||||||
|
start_time = time.time()
|
||||||
|
model.train()
|
||||||
|
total_loss = 0.0
|
||||||
|
total_bce = 0.0
|
||||||
|
total_kld = 0.0
|
||||||
|
|
||||||
|
for i, xb in enumerate(dl):
|
||||||
|
xb = xb.to(device)
|
||||||
|
recon, mu, logvar = model(xb)
|
||||||
|
loss, bce, kld = loss_fn(recon, xb, mu, logvar)
|
||||||
|
|
||||||
|
opt.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
opt.step()
|
||||||
|
|
||||||
|
total_loss += loss.item()
|
||||||
|
total_bce += bce.item()
|
||||||
|
total_kld += kld.item()
|
||||||
|
|
||||||
|
if (i + 1) % 10 == 0:
|
||||||
|
logger.info(f"Epoch [{epoch+1}/{args.epochs}] Batch [{i+1}/{len(dl)}] "
|
||||||
|
f"Loss: {loss.item()/len(xb):.4f} (BCE: {bce.item()/len(xb):.4f}, KLD: {kld.item()/len(xb):.4f})")
|
||||||
|
|
||||||
|
avg_loss = total_loss / len(ds)
|
||||||
|
avg_bce = total_bce / len(ds)
|
||||||
|
avg_kld = total_kld / len(ds)
|
||||||
|
epoch_time = time.time() - start_time
|
||||||
|
|
||||||
|
logger.info(f'Epoch {epoch+1}/{args.epochs} completed in {epoch_time:.2f}s')
|
||||||
|
logger.info(f' Average Loss: {avg_loss:.4f}')
|
||||||
|
logger.info(f' BCE: {avg_bce:.4f} | KLD: {avg_kld:.4f}')
|
||||||
|
|
||||||
|
# Save model
|
||||||
|
torch.save(model.state_dict(), args.save_path)
|
||||||
|
|
||||||
|
# Save sample reconstructions
|
||||||
|
if (epoch + 1) % 5 == 0:
|
||||||
|
logger.info(f"Saving reconstruction samples to {args.sample_dir}")
|
||||||
|
with torch.no_grad():
|
||||||
|
model.eval()
|
||||||
|
# Reconstruct first batch
|
||||||
|
xb = next(iter(dl)).to(device)
|
||||||
|
recon, _, _ = model(xb)
|
||||||
|
# Concat input and recon
|
||||||
|
vis = torch.cat([xb[:8], recon[:8]], dim=0)
|
||||||
|
# Save grid
|
||||||
|
from torchvision.utils import save_image
|
||||||
|
save_image(vis, Path(args.sample_dir) / f'epoch_{epoch+1}.png', nrow=8)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user