change some problem 7

This commit is contained in:
Jiao77
2025-11-20 03:09:18 +08:00
parent 3d75ed722a
commit 3258b7b6de
3 changed files with 407 additions and 100 deletions

View File

@@ -131,28 +131,54 @@ def train_optimized_diffusion(args):
use_edge_condition=args.edge_condition
)
# 数据集分割
# 检查数据集是否为空
if len(dataset) == 0:
logger.error(f"数据集为空!请检查数据目录: {args.data_dir}")
raise ValueError(f"数据集为空,在目录 {args.data_dir} 中未找到图像文件")
logger.info(f"找到 {len(dataset)} 个训练样本")
# 数据集分割 - 修复空数据集问题
total_size = len(dataset)
train_size = int(0.9 * total_size)
val_size = total_size - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
if total_size < 10: # 如果数据集太小,全部用于训练
logger.warning(f"数据集较小 ({total_size} 样本),全部用于训练")
train_dataset = dataset
val_dataset = None
else:
train_size = int(0.9 * total_size)
val_size = total_size - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
logger.info(f"训练集: {len(train_dataset)}, 验证集: {len(val_dataset)}")
# 数据加载器
train_dataloader = DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=4,
pin_memory=True
)
val_dataloader = DataLoader(
val_dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=2
)
# 数据加载器 - 修复None验证集问题
if device.type == 'cuda':
train_dataloader = DataLoader(
train_dataset,
batch_size=min(args.batch_size, len(train_dataset)), # 确保批次大小不超过数据集大小
shuffle=True,
num_workers=min(4, max(1, len(train_dataset) // args.batch_size)),
pin_memory=True,
drop_last=True # 避免最后一个不完整的批次
)
else:
# CPU模式下使用较少的worker
train_dataloader = DataLoader(
train_dataset,
batch_size=min(args.batch_size, len(train_dataset)),
shuffle=True,
num_workers=0, # CPU模式下避免多进程
drop_last=True
)
logger.info(f"训练集大小: {len(train_dataset)}, 验证集大小: {len(val_dataset)}")
if val_dataset is not None:
val_dataloader = DataLoader(
val_dataset,
batch_size=min(args.batch_size, len(val_dataset)),
shuffle=False,
num_workers=2
)
else:
val_dataloader = None
# 创建模型
logger.info("创建优化模型...")
@@ -209,8 +235,13 @@ def train_optimized_diffusion(args):
optimizer, train_dataloader, args.manhattan_weight
)
# 验证
val_loss = validate_model(trainer, val_dataloader, device)
# 验证 - 修复None验证集问题
if val_dataloader is not None:
val_loss = validate_model(trainer, val_dataloader, device)
else:
# 如果没有验证集,使用训练损失作为验证损失
val_loss = train_losses['total_loss']
logger.warning("未使用验证集 - 使用训练损失作为参考")
# 学习率调度
lr_scheduler.step()
@@ -238,7 +269,7 @@ def train_optimized_diffusion(args):
f"LR: {current_lr:.2e}"
)
# 保存最佳模型
# 保存最佳模型 - 即使没有验证集也保存
if val_loss < best_val_loss:
best_val_loss = val_loss
best_model_path = output_dir / "best_model.pth"