improve IC Layout Diffussion model 20251120
This commit is contained in:
319
tools/diffusion/generate_optimized.py
Executable file
319
tools/diffusion/generate_optimized.py
Executable file
@@ -0,0 +1,319 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
使用优化后的扩散模型生成IC版图图像
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from pathlib import Path
|
||||
import logging
|
||||
import argparse
|
||||
import yaml
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
# 导入优化后的模块
|
||||
from ic_layout_diffusion_optimized import (
|
||||
ManhattanAwareUNet,
|
||||
OptimizedNoiseScheduler,
|
||||
OptimizedDiffusionTrainer,
|
||||
manhattan_post_process,
|
||||
manhattan_regularization_loss
|
||||
)
|
||||
|
||||
def setup_logging():
|
||||
"""设置日志"""
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.StreamHandler(sys.stdout),
|
||||
logging.FileHandler('diffusion_generation.log')
|
||||
]
|
||||
)
|
||||
return logging.getLogger(__name__)
|
||||
|
||||
def load_model(checkpoint_path, device):
|
||||
"""加载训练好的模型"""
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 加载检查点
|
||||
checkpoint = torch.load(checkpoint_path, map_location=device)
|
||||
|
||||
# 从检查点中获取配置信息(如果有)
|
||||
config = checkpoint.get('config', {})
|
||||
|
||||
# 创建模型
|
||||
model = ManhattanAwareUNet(
|
||||
in_channels=1,
|
||||
out_channels=1,
|
||||
use_edge_condition=config.get('edge_condition', False)
|
||||
).to(device)
|
||||
|
||||
# 加载权重
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
model.eval()
|
||||
|
||||
logger.info(f"模型已加载: {checkpoint_path}")
|
||||
logger.info(f"模型参数数量: {sum(p.numel() for p in model.parameters()):,}")
|
||||
|
||||
return model, config
|
||||
|
||||
def ddim_sample(model, scheduler, num_samples, image_size, device, num_steps=50, eta=0.0):
|
||||
"""DDIM采样,比标准DDPM更快"""
|
||||
model.eval()
|
||||
|
||||
# 从纯噪声开始
|
||||
x = torch.randn(num_samples, 1, image_size, image_size).to(device)
|
||||
|
||||
# 选择时间步
|
||||
timesteps = torch.linspace(scheduler.num_timesteps - 1, 0, num_steps).long().to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
for i, t in enumerate(timesteps):
|
||||
t_batch = torch.full((num_samples,), t, device=device)
|
||||
|
||||
# 预测噪声
|
||||
predicted_noise = model(x, t_batch)
|
||||
|
||||
# 计算原始图像的估计
|
||||
alpha_t = scheduler.alphas[t].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
||||
alpha_cumprod_t = scheduler.alphas_cumprod[t].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
||||
beta_t = scheduler.betas[t].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
||||
sqrt_one_minus_alpha_cumprod_t = scheduler.sqrt_one_minus_alphas_cumprod[t].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
||||
|
||||
# 计算x_0的估计
|
||||
x_0_pred = (x - sqrt_one_minus_alpha_cumprod_t * predicted_noise) / torch.sqrt(alpha_cumprod_t)
|
||||
|
||||
# 计算前一时间步的方向
|
||||
if i < len(timesteps) - 1:
|
||||
alpha_t_prev = scheduler.alphas[timesteps[i+1]]
|
||||
alpha_cumprod_t_prev = scheduler.alphas_cumprod[timesteps[i+1]]
|
||||
sqrt_alpha_cumprod_t_prev = torch.sqrt(alpha_cumprod_t_prev).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
||||
sqrt_one_minus_alpha_cumprod_t_prev = torch.sqrt(1 - alpha_cumprod_t_prev).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
||||
|
||||
# 计算方差
|
||||
variance = eta * torch.sqrt(beta_t).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
||||
|
||||
# 计算前一时间步的x
|
||||
x = sqrt_alpha_cumprod_t_prev * x_0_pred + torch.sqrt(1 - alpha_cumprod_t_prev - variance**2) * predicted_noise
|
||||
|
||||
if eta > 0:
|
||||
noise = torch.randn_like(x)
|
||||
x += variance * noise
|
||||
else:
|
||||
x = x_0_pred
|
||||
|
||||
# 限制范围
|
||||
x = torch.clamp(x, -2.0, 2.0)
|
||||
|
||||
return torch.clamp(x, 0.0, 1.0)
|
||||
|
||||
def generate_with_guidance(model, scheduler, num_samples, image_size, device,
|
||||
guidance_scale=1.0, num_steps=50, use_ddim=True):
|
||||
"""带引导的采样(可扩展为classifier-free guidance)"""
|
||||
|
||||
if use_ddim:
|
||||
# 使用DDIM采样
|
||||
samples = ddim_sample(model, scheduler, num_samples, image_size, device, num_steps)
|
||||
else:
|
||||
# 使用标准DDPM采样
|
||||
trainer = OptimizedDiffusionTrainer(model, scheduler, device)
|
||||
samples = trainer.generate(num_samples, image_size, save_dir=None, use_post_process=False)
|
||||
|
||||
return samples
|
||||
|
||||
def evaluate_generation_quality(samples, device):
|
||||
"""评估生成质量"""
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
quality_metrics = {}
|
||||
|
||||
# 1. 曼哈顿几何合规性
|
||||
manhattan_loss = manhattan_regularization_loss(samples, device)
|
||||
quality_metrics['manhattan_compliance'] = float(manhattan_loss.item())
|
||||
|
||||
# 2. 边缘锐度
|
||||
sobel_x = torch.tensor([[[[-1,0,1],[-2,0,2],[-1,0,1]]]], device=device, dtype=samples.dtype)
|
||||
sobel_y = torch.tensor([[[[-1,-2,-1],[0,0,0],[1,2,1]]]], device=device, dtype=samples.dtype)
|
||||
|
||||
edge_x = F.conv2d(samples, sobel_x, padding=1)
|
||||
edge_y = F.conv2d(samples, sobel_y, padding=1)
|
||||
edge_magnitude = torch.sqrt(edge_x**2 + edge_y**2)
|
||||
|
||||
quality_metrics['edge_sharpness'] = float(torch.mean(edge_magnitude).item())
|
||||
|
||||
# 3. 对比度
|
||||
quality_metrics['contrast'] = float(torch.std(samples).item())
|
||||
|
||||
# 4. 稀疏性(IC版图通常是稀疏的)
|
||||
quality_metrics['sparsity'] = float((samples < 0.1).float().mean().item())
|
||||
|
||||
logger.info("生成质量评估:")
|
||||
for metric, value in quality_metrics.items():
|
||||
logger.info(f" {metric}: {value:.4f}")
|
||||
|
||||
return quality_metrics
|
||||
|
||||
def generate_optimized_samples(args):
|
||||
"""生成优化样本的主函数"""
|
||||
logger = setup_logging()
|
||||
|
||||
# 设备检查
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
logger.info(f"使用设备: {device}")
|
||||
|
||||
# 设置随机种子
|
||||
torch.manual_seed(args.seed)
|
||||
if device.type == 'cuda':
|
||||
torch.cuda.manual_seed(args.seed)
|
||||
|
||||
# 创建输出目录
|
||||
output_dir = Path(args.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 加载模型
|
||||
model, config = load_model(args.checkpoint, device)
|
||||
|
||||
# 创建调度器
|
||||
scheduler = OptimizedNoiseScheduler(
|
||||
num_timesteps=config.get('timesteps', args.timesteps),
|
||||
schedule_type=config.get('schedule_type', args.schedule_type)
|
||||
)
|
||||
|
||||
# 生成参数
|
||||
generation_config = {
|
||||
'num_samples': args.num_samples,
|
||||
'image_size': args.image_size,
|
||||
'guidance_scale': args.guidance_scale,
|
||||
'num_steps': args.num_steps,
|
||||
'use_ddim': args.use_ddim,
|
||||
'eta': args.eta,
|
||||
'seed': args.seed
|
||||
}
|
||||
|
||||
# 保存生成配置
|
||||
with open(output_dir / 'generation_config.yaml', 'w') as f:
|
||||
yaml.dump(generation_config, f, default_flow_style=False)
|
||||
|
||||
logger.info(f"开始生成 {args.num_samples} 个样本...")
|
||||
logger.info(f"采样步数: {args.num_steps}, DDIM: {args.use_ddim}, ETA: {args.eta}")
|
||||
|
||||
# 分批生成以避免内存不足
|
||||
all_samples = []
|
||||
batch_size = min(args.batch_size, args.num_samples)
|
||||
num_batches = (args.num_samples + batch_size - 1) // batch_size
|
||||
|
||||
for batch_idx in range(num_batches):
|
||||
start_idx = batch_idx * batch_size
|
||||
end_idx = min(start_idx + batch_size, args.num_samples)
|
||||
current_batch_size = end_idx - start_idx
|
||||
|
||||
logger.info(f"生成批次 {batch_idx + 1}/{num_batches} ({current_batch_size} 个样本)")
|
||||
|
||||
# 生成样本
|
||||
with torch.no_grad():
|
||||
samples = generate_with_guidance(
|
||||
model, scheduler, current_batch_size, args.image_size, device,
|
||||
args.guidance_scale, args.num_steps, args.use_ddim
|
||||
)
|
||||
|
||||
# 后处理
|
||||
if args.use_post_process:
|
||||
samples = manhattan_post_process(samples, threshold=args.post_process_threshold)
|
||||
|
||||
all_samples.append(samples)
|
||||
|
||||
# 立即保存当前批次
|
||||
batch_dir = output_dir / f"batch_{batch_idx + 1}"
|
||||
batch_dir.mkdir(exist_ok=True)
|
||||
|
||||
for i in range(current_batch_size):
|
||||
img_tensor = samples[i].cpu()
|
||||
img_array = (img_tensor.squeeze().numpy() * 255).astype(np.uint8)
|
||||
img = Image.fromarray(img_array, mode='L')
|
||||
img.save(batch_dir / f"sample_{start_idx + i:06d}.png")
|
||||
|
||||
# 合并所有样本
|
||||
all_samples = torch.cat(all_samples, dim=0)
|
||||
|
||||
# 评估生成质量
|
||||
quality_metrics = evaluate_generation_quality(all_samples, device)
|
||||
|
||||
# 保存质量评估结果
|
||||
with open(output_dir / 'quality_metrics.yaml', 'w') as f:
|
||||
yaml.dump(quality_metrics, f, default_flow_style=False)
|
||||
|
||||
# 创建质量报告
|
||||
report_content = f"""
|
||||
IC版图扩散模型生成报告
|
||||
======================
|
||||
|
||||
生成配置:
|
||||
- 模型检查点: {args.checkpoint}
|
||||
- 样本数量: {args.num_samples}
|
||||
- 图像尺寸: {args.image_size}x{args.image_size}
|
||||
- 采样步数: {args.num_steps}
|
||||
- DDIM采样: {args.use_ddim}
|
||||
- 后处理: {args.use_post_process}
|
||||
|
||||
质量指标:
|
||||
- 曼哈顿几何合规性: {quality_metrics['manhattan_compliance']:.4f} (越低越好)
|
||||
- 边缘锐度: {quality_metrics['edge_sharpness']:.4f}
|
||||
- 对比度: {quality_metrics['contrast']:.4f}
|
||||
- 稀疏性: {quality_metrics['sparsity']:.4f}
|
||||
|
||||
输出目录: {output_dir}
|
||||
"""
|
||||
|
||||
with open(output_dir / 'generation_report.txt', 'w') as f:
|
||||
f.write(report_content)
|
||||
|
||||
logger.info("生成完成!")
|
||||
logger.info(f"样本保存目录: {output_dir}")
|
||||
logger.info(f"质量报告: {output_dir / 'generation_report.txt'}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="使用优化的扩散模型生成IC版图")
|
||||
|
||||
# 必需参数
|
||||
parser.add_argument('--checkpoint', type=str, required=True, help='模型检查点路径')
|
||||
parser.add_argument('--output_dir', type=str, required=True, help='输出目录')
|
||||
|
||||
# 生成参数
|
||||
parser.add_argument('--num_samples', type=int, default=200, help='生成样本数量')
|
||||
parser.add_argument('--image_size', type=int, default=256, help='图像尺寸')
|
||||
parser.add_argument('--seed', type=int, default=42, help='随机种子')
|
||||
parser.add_argument('--batch_size', type=int, default=8, help='批次大小')
|
||||
|
||||
# 采样参数
|
||||
parser.add_argument('--num_steps', type=int, default=50, help='采样步数')
|
||||
parser.add_argument('--use_ddim', action='store_true', default=True, help='使用DDIM采样')
|
||||
parser.add_argument('--guidance_scale', type=float, default=1.0, help='引导尺度')
|
||||
parser.add_argument('--eta', type=float, default=0.0, help='DDIM eta参数 (0=确定性, 1=随机)')
|
||||
|
||||
# 后处理参数
|
||||
parser.add_argument('--use_post_process', action='store_true', default=True, help='启用后处理')
|
||||
parser.add_argument('--post_process_threshold', type=float, default=0.5, help='后处理阈值')
|
||||
|
||||
# 模型配置(用于覆盖检查点中的配置)
|
||||
parser.add_argument('--timesteps', type=int, default=1000, help='扩散时间步数')
|
||||
parser.add_argument('--schedule_type', type=str, default='cosine',
|
||||
choices=['linear', 'cosine'], help='噪声调度类型')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 检查检查点文件
|
||||
if not Path(args.checkpoint).exists():
|
||||
print(f"错误: 检查点文件不存在: {args.checkpoint}")
|
||||
sys.exit(1)
|
||||
|
||||
# 开始生成
|
||||
generate_optimized_samples(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user