change some problem 7
This commit is contained in:
@@ -78,24 +78,24 @@ def ddim_sample(model, scheduler, num_samples, image_size, device, num_steps=50,
|
|||||||
# 预测噪声
|
# 预测噪声
|
||||||
predicted_noise = model(x, t_batch)
|
predicted_noise = model(x, t_batch)
|
||||||
|
|
||||||
# 计算原始图像的估计
|
# 计算原始图像的估计 - 确保调度器张量在正确设备上
|
||||||
alpha_t = scheduler.alphas[t].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
alpha_t = scheduler.alphas[t].to(device).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
||||||
alpha_cumprod_t = scheduler.alphas_cumprod[t].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
alpha_cumprod_t = scheduler.alphas_cumprod[t].to(device).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
||||||
beta_t = scheduler.betas[t].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
beta_t = scheduler.betas[t].to(device).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
||||||
sqrt_one_minus_alpha_cumprod_t = scheduler.sqrt_one_minus_alphas_cumprod[t].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
sqrt_one_minus_alpha_cumprod_t = scheduler.sqrt_one_minus_alphas_cumprod[t].to(device).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
||||||
|
|
||||||
# 计算x_0的估计
|
# 计算x_0的估计
|
||||||
x_0_pred = (x - sqrt_one_minus_alpha_cumprod_t * predicted_noise) / torch.sqrt(alpha_cumprod_t)
|
x_0_pred = (x - sqrt_one_minus_alpha_cumprod_t * predicted_noise) / torch.sqrt(alpha_cumprod_t)
|
||||||
|
|
||||||
# 计算前一时间步的方向
|
# 计算前一时间步的方向
|
||||||
if i < len(timesteps) - 1:
|
if i < len(timesteps) - 1:
|
||||||
alpha_t_prev = scheduler.alphas[timesteps[i+1]]
|
alpha_t_prev = scheduler.alphas[timesteps[i+1]].to(device)
|
||||||
alpha_cumprod_t_prev = scheduler.alphas_cumprod[timesteps[i+1]]
|
alpha_cumprod_t_prev = scheduler.alphas_cumprod[timesteps[i+1]].to(device)
|
||||||
sqrt_alpha_cumprod_t_prev = torch.sqrt(alpha_cumprod_t_prev).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
sqrt_alpha_cumprod_t_prev = torch.sqrt(alpha_cumprod_t_prev).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
||||||
sqrt_one_minus_alpha_cumprod_t_prev = torch.sqrt(1 - alpha_cumprod_t_prev).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
sqrt_one_minus_alpha_cumprod_t_prev = torch.sqrt(1 - alpha_cumprod_t_prev).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
||||||
|
|
||||||
# 计算方差
|
# 计算方差
|
||||||
variance = eta * torch.sqrt(beta_t).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
variance = eta * torch.sqrt(beta_t).squeeze().squeeze().squeeze()
|
||||||
|
|
||||||
# 计算前一时间步的x
|
# 计算前一时间步的x
|
||||||
x = sqrt_alpha_cumprod_t_prev * x_0_pred + torch.sqrt(1 - alpha_cumprod_t_prev - variance**2) * predicted_noise
|
x = sqrt_alpha_cumprod_t_prev * x_0_pred + torch.sqrt(1 - alpha_cumprod_t_prev - variance**2) * predicted_noise
|
||||||
@@ -183,6 +183,18 @@ def generate_optimized_samples(args):
|
|||||||
schedule_type=config.get('schedule_type', args.schedule_type)
|
schedule_type=config.get('schedule_type', args.schedule_type)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 确保调度器的所有张量都在正确的设备上
|
||||||
|
if hasattr(scheduler, 'betas'):
|
||||||
|
scheduler.betas = scheduler.betas.to(device)
|
||||||
|
if hasattr(scheduler, 'alphas'):
|
||||||
|
scheduler.alphas = scheduler.alphas.to(device)
|
||||||
|
if hasattr(scheduler, 'alphas_cumprod'):
|
||||||
|
scheduler.alphas_cumprod = scheduler.alphas_cumprod.to(device)
|
||||||
|
if hasattr(scheduler, 'sqrt_alphas_cumprod'):
|
||||||
|
scheduler.sqrt_alphas_cumprod = scheduler.sqrt_alphas_cumprod.to(device)
|
||||||
|
if hasattr(scheduler, 'sqrt_one_minus_alphas_cumprod'):
|
||||||
|
scheduler.sqrt_one_minus_alphas_cumprod = scheduler.sqrt_one_minus_alphas_cumprod.to(device)
|
||||||
|
|
||||||
# 生成参数
|
# 生成参数
|
||||||
generation_config = {
|
generation_config = {
|
||||||
'num_samples': args.num_samples,
|
'num_samples': args.num_samples,
|
||||||
|
|||||||
@@ -5,6 +5,7 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
import time
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
@@ -94,12 +95,22 @@ def train_optimized_diffusion(args):
|
|||||||
|
|
||||||
# 设备检查
|
# 设备检查
|
||||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
logger.info(f"使用设备: {device}")
|
logger.info(f"🚀 使用设备: {device}")
|
||||||
|
|
||||||
|
# 设备详细信息
|
||||||
|
if device.type == 'cuda':
|
||||||
|
logger.info(f"📊 GPU信息: {torch.cuda.get_device_name()}")
|
||||||
|
logger.info(f"💾 GPU内存: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
|
||||||
|
logger.info(f"🔥 CUDA版本: {torch.version.cuda}")
|
||||||
|
else:
|
||||||
|
logger.info("💻 使用CPU进行训练(较慢,建议使用GPU)")
|
||||||
|
|
||||||
# 设置随机种子
|
# 设置随机种子
|
||||||
torch.manual_seed(args.seed)
|
torch.manual_seed(args.seed)
|
||||||
if device.type == 'cuda':
|
if device.type == 'cuda':
|
||||||
torch.cuda.manual_seed(args.seed)
|
torch.cuda.manual_seed(args.seed)
|
||||||
|
logger.info(f"🎲 随机种子设置为: {args.seed}")
|
||||||
|
logger.info("=" * 60)
|
||||||
|
|
||||||
# 创建输出目录
|
# 创建输出目录
|
||||||
output_dir = Path(args.output_dir)
|
output_dir = Path(args.output_dir)
|
||||||
@@ -123,20 +134,27 @@ def train_optimized_diffusion(args):
|
|||||||
yaml.dump(config, f, default_flow_style=False)
|
yaml.dump(config, f, default_flow_style=False)
|
||||||
|
|
||||||
# 创建数据集
|
# 创建数据集
|
||||||
logger.info(f"加载数据集: {args.data_dir}")
|
logger.info(f"📂 开始加载数据集: {args.data_dir}")
|
||||||
|
logger.info(f"🖼️ 图像尺寸: {args.image_size}x{args.image_size}")
|
||||||
|
logger.info(f"🔄 数据增强: {'启用' if args.augment else '禁用'}")
|
||||||
|
logger.info(f"📐 边缘条件: {'启用' if args.edge_condition else '禁用'}")
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
dataset = ICDiffusionDataset(
|
dataset = ICDiffusionDataset(
|
||||||
image_dir=args.data_dir,
|
image_dir=args.data_dir,
|
||||||
image_size=args.image_size,
|
image_size=args.image_size,
|
||||||
augment=args.augment,
|
augment=args.augment,
|
||||||
use_edge_condition=args.edge_condition
|
use_edge_condition=args.edge_condition
|
||||||
)
|
)
|
||||||
|
load_time = time.time() - start_time
|
||||||
|
|
||||||
# 检查数据集是否为空
|
# 检查数据集是否为空
|
||||||
if len(dataset) == 0:
|
if len(dataset) == 0:
|
||||||
logger.error(f"数据集为空!请检查数据目录: {args.data_dir}")
|
logger.error(f"❌ 数据集为空!请检查数据目录: {args.data_dir}")
|
||||||
raise ValueError(f"数据集为空,在目录 {args.data_dir} 中未找到图像文件")
|
raise ValueError(f"数据集为空,在目录 {args.data_dir} 中未找到图像文件")
|
||||||
|
|
||||||
logger.info(f"找到 {len(dataset)} 个训练样本")
|
logger.info(f"✅ 数据集加载完成,耗时: {load_time:.2f}秒")
|
||||||
|
logger.info(f"📊 找到 {len(dataset)} 个训练样本")
|
||||||
|
|
||||||
# 数据集分割 - 修复空数据集问题
|
# 数据集分割 - 修复空数据集问题
|
||||||
total_size = len(dataset)
|
total_size = len(dataset)
|
||||||
@@ -181,20 +199,30 @@ def train_optimized_diffusion(args):
|
|||||||
val_dataloader = None
|
val_dataloader = None
|
||||||
|
|
||||||
# 创建模型
|
# 创建模型
|
||||||
logger.info("创建优化模型...")
|
logger.info("🏗️ 正在创建U-Net扩散模型...")
|
||||||
|
model_start_time = time.time()
|
||||||
model = ManhattanAwareUNet(
|
model = ManhattanAwareUNet(
|
||||||
in_channels=1,
|
in_channels=1,
|
||||||
out_channels=1,
|
out_channels=1,
|
||||||
use_edge_condition=args.edge_condition
|
use_edge_condition=args.edge_condition
|
||||||
).to(device)
|
).to(device)
|
||||||
|
|
||||||
|
# 计算模型参数量
|
||||||
|
total_params = sum(p.numel() for p in model.parameters())
|
||||||
|
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||||
|
|
||||||
|
logger.info(f"✅ 模型创建完成,耗时: {time.time() - model_start_time:.2f}秒")
|
||||||
|
logger.info(f"📊 模型参数: {total_params:,} 总计, {trainable_params:,} 可训练")
|
||||||
|
|
||||||
# 创建调度器
|
# 创建调度器
|
||||||
|
logger.info(f"⏱️ 创建噪声调度器: {args.schedule_type}, {args.timesteps} 步")
|
||||||
scheduler = OptimizedNoiseScheduler(
|
scheduler = OptimizedNoiseScheduler(
|
||||||
num_timesteps=args.timesteps,
|
num_timesteps=args.timesteps,
|
||||||
schedule_type=args.schedule_type
|
schedule_type=args.schedule_type
|
||||||
)
|
)
|
||||||
|
|
||||||
# 创建训练器
|
# 创建训练器
|
||||||
|
logger.info("🎯 创建优化训练器...")
|
||||||
trainer = OptimizedDiffusionTrainer(
|
trainer = OptimizedDiffusionTrainer(
|
||||||
model, scheduler, device, args.edge_condition
|
model, scheduler, device, args.edge_condition
|
||||||
)
|
)
|
||||||
@@ -224,30 +252,56 @@ def train_optimized_diffusion(args):
|
|||||||
else:
|
else:
|
||||||
logger.warning(f"检查点文件不存在: {checkpoint_path}")
|
logger.warning(f"检查点文件不存在: {checkpoint_path}")
|
||||||
|
|
||||||
logger.info(f"开始训练 {args.epochs} 个epoch (从epoch {start_epoch}开始)...")
|
logger.info(f"🚀 开始训练 {args.epochs} 个epoch (从epoch {start_epoch}开始)...")
|
||||||
|
logger.info("=" * 80)
|
||||||
|
|
||||||
# 训练循环
|
# 训练循环
|
||||||
best_val_loss = float('inf')
|
best_val_loss = float('inf')
|
||||||
|
total_training_start = time.time()
|
||||||
|
|
||||||
for epoch in range(start_epoch, args.epochs):
|
for epoch in range(start_epoch, args.epochs):
|
||||||
|
epoch_start_time = time.time()
|
||||||
|
|
||||||
# 训练
|
# 训练
|
||||||
|
logger.info(f"🏃 训练 Epoch {epoch+1}/{args.epochs}")
|
||||||
|
step_start_time = time.time()
|
||||||
train_losses = trainer.train_step(
|
train_losses = trainer.train_step(
|
||||||
optimizer, train_dataloader, args.manhattan_weight
|
optimizer, train_dataloader, args.manhattan_weight
|
||||||
)
|
)
|
||||||
|
step_time = time.time() - step_start_time
|
||||||
|
|
||||||
# 验证 - 修复None验证集问题
|
# 验证 - 修复None验证集问题
|
||||||
if val_dataloader is not None:
|
if val_dataloader is not None:
|
||||||
|
logger.info(f"🔍 验证 Epoch {epoch+1}/{args.epochs}")
|
||||||
val_loss = validate_model(trainer, val_dataloader, device)
|
val_loss = validate_model(trainer, val_dataloader, device)
|
||||||
else:
|
else:
|
||||||
# 如果没有验证集,使用训练损失作为验证损失
|
# 如果没有验证集,使用训练损失作为验证损失
|
||||||
val_loss = train_losses['total_loss']
|
val_loss = train_losses['total_loss']
|
||||||
logger.warning("未使用验证集 - 使用训练损失作为参考")
|
logger.info("⚠️ 未使用验证集 - 使用训练损失作为参考")
|
||||||
|
|
||||||
# 学习率调度
|
# 学习率调度
|
||||||
lr_scheduler.step()
|
lr_scheduler.step()
|
||||||
|
|
||||||
# 记录损失
|
# 记录损失和进度
|
||||||
current_lr = optimizer.param_groups[0]['lr']
|
current_lr = optimizer.param_groups[0]['lr']
|
||||||
|
epoch_time = time.time() - epoch_start_time
|
||||||
|
|
||||||
|
# 详细日志输出
|
||||||
|
logger.info(f"✅ Epoch {epoch+1} 完成 (耗时: {epoch_time:.1f}s, 训练: {step_time:.1f}s)")
|
||||||
|
logger.info(f"📉 训练损失 - Total: {train_losses['total_loss']:.6f} | "
|
||||||
|
f"MSE: {train_losses['mse_loss']:.6f} | "
|
||||||
|
f"Edge: {train_losses['edge_loss']:.6f} | "
|
||||||
|
f"Manhattan: {train_losses['manhattan_loss']:.6f}")
|
||||||
|
if val_dataloader is not None:
|
||||||
|
logger.info(f"🎯 验证损失: {val_loss:.6f}")
|
||||||
|
logger.info(f"📈 学习率: {current_lr:.8f}")
|
||||||
|
|
||||||
|
# 内存使用情况(仅GPU)
|
||||||
|
if device.type == 'cuda':
|
||||||
|
memory_used = torch.cuda.memory_allocated() / 1024**3
|
||||||
|
memory_total = torch.cuda.get_device_properties(0).total_memory / 1024**3
|
||||||
|
logger.info(f"💾 GPU内存: {memory_used:.1f}GB / {memory_total:.1f}GB ({memory_used/memory_total*100:.1f}%)")
|
||||||
|
|
||||||
losses_history.append({
|
losses_history.append({
|
||||||
'epoch': epoch,
|
'epoch': epoch,
|
||||||
'train_loss': train_losses['total_loss'],
|
'train_loss': train_losses['total_loss'],
|
||||||
@@ -258,21 +312,11 @@ def train_optimized_diffusion(args):
|
|||||||
'lr': current_lr
|
'lr': current_lr
|
||||||
})
|
})
|
||||||
|
|
||||||
# 日志输出
|
# 保存最佳模型
|
||||||
logger.info(
|
|
||||||
f"Epoch {epoch+1}/{args.epochs} | "
|
|
||||||
f"Train Loss: {train_losses['total_loss']:.6f} | "
|
|
||||||
f"Val Loss: {val_loss:.6f} | "
|
|
||||||
f"Edge: {train_losses['edge_loss']:.6f} | "
|
|
||||||
f"Structure: {train_losses['structure_loss']:.6f} | "
|
|
||||||
f"Manhattan: {train_losses['manhattan_loss']:.6f} | "
|
|
||||||
f"LR: {current_lr:.2e}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 保存最佳模型 - 即使没有验证集也保存
|
|
||||||
if val_loss < best_val_loss:
|
if val_loss < best_val_loss:
|
||||||
best_val_loss = val_loss
|
best_val_loss = val_loss
|
||||||
best_model_path = output_dir / "best_model.pth"
|
best_model_path = output_dir / "best_model.pth"
|
||||||
|
logger.info(f"🏆 新的最佳模型! 验证损失: {val_loss:.6f}")
|
||||||
save_checkpoint(
|
save_checkpoint(
|
||||||
model, optimizer, lr_scheduler, epoch, losses_history, best_model_path
|
model, optimizer, lr_scheduler, epoch, losses_history, best_model_path
|
||||||
)
|
)
|
||||||
@@ -280,6 +324,7 @@ def train_optimized_diffusion(args):
|
|||||||
# 定期保存检查点
|
# 定期保存检查点
|
||||||
if (epoch + 1) % args.save_interval == 0:
|
if (epoch + 1) % args.save_interval == 0:
|
||||||
checkpoint_path = output_dir / f"checkpoint_epoch_{epoch+1}.pth"
|
checkpoint_path = output_dir / f"checkpoint_epoch_{epoch+1}.pth"
|
||||||
|
logger.info(f"💾 保存检查点: {checkpoint_path.name}")
|
||||||
save_checkpoint(
|
save_checkpoint(
|
||||||
model, optimizer, lr_scheduler, epoch, losses_history, checkpoint_path
|
model, optimizer, lr_scheduler, epoch, losses_history, checkpoint_path
|
||||||
)
|
)
|
||||||
@@ -287,26 +332,36 @@ def train_optimized_diffusion(args):
|
|||||||
# 生成样本
|
# 生成样本
|
||||||
if (epoch + 1) % args.sample_interval == 0:
|
if (epoch + 1) % args.sample_interval == 0:
|
||||||
sample_dir = output_dir / f"samples_epoch_{epoch+1}"
|
sample_dir = output_dir / f"samples_epoch_{epoch+1}"
|
||||||
logger.info(f"生成样本到 {sample_dir}")
|
logger.info(f"🎨 生成 {args.num_samples} 个样本到 {sample_dir}")
|
||||||
|
sample_start_time = time.time()
|
||||||
trainer.generate(
|
trainer.generate(
|
||||||
num_samples=args.num_samples,
|
num_samples=args.num_samples,
|
||||||
image_size=args.image_size,
|
image_size=args.image_size,
|
||||||
save_dir=sample_dir,
|
save_dir=sample_dir,
|
||||||
use_post_process=True
|
use_post_process=True
|
||||||
)
|
)
|
||||||
|
sample_time = time.time() - sample_start_time
|
||||||
|
logger.info(f"✅ 样本生成完成,耗时: {sample_time:.1f}秒")
|
||||||
|
|
||||||
|
logger.info("-" * 80) # 分隔线
|
||||||
|
|
||||||
# 保存最终模型
|
# 保存最终模型
|
||||||
|
total_training_time = time.time() - total_training_start
|
||||||
final_model_path = output_dir / "final_model.pth"
|
final_model_path = output_dir / "final_model.pth"
|
||||||
|
logger.info("🎉 训练完成! 保存最终模型...")
|
||||||
save_checkpoint(
|
save_checkpoint(
|
||||||
model, optimizer, lr_scheduler, args.epochs-1, losses_history, final_model_path
|
model, optimizer, lr_scheduler, args.epochs-1, losses_history, final_model_path
|
||||||
)
|
)
|
||||||
|
logger.info(f"💾 最终模型已保存: {final_model_path}")
|
||||||
|
logger.info(f"⏱️ 总训练时间: {total_training_time/3600:.2f} 小时")
|
||||||
|
logger.info(f"⚡ 平均每epoch: {total_training_time/args.epochs:.1f} 秒")
|
||||||
|
|
||||||
# 保存损失历史
|
# 保存损失历史
|
||||||
with open(output_dir / 'loss_history.yaml', 'w') as f:
|
with open(output_dir / 'loss_history.yaml', 'w') as f:
|
||||||
yaml.dump(losses_history, f, default_flow_style=False)
|
yaml.dump(losses_history, f, default_flow_style=False)
|
||||||
|
|
||||||
# 最终生成
|
# 最终生成
|
||||||
logger.info("生成最终样本...")
|
logger.info("🎨 生成最终样本...")
|
||||||
final_sample_dir = output_dir / "final_samples"
|
final_sample_dir = output_dir / "final_samples"
|
||||||
trainer.generate(
|
trainer.generate(
|
||||||
num_samples=args.num_samples * 2, # 生成更多样本
|
num_samples=args.num_samples * 2, # 生成更多样本
|
||||||
@@ -315,10 +370,19 @@ def train_optimized_diffusion(args):
|
|||||||
use_post_process=True
|
use_post_process=True
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("训练完成!")
|
# 训练总结
|
||||||
logger.info(f"最佳模型: {output_dir / 'best_model.pth'}")
|
logger.info("=" * 80)
|
||||||
logger.info(f"最终模型: {final_model_path}")
|
logger.info("🎉 训练总结:")
|
||||||
logger.info(f"最终样本: {final_sample_dir}")
|
logger.info(f" ✅ 训练完成!")
|
||||||
|
logger.info(f" 📊 最终验证损失: {best_val_loss:.6f}")
|
||||||
|
logger.info(f" 📈 总训练epochs: {args.epochs}")
|
||||||
|
logger.info(f" 🎯 训练配置: 学习率={args.lr}, 批次大小={args.batch_size}")
|
||||||
|
logger.info(f" 📁 输出目录: {output_dir}")
|
||||||
|
logger.info(f" 🏆 最佳模型: {output_dir / 'best_model.pth'}")
|
||||||
|
logger.info(f" 💾 最终模型: {final_model_path}")
|
||||||
|
logger.info(f" 🎨 最终样本: {final_sample_dir}")
|
||||||
|
logger.info(f" 📋 损失历史: {output_dir / 'loss_history.yaml'}")
|
||||||
|
logger.info("=" * 80)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|||||||
Reference in New Issue
Block a user