#!/usr/bin/env python3 """ 使用优化后的扩散模型进行训练的完整脚本 """ import os import sys import torch import torch.nn as nn import torch.optim as optim from pathlib import Path import logging import yaml from torch.utils.data import DataLoader import argparse # 导入优化后的模块 from ic_layout_diffusion_optimized import ( ICDiffusionDataset, ManhattanAwareUNet, OptimizedNoiseScheduler, OptimizedDiffusionTrainer ) def setup_logging(): """设置日志""" logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[ logging.StreamHandler(sys.stdout), logging.FileHandler('diffusion_training.log') ] ) return logging.getLogger(__name__) def save_checkpoint(model, optimizer, scheduler, epoch, losses, checkpoint_path): """保存检查点""" checkpoint = { 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict() if hasattr(scheduler, 'state_dict') else None, 'losses': losses } torch.save(checkpoint, checkpoint_path) logging.info(f"检查点已保存: {checkpoint_path}") def load_checkpoint(checkpoint_path, model, optimizer=None, scheduler=None): """加载检查点""" checkpoint = torch.load(checkpoint_path, map_location='cpu') model.load_state_dict(checkpoint['model_state_dict']) if optimizer is not None and 'optimizer_state_dict' in checkpoint: optimizer.load_state_dict(checkpoint['optimizer_state_dict']) if scheduler is not None and 'scheduler_state_dict' in checkpoint and checkpoint['scheduler_state_dict']: scheduler.load_state_dict(checkpoint['scheduler_state_dict']) start_epoch = checkpoint.get('epoch', 0) losses = checkpoint.get('losses', {}) logging.info(f"检查点已加载: {checkpoint_path}, 从epoch {start_epoch}继续") return start_epoch, losses def validate_model(trainer, val_dataloader, device): """验证模型""" trainer.model.eval() total_loss = 0 with torch.no_grad(): for batch in val_dataloader: if trainer.use_edge_condition: images, edge_conditions = batch edge_conditions = edge_conditions.to(device) else: images = batch edge_conditions = None images = images.to(device) t = trainer.scheduler.sample_timestep(images.shape[0]).to(device) noisy_images, noise = trainer.scheduler.add_noise(images, t) predicted_noise = trainer.model(noisy_images, t, edge_conditions) loss = trainer.mse_loss(predicted_noise, noise) total_loss += loss.item() trainer.model.train() return total_loss / len(val_dataloader) def train_optimized_diffusion(args): """训练优化的扩散模型""" logger = setup_logging() # 设备检查 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') logger.info(f"使用设备: {device}") # 设置随机种子 torch.manual_seed(args.seed) if device.type == 'cuda': torch.cuda.manual_seed(args.seed) # 创建输出目录 output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) # 保存训练配置 config = { 'image_size': args.image_size, 'batch_size': args.batch_size, 'epochs': args.epochs, 'lr': args.lr, 'timesteps': args.timesteps, 'schedule_type': args.schedule_type, 'edge_condition': args.edge_condition, 'manhattan_weight': args.manhattan_weight, 'augment': args.augment, 'seed': args.seed } with open(output_dir / 'training_config.yaml', 'w') as f: yaml.dump(config, f, default_flow_style=False) # 创建数据集 logger.info(f"加载数据集: {args.data_dir}") dataset = ICDiffusionDataset( image_dir=args.data_dir, image_size=args.image_size, augment=args.augment, use_edge_condition=args.edge_condition ) # 数据集分割 total_size = len(dataset) train_size = int(0.9 * total_size) val_size = total_size - train_size train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size]) # 数据加载器 train_dataloader = DataLoader( train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True ) val_dataloader = DataLoader( val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=2 ) logger.info(f"训练集大小: {len(train_dataset)}, 验证集大小: {len(val_dataset)}") # 创建模型 logger.info("创建优化模型...") model = ManhattanAwareUNet( in_channels=1, out_channels=1, use_edge_condition=args.edge_condition ).to(device) # 创建调度器 scheduler = OptimizedNoiseScheduler( num_timesteps=args.timesteps, schedule_type=args.schedule_type ) # 创建训练器 trainer = OptimizedDiffusionTrainer( model, scheduler, device, args.edge_condition ) # 优化器和学习率调度器 optimizer = optim.AdamW( model.parameters(), lr=args.lr, weight_decay=0.01, betas=(0.9, 0.999) ) lr_scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts( optimizer, T_0=10, T_mult=2, eta_min=1e-6 ) # 检查点恢复 start_epoch = 0 losses_history = [] if args.resume: checkpoint_path = Path(args.resume) if checkpoint_path.exists(): start_epoch, losses_history = load_checkpoint( checkpoint_path, model, optimizer, lr_scheduler ) else: logger.warning(f"检查点文件不存在: {checkpoint_path}") logger.info(f"开始训练 {args.epochs} 个epoch (从epoch {start_epoch}开始)...") # 训练循环 best_val_loss = float('inf') for epoch in range(start_epoch, args.epochs): # 训练 train_losses = trainer.train_step( optimizer, train_dataloader, args.manhattan_weight ) # 验证 val_loss = validate_model(trainer, val_dataloader, device) # 学习率调度 lr_scheduler.step() # 记录损失 current_lr = optimizer.param_groups[0]['lr'] losses_history.append({ 'epoch': epoch, 'train_loss': train_losses['total_loss'], 'val_loss': val_loss, 'edge_loss': train_losses['edge_loss'], 'structure_loss': train_losses['structure_loss'], 'manhattan_loss': train_losses['manhattan_loss'], 'lr': current_lr }) # 日志输出 logger.info( f"Epoch {epoch+1}/{args.epochs} | " f"Train Loss: {train_losses['total_loss']:.6f} | " f"Val Loss: {val_loss:.6f} | " f"Edge: {train_losses['edge_loss']:.6f} | " f"Structure: {train_losses['structure_loss']:.6f} | " f"Manhattan: {train_losses['manhattan_loss']:.6f} | " f"LR: {current_lr:.2e}" ) # 保存最佳模型 if val_loss < best_val_loss: best_val_loss = val_loss best_model_path = output_dir / "best_model.pth" save_checkpoint( model, optimizer, lr_scheduler, epoch, losses_history, best_model_path ) # 定期保存检查点 if (epoch + 1) % args.save_interval == 0: checkpoint_path = output_dir / f"checkpoint_epoch_{epoch+1}.pth" save_checkpoint( model, optimizer, lr_scheduler, epoch, losses_history, checkpoint_path ) # 生成样本 if (epoch + 1) % args.sample_interval == 0: sample_dir = output_dir / f"samples_epoch_{epoch+1}" logger.info(f"生成样本到 {sample_dir}") trainer.generate( num_samples=args.num_samples, image_size=args.image_size, save_dir=sample_dir, use_post_process=True ) # 保存最终模型 final_model_path = output_dir / "final_model.pth" save_checkpoint( model, optimizer, lr_scheduler, args.epochs-1, losses_history, final_model_path ) # 保存损失历史 with open(output_dir / 'loss_history.yaml', 'w') as f: yaml.dump(losses_history, f, default_flow_style=False) # 最终生成 logger.info("生成最终样本...") final_sample_dir = output_dir / "final_samples" trainer.generate( num_samples=args.num_samples * 2, # 生成更多样本 image_size=args.image_size, save_dir=final_sample_dir, use_post_process=True ) logger.info("训练完成!") logger.info(f"最佳模型: {output_dir / 'best_model.pth'}") logger.info(f"最终模型: {final_model_path}") logger.info(f"最终样本: {final_sample_dir}") def main(): parser = argparse.ArgumentParser(description="训练优化的IC版图扩散模型") # 数据参数 parser.add_argument('--data_dir', type=str, required=True, help='训练数据目录') parser.add_argument('--output_dir', type=str, required=True, help='输出目录') # 模型参数 parser.add_argument('--image_size', type=int, default=256, help='图像尺寸') parser.add_argument('--timesteps', type=int, default=1000, help='扩散时间步数') parser.add_argument('--schedule_type', type=str, default='cosine', choices=['linear', 'cosine'], help='噪声调度类型') parser.add_argument('--edge_condition', action='store_true', help='使用边缘条件') # 训练参数 parser.add_argument('--batch_size', type=int, default=4, help='批次大小') parser.add_argument('--epochs', type=int, default=100, help='训练轮数') parser.add_argument('--lr', type=float, default=1e-4, help='学习率') parser.add_argument('--manhattan_weight', type=float, default=0.1, help='曼哈顿正则化权重') parser.add_argument('--seed', type=int, default=42, help='随机种子') # 训练控制 parser.add_argument('--augment', action='store_true', help='启用数据增强') parser.add_argument('--resume', type=str, default=None, help='恢复训练的检查点路径') parser.add_argument('--save_interval', type=int, default=10, help='保存间隔') parser.add_argument('--sample_interval', type=int, default=20, help='生成样本间隔') parser.add_argument('--num_samples', type=int, default=16, help='每次生成的样本数量') args = parser.parse_args() # 检查数据目录 if not Path(args.data_dir).exists(): print(f"错误: 数据目录不存在: {args.data_dir}") sys.exit(1) # 开始训练 train_optimized_diffusion(args) if __name__ == "__main__": main()