change some problem 7
This commit is contained in:
@@ -5,6 +5,7 @@
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
@@ -94,12 +95,22 @@ def train_optimized_diffusion(args):
|
||||
|
||||
# 设备检查
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
logger.info(f"使用设备: {device}")
|
||||
logger.info(f"🚀 使用设备: {device}")
|
||||
|
||||
# 设备详细信息
|
||||
if device.type == 'cuda':
|
||||
logger.info(f"📊 GPU信息: {torch.cuda.get_device_name()}")
|
||||
logger.info(f"💾 GPU内存: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
|
||||
logger.info(f"🔥 CUDA版本: {torch.version.cuda}")
|
||||
else:
|
||||
logger.info("💻 使用CPU进行训练(较慢,建议使用GPU)")
|
||||
|
||||
# 设置随机种子
|
||||
torch.manual_seed(args.seed)
|
||||
if device.type == 'cuda':
|
||||
torch.cuda.manual_seed(args.seed)
|
||||
logger.info(f"🎲 随机种子设置为: {args.seed}")
|
||||
logger.info("=" * 60)
|
||||
|
||||
# 创建输出目录
|
||||
output_dir = Path(args.output_dir)
|
||||
@@ -123,20 +134,27 @@ def train_optimized_diffusion(args):
|
||||
yaml.dump(config, f, default_flow_style=False)
|
||||
|
||||
# 创建数据集
|
||||
logger.info(f"加载数据集: {args.data_dir}")
|
||||
logger.info(f"📂 开始加载数据集: {args.data_dir}")
|
||||
logger.info(f"🖼️ 图像尺寸: {args.image_size}x{args.image_size}")
|
||||
logger.info(f"🔄 数据增强: {'启用' if args.augment else '禁用'}")
|
||||
logger.info(f"📐 边缘条件: {'启用' if args.edge_condition else '禁用'}")
|
||||
|
||||
start_time = time.time()
|
||||
dataset = ICDiffusionDataset(
|
||||
image_dir=args.data_dir,
|
||||
image_size=args.image_size,
|
||||
augment=args.augment,
|
||||
use_edge_condition=args.edge_condition
|
||||
)
|
||||
load_time = time.time() - start_time
|
||||
|
||||
# 检查数据集是否为空
|
||||
if len(dataset) == 0:
|
||||
logger.error(f"数据集为空!请检查数据目录: {args.data_dir}")
|
||||
logger.error(f"❌ 数据集为空!请检查数据目录: {args.data_dir}")
|
||||
raise ValueError(f"数据集为空,在目录 {args.data_dir} 中未找到图像文件")
|
||||
|
||||
logger.info(f"找到 {len(dataset)} 个训练样本")
|
||||
logger.info(f"✅ 数据集加载完成,耗时: {load_time:.2f}秒")
|
||||
logger.info(f"📊 找到 {len(dataset)} 个训练样本")
|
||||
|
||||
# 数据集分割 - 修复空数据集问题
|
||||
total_size = len(dataset)
|
||||
@@ -181,20 +199,30 @@ def train_optimized_diffusion(args):
|
||||
val_dataloader = None
|
||||
|
||||
# 创建模型
|
||||
logger.info("创建优化模型...")
|
||||
logger.info("🏗️ 正在创建U-Net扩散模型...")
|
||||
model_start_time = time.time()
|
||||
model = ManhattanAwareUNet(
|
||||
in_channels=1,
|
||||
out_channels=1,
|
||||
use_edge_condition=args.edge_condition
|
||||
).to(device)
|
||||
|
||||
# 计算模型参数量
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
|
||||
logger.info(f"✅ 模型创建完成,耗时: {time.time() - model_start_time:.2f}秒")
|
||||
logger.info(f"📊 模型参数: {total_params:,} 总计, {trainable_params:,} 可训练")
|
||||
|
||||
# 创建调度器
|
||||
logger.info(f"⏱️ 创建噪声调度器: {args.schedule_type}, {args.timesteps} 步")
|
||||
scheduler = OptimizedNoiseScheduler(
|
||||
num_timesteps=args.timesteps,
|
||||
schedule_type=args.schedule_type
|
||||
)
|
||||
|
||||
# 创建训练器
|
||||
logger.info("🎯 创建优化训练器...")
|
||||
trainer = OptimizedDiffusionTrainer(
|
||||
model, scheduler, device, args.edge_condition
|
||||
)
|
||||
@@ -224,30 +252,56 @@ def train_optimized_diffusion(args):
|
||||
else:
|
||||
logger.warning(f"检查点文件不存在: {checkpoint_path}")
|
||||
|
||||
logger.info(f"开始训练 {args.epochs} 个epoch (从epoch {start_epoch}开始)...")
|
||||
logger.info(f"🚀 开始训练 {args.epochs} 个epoch (从epoch {start_epoch}开始)...")
|
||||
logger.info("=" * 80)
|
||||
|
||||
# 训练循环
|
||||
best_val_loss = float('inf')
|
||||
total_training_start = time.time()
|
||||
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
epoch_start_time = time.time()
|
||||
|
||||
# 训练
|
||||
logger.info(f"🏃 训练 Epoch {epoch+1}/{args.epochs}")
|
||||
step_start_time = time.time()
|
||||
train_losses = trainer.train_step(
|
||||
optimizer, train_dataloader, args.manhattan_weight
|
||||
)
|
||||
step_time = time.time() - step_start_time
|
||||
|
||||
# 验证 - 修复None验证集问题
|
||||
if val_dataloader is not None:
|
||||
logger.info(f"🔍 验证 Epoch {epoch+1}/{args.epochs}")
|
||||
val_loss = validate_model(trainer, val_dataloader, device)
|
||||
else:
|
||||
# 如果没有验证集,使用训练损失作为验证损失
|
||||
val_loss = train_losses['total_loss']
|
||||
logger.warning("未使用验证集 - 使用训练损失作为参考")
|
||||
logger.info("⚠️ 未使用验证集 - 使用训练损失作为参考")
|
||||
|
||||
# 学习率调度
|
||||
lr_scheduler.step()
|
||||
|
||||
# 记录损失
|
||||
# 记录损失和进度
|
||||
current_lr = optimizer.param_groups[0]['lr']
|
||||
epoch_time = time.time() - epoch_start_time
|
||||
|
||||
# 详细日志输出
|
||||
logger.info(f"✅ Epoch {epoch+1} 完成 (耗时: {epoch_time:.1f}s, 训练: {step_time:.1f}s)")
|
||||
logger.info(f"📉 训练损失 - Total: {train_losses['total_loss']:.6f} | "
|
||||
f"MSE: {train_losses['mse_loss']:.6f} | "
|
||||
f"Edge: {train_losses['edge_loss']:.6f} | "
|
||||
f"Manhattan: {train_losses['manhattan_loss']:.6f}")
|
||||
if val_dataloader is not None:
|
||||
logger.info(f"🎯 验证损失: {val_loss:.6f}")
|
||||
logger.info(f"📈 学习率: {current_lr:.8f}")
|
||||
|
||||
# 内存使用情况(仅GPU)
|
||||
if device.type == 'cuda':
|
||||
memory_used = torch.cuda.memory_allocated() / 1024**3
|
||||
memory_total = torch.cuda.get_device_properties(0).total_memory / 1024**3
|
||||
logger.info(f"💾 GPU内存: {memory_used:.1f}GB / {memory_total:.1f}GB ({memory_used/memory_total*100:.1f}%)")
|
||||
|
||||
losses_history.append({
|
||||
'epoch': epoch,
|
||||
'train_loss': train_losses['total_loss'],
|
||||
@@ -258,21 +312,11 @@ def train_optimized_diffusion(args):
|
||||
'lr': current_lr
|
||||
})
|
||||
|
||||
# 日志输出
|
||||
logger.info(
|
||||
f"Epoch {epoch+1}/{args.epochs} | "
|
||||
f"Train Loss: {train_losses['total_loss']:.6f} | "
|
||||
f"Val Loss: {val_loss:.6f} | "
|
||||
f"Edge: {train_losses['edge_loss']:.6f} | "
|
||||
f"Structure: {train_losses['structure_loss']:.6f} | "
|
||||
f"Manhattan: {train_losses['manhattan_loss']:.6f} | "
|
||||
f"LR: {current_lr:.2e}"
|
||||
)
|
||||
|
||||
# 保存最佳模型 - 即使没有验证集也保存
|
||||
# 保存最佳模型
|
||||
if val_loss < best_val_loss:
|
||||
best_val_loss = val_loss
|
||||
best_model_path = output_dir / "best_model.pth"
|
||||
logger.info(f"🏆 新的最佳模型! 验证损失: {val_loss:.6f}")
|
||||
save_checkpoint(
|
||||
model, optimizer, lr_scheduler, epoch, losses_history, best_model_path
|
||||
)
|
||||
@@ -280,6 +324,7 @@ def train_optimized_diffusion(args):
|
||||
# 定期保存检查点
|
||||
if (epoch + 1) % args.save_interval == 0:
|
||||
checkpoint_path = output_dir / f"checkpoint_epoch_{epoch+1}.pth"
|
||||
logger.info(f"💾 保存检查点: {checkpoint_path.name}")
|
||||
save_checkpoint(
|
||||
model, optimizer, lr_scheduler, epoch, losses_history, checkpoint_path
|
||||
)
|
||||
@@ -287,26 +332,36 @@ def train_optimized_diffusion(args):
|
||||
# 生成样本
|
||||
if (epoch + 1) % args.sample_interval == 0:
|
||||
sample_dir = output_dir / f"samples_epoch_{epoch+1}"
|
||||
logger.info(f"生成样本到 {sample_dir}")
|
||||
logger.info(f"🎨 生成 {args.num_samples} 个样本到 {sample_dir}")
|
||||
sample_start_time = time.time()
|
||||
trainer.generate(
|
||||
num_samples=args.num_samples,
|
||||
image_size=args.image_size,
|
||||
save_dir=sample_dir,
|
||||
use_post_process=True
|
||||
)
|
||||
sample_time = time.time() - sample_start_time
|
||||
logger.info(f"✅ 样本生成完成,耗时: {sample_time:.1f}秒")
|
||||
|
||||
logger.info("-" * 80) # 分隔线
|
||||
|
||||
# 保存最终模型
|
||||
total_training_time = time.time() - total_training_start
|
||||
final_model_path = output_dir / "final_model.pth"
|
||||
logger.info("🎉 训练完成! 保存最终模型...")
|
||||
save_checkpoint(
|
||||
model, optimizer, lr_scheduler, args.epochs-1, losses_history, final_model_path
|
||||
)
|
||||
logger.info(f"💾 最终模型已保存: {final_model_path}")
|
||||
logger.info(f"⏱️ 总训练时间: {total_training_time/3600:.2f} 小时")
|
||||
logger.info(f"⚡ 平均每epoch: {total_training_time/args.epochs:.1f} 秒")
|
||||
|
||||
# 保存损失历史
|
||||
with open(output_dir / 'loss_history.yaml', 'w') as f:
|
||||
yaml.dump(losses_history, f, default_flow_style=False)
|
||||
|
||||
# 最终生成
|
||||
logger.info("生成最终样本...")
|
||||
logger.info("🎨 生成最终样本...")
|
||||
final_sample_dir = output_dir / "final_samples"
|
||||
trainer.generate(
|
||||
num_samples=args.num_samples * 2, # 生成更多样本
|
||||
@@ -315,10 +370,19 @@ def train_optimized_diffusion(args):
|
||||
use_post_process=True
|
||||
)
|
||||
|
||||
logger.info("训练完成!")
|
||||
logger.info(f"最佳模型: {output_dir / 'best_model.pth'}")
|
||||
logger.info(f"最终模型: {final_model_path}")
|
||||
logger.info(f"最终样本: {final_sample_dir}")
|
||||
# 训练总结
|
||||
logger.info("=" * 80)
|
||||
logger.info("🎉 训练总结:")
|
||||
logger.info(f" ✅ 训练完成!")
|
||||
logger.info(f" 📊 最终验证损失: {best_val_loss:.6f}")
|
||||
logger.info(f" 📈 总训练epochs: {args.epochs}")
|
||||
logger.info(f" 🎯 训练配置: 学习率={args.lr}, 批次大小={args.batch_size}")
|
||||
logger.info(f" 📁 输出目录: {output_dir}")
|
||||
logger.info(f" 🏆 最佳模型: {output_dir / 'best_model.pth'}")
|
||||
logger.info(f" 💾 最终模型: {final_model_path}")
|
||||
logger.info(f" 🎨 最终样本: {final_sample_dir}")
|
||||
logger.info(f" 📋 损失历史: {output_dir / 'loss_history.yaml'}")
|
||||
logger.info("=" * 80)
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
Reference in New Issue
Block a user