change some problem 7
This commit is contained in:
@@ -131,28 +131,54 @@ def train_optimized_diffusion(args):
|
||||
use_edge_condition=args.edge_condition
|
||||
)
|
||||
|
||||
# 数据集分割
|
||||
# 检查数据集是否为空
|
||||
if len(dataset) == 0:
|
||||
logger.error(f"数据集为空!请检查数据目录: {args.data_dir}")
|
||||
raise ValueError(f"数据集为空,在目录 {args.data_dir} 中未找到图像文件")
|
||||
|
||||
logger.info(f"找到 {len(dataset)} 个训练样本")
|
||||
|
||||
# 数据集分割 - 修复空数据集问题
|
||||
total_size = len(dataset)
|
||||
train_size = int(0.9 * total_size)
|
||||
val_size = total_size - train_size
|
||||
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
|
||||
if total_size < 10: # 如果数据集太小,全部用于训练
|
||||
logger.warning(f"数据集较小 ({total_size} 样本),全部用于训练")
|
||||
train_dataset = dataset
|
||||
val_dataset = None
|
||||
else:
|
||||
train_size = int(0.9 * total_size)
|
||||
val_size = total_size - train_size
|
||||
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
|
||||
logger.info(f"训练集: {len(train_dataset)}, 验证集: {len(val_dataset)}")
|
||||
|
||||
# 数据加载器
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=4,
|
||||
pin_memory=True
|
||||
)
|
||||
val_dataloader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=2
|
||||
)
|
||||
# 数据加载器 - 修复None验证集问题
|
||||
if device.type == 'cuda':
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=min(args.batch_size, len(train_dataset)), # 确保批次大小不超过数据集大小
|
||||
shuffle=True,
|
||||
num_workers=min(4, max(1, len(train_dataset) // args.batch_size)),
|
||||
pin_memory=True,
|
||||
drop_last=True # 避免最后一个不完整的批次
|
||||
)
|
||||
else:
|
||||
# CPU模式下使用较少的worker
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=min(args.batch_size, len(train_dataset)),
|
||||
shuffle=True,
|
||||
num_workers=0, # CPU模式下避免多进程
|
||||
drop_last=True
|
||||
)
|
||||
|
||||
logger.info(f"训练集大小: {len(train_dataset)}, 验证集大小: {len(val_dataset)}")
|
||||
if val_dataset is not None:
|
||||
val_dataloader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=min(args.batch_size, len(val_dataset)),
|
||||
shuffle=False,
|
||||
num_workers=2
|
||||
)
|
||||
else:
|
||||
val_dataloader = None
|
||||
|
||||
# 创建模型
|
||||
logger.info("创建优化模型...")
|
||||
@@ -209,8 +235,13 @@ def train_optimized_diffusion(args):
|
||||
optimizer, train_dataloader, args.manhattan_weight
|
||||
)
|
||||
|
||||
# 验证
|
||||
val_loss = validate_model(trainer, val_dataloader, device)
|
||||
# 验证 - 修复None验证集问题
|
||||
if val_dataloader is not None:
|
||||
val_loss = validate_model(trainer, val_dataloader, device)
|
||||
else:
|
||||
# 如果没有验证集,使用训练损失作为验证损失
|
||||
val_loss = train_losses['total_loss']
|
||||
logger.warning("未使用验证集 - 使用训练损失作为参考")
|
||||
|
||||
# 学习率调度
|
||||
lr_scheduler.step()
|
||||
@@ -238,7 +269,7 @@ def train_optimized_diffusion(args):
|
||||
f"LR: {current_lr:.2e}"
|
||||
)
|
||||
|
||||
# 保存最佳模型
|
||||
# 保存最佳模型 - 即使没有验证集也保存
|
||||
if val_loss < best_val_loss:
|
||||
best_val_loss = val_loss
|
||||
best_model_path = output_dir / "best_model.pth"
|
||||
|
||||
Reference in New Issue
Block a user