initial commit
This commit is contained in:
43
scripts/expand_skeleton.py
Normal file
43
scripts/expand_skeleton.py
Normal file
@@ -0,0 +1,43 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
expand_skeleton.py
|
||||
Rasterize Manhattan polylines (JSON) back to a binary layout image.
|
||||
"""
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
import json
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
|
||||
def draw_polylines(polylines, shape, line_width=3):
|
||||
h, w = shape
|
||||
img = np.zeros((h, w), dtype='uint8')
|
||||
for pl in polylines:
|
||||
if len(pl) < 2:
|
||||
continue
|
||||
pts = [(int(x), int(y)) for x,y in pl]
|
||||
for i in range(len(pts)-1):
|
||||
p0 = pts[i]
|
||||
p1 = pts[i+1]
|
||||
cv2.line(img, p0, p1, color=255, thickness=line_width)
|
||||
return (img > 127).astype('uint8')
|
||||
|
||||
|
||||
def main():
|
||||
p = argparse.ArgumentParser()
|
||||
p.add_argument('polyjson')
|
||||
p.add_argument('outpng')
|
||||
p.add_argument('--line-width', type=int, default=3)
|
||||
args = p.parse_args()
|
||||
|
||||
data = json.load(open(args.polyjson,'r'))
|
||||
polylines = data.get('polylines', [])
|
||||
shape = tuple(data.get('shape', [512,512]))
|
||||
img = draw_polylines(polylines, shape, line_width=args.line_width)
|
||||
cv2.imwrite(args.outpng, img*255)
|
||||
print('Wrote', args.outpng)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
63
scripts/generate_skeleton.py
Normal file
63
scripts/generate_skeleton.py
Normal file
@@ -0,0 +1,63 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Generate new skeletons using the trained VAE model.
|
||||
"""
|
||||
import argparse
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import torch
|
||||
from torchvision.utils import save_image
|
||||
import numpy as np
|
||||
|
||||
# Add root to path to allow importing models
|
||||
sys.path.append(str(Path(__file__).parents[1]))
|
||||
|
||||
from models.skeleton_vae import SkeletonVAE
|
||||
|
||||
def main():
|
||||
p = argparse.ArgumentParser()
|
||||
p.add_argument('--model_path', type=str, required=True, help='Path to trained .pth model')
|
||||
p.add_argument('--out_dir', type=str, default='out/generated', help='Output directory for generated images')
|
||||
p.add_argument('--num_samples', type=int, default=10, help='Number of samples to generate')
|
||||
p.add_argument('--z_dim', type=int, default=64, help='Latent dimension size')
|
||||
p.add_argument('--threshold', type=float, default=0.5, help='Binarization threshold')
|
||||
args = p.parse_args()
|
||||
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
# Load model
|
||||
model = SkeletonVAE(z_dim=args.z_dim).to(device)
|
||||
if not Path(args.model_path).exists():
|
||||
print(f"Error: Model not found at {args.model_path}")
|
||||
return
|
||||
|
||||
model.load_state_dict(torch.load(args.model_path, map_location=device))
|
||||
model.eval()
|
||||
|
||||
out_dir = Path(args.out_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"Generating {args.num_samples} samples...")
|
||||
|
||||
with torch.no_grad():
|
||||
# Sample from standard normal distribution
|
||||
z = torch.randn(args.num_samples, args.z_dim).to(device)
|
||||
|
||||
# Decode
|
||||
generated = model.decoder(z)
|
||||
|
||||
# Save raw and binarized
|
||||
for i in range(args.num_samples):
|
||||
img = generated[i]
|
||||
|
||||
# Save raw probability map
|
||||
save_image(img, out_dir / f'gen_{i}_raw.png')
|
||||
|
||||
# Binarize
|
||||
bin_img = (img > args.threshold).float()
|
||||
save_image(bin_img, out_dir / f'gen_{i}_bin.png')
|
||||
|
||||
print(f"Saved generated images to {out_dir}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
79
scripts/skeleton_extract.py
Normal file
79
scripts/skeleton_extract.py
Normal file
@@ -0,0 +1,79 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
skeleton_extract.py
|
||||
Extract Manhattan-style skeleton from binary layout images.
|
||||
Supports optional color inversion and basic denoising.
|
||||
Outputs skeleton PNG and a JSON summary with connected-component counts.
|
||||
"""
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import cv2
|
||||
from skimage.morphology import skeletonize
|
||||
from skimage import img_as_ubyte
|
||||
import json
|
||||
|
||||
|
||||
def load_image(path, invert=False):
|
||||
im = Image.open(path).convert('L')
|
||||
a = np.array(im)
|
||||
# auto threshold by Otsu
|
||||
_, th = cv2.threshold(a, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
||||
if invert:
|
||||
th = 255 - th
|
||||
return th // 255
|
||||
|
||||
|
||||
def denoise(bin_img, kernel=3):
|
||||
# simple opening/closing
|
||||
k = cv2.getStructuringElement(cv2.MORPH_RECT, (kernel, kernel))
|
||||
img = cv2.morphologyEx((bin_img * 255).astype('uint8'), cv2.MORPH_OPEN, k)
|
||||
img = cv2.morphologyEx(img, cv2.MORPH_CLOSE, k)
|
||||
return img // 255
|
||||
|
||||
|
||||
def extract_skeleton(bin_img):
|
||||
# skimage expects bool image
|
||||
sk = skeletonize(bin_img > 0)
|
||||
return sk.astype('uint8')
|
||||
|
||||
|
||||
def save_png(arr, path):
|
||||
im = Image.fromarray((arr * 255).astype('uint8'))
|
||||
im.save(path)
|
||||
|
||||
|
||||
def main():
|
||||
p = argparse.ArgumentParser()
|
||||
p.add_argument('input')
|
||||
p.add_argument('outdir')
|
||||
p.add_argument('--invert', action='store_true', help='Invert black/white before processing')
|
||||
p.add_argument('--denoise', type=int, default=3, help='Denoise kernel size')
|
||||
args = p.parse_args()
|
||||
|
||||
inp = Path(args.input)
|
||||
out = Path(args.outdir)
|
||||
out.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
bin_img = load_image(inp, invert=args.invert)
|
||||
if args.denoise and args.denoise > 0:
|
||||
bin_img = denoise(bin_img, kernel=args.denoise)
|
||||
|
||||
sk = extract_skeleton(bin_img)
|
||||
|
||||
save_png(bin_img, out / (inp.stem + '_bin.png'))
|
||||
save_png(sk, out / (inp.stem + '_sk.png'))
|
||||
|
||||
# summary
|
||||
num_pixels = int(sk.sum())
|
||||
components = cv2.connectedComponents((sk * 255).astype('uint8'))[0] - 1
|
||||
info = {'input': str(inp), 'pixels_in_skeleton': int(num_pixels), 'components': int(components)}
|
||||
with open(out / (inp.stem + '_sk.json'), 'w') as f:
|
||||
json.dump(info, f, indent=2)
|
||||
|
||||
print('Saved:', out / (inp.stem + '_sk.png'))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
308
scripts/vectorize_skeleton.py
Normal file
308
scripts/vectorize_skeleton.py
Normal file
@@ -0,0 +1,308 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
vectorize_skeleton.py
|
||||
Trace skeleton PNG to Manhattan polylines (simple 4-neighbor tracing) and export JSON.
|
||||
"""
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import json
|
||||
import cv2
|
||||
from collections import deque
|
||||
|
||||
|
||||
def load_skeleton(path):
|
||||
im = Image.open(path).convert('L')
|
||||
a = np.array(im)
|
||||
return (a > 127).astype('uint8')
|
||||
|
||||
|
||||
def neighbors4(y, x, h, w):
|
||||
for dy, dx in ((0,1),(1,0),(0,-1),(-1,0)):
|
||||
ny, nx = y+dy, x+dx
|
||||
if 0 <= ny < h and 0 <= nx < w:
|
||||
yield ny, nx
|
||||
|
||||
|
||||
def trace_components(sk):
|
||||
h, w = sk.shape
|
||||
visited = np.zeros_like(sk, dtype=bool)
|
||||
comps = []
|
||||
for y in range(h):
|
||||
for x in range(w):
|
||||
if sk[y,x] and not visited[y,x]:
|
||||
# BFS to collect component pixels
|
||||
q = deque()
|
||||
q.append((y,x))
|
||||
visited[y,x] = True
|
||||
pts = []
|
||||
while q:
|
||||
cy, cx = q.popleft()
|
||||
pts.append((int(cx), int(cy)))
|
||||
for ny, nx in neighbors4(cy, cx, h, w):
|
||||
if sk[ny,nx] and not visited[ny,nx]:
|
||||
visited[ny,nx] = True
|
||||
q.append((ny,nx))
|
||||
comps.append(pts)
|
||||
return comps
|
||||
|
||||
|
||||
def prune_segments(segments, min_len=5, merge_tol=3):
|
||||
if not segments: return []
|
||||
|
||||
# 1. Merge consecutive same-direction segments (simple pass)
|
||||
# Repeat until no more merges to handle multi-segment merges
|
||||
while True:
|
||||
merged = []
|
||||
changed = False
|
||||
if segments:
|
||||
curr = segments[0]
|
||||
for i in range(1, len(segments)):
|
||||
next_seg = segments[i]
|
||||
if curr[0] == next_seg[0] and abs(curr[1] - next_seg[1]) <= merge_tol:
|
||||
# Merge: extend end_idx, re-calculate average val
|
||||
len1 = curr[3] - curr[2] + 1
|
||||
len2 = next_seg[3] - next_seg[2] + 1
|
||||
new_val = int(round((curr[1]*len1 + next_seg[1]*len2) / (len1+len2)))
|
||||
curr = (curr[0], new_val, curr[2], next_seg[3])
|
||||
changed = True
|
||||
else:
|
||||
merged.append(curr)
|
||||
curr = next_seg
|
||||
merged.append(curr)
|
||||
segments = merged
|
||||
if not changed:
|
||||
break
|
||||
|
||||
# 2. Remove short tails
|
||||
# Check start
|
||||
if len(segments) > 1:
|
||||
s0 = segments[0]
|
||||
l0 = s0[3] - s0[2] + 1
|
||||
if l0 < min_len:
|
||||
segments.pop(0)
|
||||
|
||||
# Check end
|
||||
if len(segments) > 1:
|
||||
s_last = segments[-1]
|
||||
l_last = s_last[3] - s_last[2] + 1
|
||||
if l_last < min_len:
|
||||
segments.pop(-1)
|
||||
|
||||
# 3. Remove short internal bumps (Z-shape)
|
||||
# H1 -> V(short) -> H2 => Merge H1, H2 if aligned
|
||||
# We iterate and build a new list. If merge happens, we modify the next segment and skip.
|
||||
final_segs = []
|
||||
i = 0
|
||||
while i < len(segments):
|
||||
curr = segments[i]
|
||||
merged_bump = False
|
||||
if i + 2 < len(segments):
|
||||
mid = segments[i+1]
|
||||
next_seg = segments[i+2]
|
||||
|
||||
# Check pattern: A -> B(short) -> A
|
||||
if curr[0] == next_seg[0] and mid[0] != curr[0]:
|
||||
len_mid = mid[3] - mid[2] + 1
|
||||
if len_mid < min_len:
|
||||
# Check alignment of curr and next_seg
|
||||
if abs(curr[1] - next_seg[1]) <= merge_tol:
|
||||
# Merge all three: curr + mid + next_seg -> new_curr
|
||||
len1 = curr[3] - curr[2] + 1
|
||||
len3 = next_seg[3] - next_seg[2] + 1
|
||||
new_val = int(round((curr[1]*len1 + next_seg[1]*len3) / (len1+len3)))
|
||||
# New segment spans from curr.start to next_seg.end
|
||||
new_seg = (curr[0], new_val, curr[2], next_seg[3])
|
||||
|
||||
# We effectively skip mid and next_seg, and replace curr with new_seg
|
||||
# But we need to check if this new_seg can merge further?
|
||||
# For simplicity, let's push new_seg to final_segs and skip 2.
|
||||
# But wait, if we push to final_segs, we can't merge it with subsequent ones in this loop easily.
|
||||
# Let's update segments[i+2] to be the merged one and continue loop from i+2.
|
||||
segments[i+2] = new_seg
|
||||
i += 2
|
||||
merged_bump = True
|
||||
continue
|
||||
|
||||
if not merged_bump:
|
||||
final_segs.append(curr)
|
||||
i += 1
|
||||
|
||||
return final_segs
|
||||
|
||||
|
||||
def simplify_manhattan(pts, tolerance=2):
|
||||
if not pts: return []
|
||||
if len(pts) < 2: return pts
|
||||
|
||||
segments = []
|
||||
n = len(pts)
|
||||
i = 0
|
||||
while i < n - 1:
|
||||
start_pt = pts[i]
|
||||
|
||||
# Check Horizontal
|
||||
k_h = i + 1
|
||||
while k_h < n:
|
||||
if abs(pts[k_h][1] - start_pt[1]) > tolerance:
|
||||
break
|
||||
k_h += 1
|
||||
len_h = k_h - i
|
||||
|
||||
# Check Vertical
|
||||
k_v = i + 1
|
||||
while k_v < n:
|
||||
if abs(pts[k_v][0] - start_pt[0]) > tolerance:
|
||||
break
|
||||
k_v += 1
|
||||
len_v = k_v - i
|
||||
|
||||
if len_h >= len_v:
|
||||
# Horizontal
|
||||
segment_pts = pts[i:k_h]
|
||||
avg_y = int(round(np.mean([p[1] for p in segment_pts])))
|
||||
segments.append(('H', avg_y, i, k_h - 1))
|
||||
i = k_h - 1
|
||||
else:
|
||||
# Vertical
|
||||
segment_pts = pts[i:k_v]
|
||||
avg_x = int(round(np.mean([p[0] for p in segment_pts])))
|
||||
segments.append(('V', avg_x, i, k_v - 1))
|
||||
i = k_v - 1
|
||||
|
||||
# --- Pruning / Refining ---
|
||||
segments = prune_segments(segments, min_len=5, merge_tol=3)
|
||||
|
||||
if not segments:
|
||||
return []
|
||||
|
||||
out_poly = []
|
||||
|
||||
# First point
|
||||
first_seg = segments[0]
|
||||
first_pt_orig = pts[first_seg[2]]
|
||||
if first_seg[0] == 'H':
|
||||
curr = (first_pt_orig[0], first_seg[1])
|
||||
else:
|
||||
curr = (first_seg[1], first_pt_orig[1])
|
||||
out_poly.append(curr)
|
||||
|
||||
for idx in range(len(segments) - 1):
|
||||
s1 = segments[idx]
|
||||
s2 = segments[idx+1]
|
||||
|
||||
# Transition point index is s1[3] (which is same as s2[2])
|
||||
# But after pruning, s1[3] might not be s2[2] - 1. Gaps might exist.
|
||||
# We should just connect s1's end to s2's start via a Manhattan corner.
|
||||
|
||||
# s1 end point (projected)
|
||||
if s1[0] == 'H':
|
||||
p1_end = (pts[s1[3]][0], s1[1])
|
||||
else:
|
||||
p1_end = (s1[1], pts[s1[3]][1])
|
||||
|
||||
# s2 start point (projected)
|
||||
if s2[0] == 'H':
|
||||
p2_start = (pts[s2[2]][0], s2[1])
|
||||
else:
|
||||
p2_start = (s2[1], pts[s2[2]][1])
|
||||
|
||||
# Connect p1_end to p2_start
|
||||
if s1[0] == 'H' and s2[0] == 'V':
|
||||
# H(y1) -> V(x2)
|
||||
# Intersection is (x2, y1)
|
||||
corner = (s2[1], s1[1])
|
||||
if out_poly[-1] != corner: out_poly.append(corner)
|
||||
elif s1[0] == 'V' and s2[0] == 'H':
|
||||
# V(x1) -> H(y2)
|
||||
# Intersection is (x1, y2)
|
||||
corner = (s1[1], s2[1])
|
||||
if out_poly[-1] != corner: out_poly.append(corner)
|
||||
else:
|
||||
# Parallel segments (should have been merged, but if gap was large...)
|
||||
# Or H -> H (gap)
|
||||
# Just connect via midpoint or direct L-shape?
|
||||
# Let's use p1_end -> p2_start directly? No, need Manhattan.
|
||||
# H(y1) ... H(y2). Connect (x_end1, y1) -> (x_end1, y2) -> (x_start2, y2) ?
|
||||
# Or (x_end1, y1) -> (x_start2, y1) -> (x_start2, y2) ?
|
||||
# Let's use the first one (Vertical bridge).
|
||||
if out_poly[-1] != p1_end: out_poly.append(p1_end)
|
||||
|
||||
if s1[0] == 'H':
|
||||
# Bridge is Vertical
|
||||
mid_x = (p1_end[0] + p2_start[0]) // 2
|
||||
c1 = (mid_x, p1_end[1])
|
||||
c2 = (mid_x, p2_start[1])
|
||||
if out_poly[-1] != c1: out_poly.append(c1)
|
||||
if c1 != c2: out_poly.append(c2)
|
||||
if c2 != p2_start: out_poly.append(p2_start)
|
||||
else:
|
||||
# Bridge is Horizontal
|
||||
mid_y = (p1_end[1] + p2_start[1]) // 2
|
||||
c1 = (p1_end[0], mid_y)
|
||||
c2 = (p2_start[0], mid_y)
|
||||
if out_poly[-1] != c1: out_poly.append(c1)
|
||||
if c1 != c2: out_poly.append(c2)
|
||||
if c2 != p2_start: out_poly.append(p2_start)
|
||||
|
||||
# Last point
|
||||
last_seg = segments[-1]
|
||||
last_pt_orig = pts[last_seg[3]]
|
||||
if last_seg[0] == 'H':
|
||||
end = (last_pt_orig[0], last_seg[1])
|
||||
else:
|
||||
end = (last_seg[1], last_pt_orig[1])
|
||||
if out_poly[-1] != end: out_poly.append(end)
|
||||
|
||||
return out_poly
|
||||
|
||||
|
||||
def polyline_from_pixels(pts):
|
||||
# pts: list of (x,y) unordered; produce simple ordering by greedy nearest neighbor
|
||||
if not pts:
|
||||
return []
|
||||
pts_set = set(pts)
|
||||
# start from leftmost-topmost
|
||||
cur = min(pts, key=lambda p: (p[1], p[0]))
|
||||
seq = [cur]
|
||||
pts_set.remove(cur)
|
||||
while pts_set:
|
||||
# look for 4-neighbor next
|
||||
x,y = seq[-1]
|
||||
found = None
|
||||
for nx, ny in ((x+1,y),(x-1,y),(x,y+1),(x,y-1)):
|
||||
if (nx,ny) in pts_set:
|
||||
found = (nx,ny)
|
||||
break
|
||||
if found is None:
|
||||
# fallback: nearest
|
||||
found = min(pts_set, key=lambda p: (abs(p[0]-x)+abs(p[1]-y), p[1], p[0]))
|
||||
seq.append(found)
|
||||
pts_set.remove(found)
|
||||
|
||||
return simplify_manhattan(seq, tolerance=2)
|
||||
|
||||
|
||||
def main():
|
||||
p = argparse.ArgumentParser()
|
||||
p.add_argument('skel_png')
|
||||
p.add_argument('outjson')
|
||||
args = p.parse_args()
|
||||
|
||||
sk = load_skeleton(args.skel_png)
|
||||
comps = trace_components(sk)
|
||||
polylines = []
|
||||
for pts in comps:
|
||||
pl = polyline_from_pixels(pts)
|
||||
if len(pl) > 1:
|
||||
polylines.append(pl)
|
||||
|
||||
out = {'polylines': polylines, 'num': len(polylines), 'shape': sk.shape}
|
||||
with open(args.outjson, 'w') as f:
|
||||
json.dump(out, f)
|
||||
print('Wrote', args.outjson)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Reference in New Issue
Block a user