change some problem 7

This commit is contained in:
Jiao77
2025-11-20 03:09:18 +08:00
parent 3d75ed722a
commit 3258b7b6de
3 changed files with 407 additions and 100 deletions

View File

@@ -13,6 +13,7 @@
import os import os
import sys import sys
import math
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
@@ -64,15 +65,23 @@ class ICDiffusionDataset(Dataset):
def _extract_edges(self, image_tensor): def _extract_edges(self, image_tensor):
"""提取边缘条件图""" """提取边缘条件图"""
# 使用Sobel算子提取边缘 # 修复Sobel算子 - 正确的3x3 Sobel算子
sobel_x = torch.tensor([[[[-1,0,1],[-2,0,2],[-1,0,1]]]], if len(image_tensor.shape) == 3: # [C, H, W]
image_tensor = image_tensor.unsqueeze(0) # [1, C, H, W]
# 为单通道图像设计的Sobel算子
sobel_x = torch.tensor([[[[-1.0, 0.0, 1.0],
[-2.0, 0.0, 2.0],
[-1.0, 0.0, 1.0]]]],
dtype=image_tensor.dtype, device=image_tensor.device) dtype=image_tensor.dtype, device=image_tensor.device)
sobel_y = torch.tensor([[[[-1,-2,-1],[0,0,0],[1,2,1]]]], sobel_y = torch.tensor([[[[-1.0, -2.0, -1.0],
[0.0, 0.0, 0.0],
[1.0, 2.0, 1.0]]]],
dtype=image_tensor.dtype, device=image_tensor.device) dtype=image_tensor.dtype, device=image_tensor.device)
edge_x = F.conv2d(image_tensor.unsqueeze(0), sobel_x, padding=1) edge_x = F.conv2d(image_tensor, sobel_x, padding=1)
edge_y = F.conv2d(image_tensor.unsqueeze(0), sobel_y, padding=1) edge_y = F.conv2d(image_tensor, sobel_y, padding=1)
edge_magnitude = torch.sqrt(edge_x**2 + edge_y**2) edge_magnitude = torch.sqrt(edge_x**2 + edge_y**2 + 1e-8)
return torch.clamp(edge_magnitude, 0, 1) return torch.clamp(edge_magnitude, 0, 1)
@@ -170,8 +179,25 @@ def manhattan_regularization_loss(generated_image, device='cuda'):
return torch.mean(angle_penalty * edge_magnitude) return torch.mean(angle_penalty * edge_magnitude)
class SinusoidalPositionEmbeddings(nn.Module):
"""正弦位置编码用于时间步嵌入"""
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, time):
device = time.device
half_dim = self.dim // 2
embeddings = math.log(10000) / (half_dim - 1)
embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
embeddings = time[:, None].float() * embeddings[None, :]
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=1)
return embeddings
class ManhattanAwareUNet(nn.Module): class ManhattanAwareUNet(nn.Module):
"""曼哈顿几何感知的U-Net架构""" """曼哈顿几何感知的U-Net架构 - 修复版"""
def __init__(self, in_channels=1, out_channels=1, time_dim=256, use_edge_condition=False): def __init__(self, in_channels=1, out_channels=1, time_dim=256, use_edge_condition=False):
super().__init__() super().__init__()
@@ -180,9 +206,9 @@ class ManhattanAwareUNet(nn.Module):
# 输入通道数(原始图像 + 可选边缘条件) # 输入通道数(原始图像 + 可选边缘条件)
input_channels = in_channels + (1 if use_edge_condition else 0) input_channels = in_channels + (1 if use_edge_condition else 0)
# 时间嵌入 # 时间嵌入 - 修复时间步编码
self.time_mlp = nn.Sequential( self.time_mlp = nn.Sequential(
nn.Linear(1, time_dim), SinusoidalPositionEmbeddings(time_dim),
nn.SiLU(), nn.SiLU(),
nn.Linear(time_dim, time_dim) nn.Linear(time_dim, time_dim)
) )
@@ -199,43 +225,34 @@ class ManhattanAwareUNet(nn.Module):
nn.SiLU() nn.SiLU()
) )
# 编码器 - 增强版 # 编码器通道配置
encoder_channels = [64, 128, 256, 512] # 修正:减少通道数避免过拟合
self.encoder = nn.ModuleList([ self.encoder = nn.ModuleList([
self._make_block(64, 128), self._make_block(encoder_channels[i], encoder_channels[i+1], stride=2 if i > 0 else 1)
self._make_block(128, 256, stride=2), for i in range(len(encoder_channels)-1)
self._make_block(256, 512, stride=2),
self._make_block(512, 1024, stride=2),
])
# 残差投影层 - 确保残差连接的通道数和空间尺寸匹配
self.residual_projections = nn.ModuleList([
nn.Conv2d(64, 128, 1), # 64 -> 128, 保持尺寸
nn.Conv2d(128, 256, 3, stride=2, padding=1), # 128 -> 256, 尺寸减半
nn.Conv2d(256, 512, 3, stride=2, padding=1), # 256 -> 512, 尺寸减半
nn.Conv2d(512, 1024, 3, stride=2, padding=1), # 512 -> 1024, 尺寸减半
]) ])
# 中间层 # 中间层
self.middle = nn.Sequential( self.middle = nn.Sequential(
nn.Conv2d(1024, 1024, 3, padding=1), nn.Conv2d(512, 512, 3, padding=1),
nn.GroupNorm(8, 1024), nn.GroupNorm(8, 512),
nn.SiLU(), nn.SiLU(),
nn.Conv2d(1024, 1024, 3, padding=1), nn.Conv2d(512, 512, 3, padding=1),
nn.GroupNorm(8, 1024), nn.GroupNorm(8, 512),
nn.SiLU(), nn.SiLU(),
) )
# 解码器 # 解码器 - 修复通道数计算
self.decoder = nn.ModuleList([ self.decoder = nn.ModuleList([
self._make_decoder_block(1024, 512), self._make_decoder_block(512, 256), # middle (512) -> 256
self._make_decoder_block(512, 256), self._make_decoder_block(512, 128), # 256+256(skip) -> 128
self._make_decoder_block(256, 128), self._make_decoder_block(256, 64), # 128+128(skip) -> 64
self._make_decoder_block(128, 64), self._make_decoder_block(128, 64), # 64+64(skip) -> 64
]) ])
# 输出层 # 输出层 - 修复输入通道数
self.output = nn.Sequential( self.output = nn.Sequential(
nn.Conv2d(64, 32, 3, padding=1), nn.Conv2d(64, 32, 3, padding=1), # 最后一层跳跃连接后是64通道
nn.GroupNorm(8, 32), nn.GroupNorm(8, 32),
nn.SiLU(), nn.SiLU(),
nn.Conv2d(32, out_channels, 3, padding=1) nn.Conv2d(32, out_channels, 3, padding=1)
@@ -243,10 +260,7 @@ class ManhattanAwareUNet(nn.Module):
# 时间融合层 - 与编码器输出通道数匹配 # 时间融合层 - 与编码器输出通道数匹配
self.time_fusion = nn.ModuleList([ self.time_fusion = nn.ModuleList([
nn.Linear(time_dim, 128), nn.Linear(time_dim, channels) for channels in encoder_channels[1:] + [512]
nn.Linear(time_dim, 256),
nn.Linear(time_dim, 512),
nn.Linear(time_dim, 1024),
]) ])
def _make_block(self, in_channels, out_channels, stride=1): def _make_block(self, in_channels, out_channels, stride=1):
@@ -276,8 +290,8 @@ class ManhattanAwareUNet(nn.Module):
if self.use_edge_condition and edge_condition is not None: if self.use_edge_condition and edge_condition is not None:
x = torch.cat([x, edge_condition], dim=1) x = torch.cat([x, edge_condition], dim=1)
# 时间嵌入 # 时间嵌入 - 使用正弦位置编码
t_emb = self.time_mlp(t.float().unsqueeze(-1)) # [B, time_dim] t_emb = self.time_mlp(t) # [B, time_dim]
# 曼哈顿几何感知的特征提取 # 曼哈顿几何感知的特征提取
h_features = F.silu(self.horiz_conv(x)) h_features = F.silu(self.horiz_conv(x))
@@ -285,45 +299,58 @@ class ManhattanAwareUNet(nn.Module):
s_features = F.silu(self.standard_conv(x)) s_features = F.silu(self.standard_conv(x))
# 融合特征 # 融合特征
x = torch.cat([h_features, v_features, s_features], dim=1) x = torch.cat([h_features, v_features, s_features], dim=1) # [B, 96, H, W]
x = self.initial_fusion(x) x = self.initial_fusion(x) # [B, 64, H, W]
# 编码器路径 # 编码器路径 - 修复跳跃连接逻辑
skips = [] skips = []
for i, (encoder, fusion) in enumerate(zip(self.encoder, self.time_fusion)): for i, (encoder, fusion) in enumerate(zip(self.encoder, self.time_fusion)):
# 保存残差连接并使用投影层匹配通道数 # 保存跳跃连接(在编码之前)
residual = self.residual_projections[i](x) skips.append(x)
# 编码
x = encoder(x) x = encoder(x)
# 融合时间信息 # 融合时间信息
t_feat = fusion(t_emb).unsqueeze(-1).unsqueeze(-1) t_feat = fusion(t_emb).unsqueeze(-1).unsqueeze(-1) # [B, channels, 1, 1]
x = x + t_feat
# 跳跃连接 - 添加投影后的残差 # 检查通道数是否匹配
skips.append(x + residual) if x.shape[1] == t_feat.shape[1]:
x = x + t_feat
else:
# 如果不匹配使用1x1卷积调整通道数
if not hasattr(self, f'time_proj_{i}'):
setattr(self, f'time_proj_{i}',
nn.Conv2d(t_feat.shape[1], x.shape[1], 1).to(x.device))
time_proj = getattr(self, f'time_proj_{i}')
x = x + time_proj(t_feat)
# 中间层 # 中间层
x = self.middle(x) x = self.middle(x)
# 解码器路径 # 解码器路径 - 修复跳跃连接逻辑
for i, (decoder, skip) in enumerate(zip(self.decoder, reversed(skips))): for i, decoder in enumerate(self.decoder):
# 获取对应的跳跃连接(反向顺序)
skip = skips[-(i+1)]
# 上采样
x = decoder(x) x = decoder(x)
# 确保跳跃连接尺寸匹配 - 如果尺寸不匹配则进行裁剪或填充 # 确保跳跃连接尺寸匹配
if x.shape[2:] != skip.shape[2:]: if x.shape[2:] != skip.shape[2:]:
# 裁剪到最小尺寸 # 使用插值调整尺寸
h_min = min(x.shape[2], skip.shape[2]) x = F.interpolate(x, size=skip.shape[2:], mode='bilinear', align_corners=False)
w_min = min(x.shape[3], skip.shape[3])
if x.shape[2] > h_min:
x = x[:, :, :h_min, :]
if x.shape[3] > w_min:
x = x[:, :, :, :w_min]
if skip.shape[2] > h_min:
skip = skip[:, :, :h_min, :]
if skip.shape[3] > w_min:
skip = skip[:, :, :, :w_min]
x = x + skip # 跳跃连接 # 跳跃连接(需要处理通道数匹配)
if x.shape[1] == skip.shape[1]:
x = x + skip
else:
# 如果通道数不匹配使用1x1卷积调整
if not hasattr(self, f'skip_proj_{i}'):
setattr(self, f'skip_proj_{i}',
nn.Conv2d(skip.shape[1], x.shape[1], 1).to(x.device))
skip_proj = getattr(self, f'skip_proj_{i}')
x = x + skip_proj(skip)
# 输出 # 输出
x = self.output(x) x = self.output(x)
@@ -358,14 +385,11 @@ class OptimizedNoiseScheduler:
def add_noise(self, x_0, t): def add_noise(self, x_0, t):
"""向干净图像添加噪声""" """向干净图像添加噪声"""
noise = torch.randn_like(x_0) noise = torch.randn_like(x_0)
# 确保调度器张量与输入张量在同一设备上
device = x_0.device device = x_0.device
if self.sqrt_alphas_cumprod.device != device:
self.sqrt_alphas_cumprod = self.sqrt_alphas_cumprod.to(device)
self.sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod.to(device)
sqrt_alphas_cumprod_t = self.sqrt_alphas_cumprod[t].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) # 确保调度器张量与输入张量在同一设备上
sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) sqrt_alphas_cumprod_t = self.sqrt_alphas_cumprod[t].to(device).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t].to(device).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
return sqrt_alphas_cumprod_t * x_0 + sqrt_one_minus_alphas_cumprod_t * noise, noise return sqrt_alphas_cumprod_t * x_0 + sqrt_one_minus_alphas_cumprod_t * noise, noise
@@ -381,18 +405,13 @@ class OptimizedNoiseScheduler:
# 预测噪声 # 预测噪声
predicted_noise = model(x_t, t) predicted_noise = model(x_t, t)
# 确保调度器张量与输入张量在同一设备上
device = x_t.device device = x_t.device
if self.alphas.device != device:
self.alphas = self.alphas.to(device)
self.betas = self.betas.to(device)
self.sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod.to(device)
# 计算系数 # 计算系数(直接使用索引并移动到设备)
alpha_t = self.alphas[t].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) alpha_t = self.alphas[t].to(device).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
sqrt_alpha_t = torch.sqrt(alpha_t) sqrt_alpha_t = torch.sqrt(alpha_t)
beta_t = self.betas[t].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) beta_t = self.betas[t].to(device).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
sqrt_one_minus_alpha_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) sqrt_one_minus_alpha_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t].to(device).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
# 计算均值 # 计算均值
model_mean = (1.0 / sqrt_alpha_t) * (x_t - (beta_t / sqrt_one_minus_alpha_cumprod_t) * predicted_noise) model_mean = (1.0 / sqrt_alpha_t) * (x_t - (beta_t / sqrt_one_minus_alpha_cumprod_t) * predicted_noise)

View File

@@ -0,0 +1,257 @@
#!/usr/bin/env python3
"""
测试修复后的IC版图扩散模型
"""
import torch
import torch.nn as nn
import numpy as np
from pathlib import Path
# 导入修复后的模型
from ic_layout_diffusion_optimized import (
ManhattanAwareUNet,
OptimizedNoiseScheduler,
OptimizedDiffusionTrainer,
ICDiffusionDataset
)
def test_model_architecture():
"""测试模型架构修复"""
print("测试模型架构...")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"使用设备: {device}")
# 测试不同的配置
test_configs = [
{'use_edge_condition': False, 'batch_size': 2},
{'use_edge_condition': True, 'batch_size': 2},
]
for config in test_configs:
print(f"\n测试配置: {config}")
try:
# 创建模型
model = ManhattanAwareUNet(
in_channels=1,
out_channels=1,
time_dim=256,
use_edge_condition=config['use_edge_condition']
).to(device)
# 创建测试数据
batch_size = config['batch_size']
image_size = 64 # 使用较小的图像尺寸进行快速测试
x = torch.randn(batch_size, 1, image_size, image_size).to(device)
t = torch.randint(0, 1000, (batch_size,)).to(device)
if config['use_edge_condition']:
edge_condition = torch.randn(batch_size, 1, image_size, image_size).to(device)
output = model(x, t, edge_condition)
else:
output = model(x, t)
print(f"✓ 输入形状: {x.shape}")
print(f"✓ 输出形状: {output.shape}")
print(f"✓ 时间步形状: {t.shape}")
# 检查输出形状
expected_shape = (batch_size, 1, image_size, image_size)
if output.shape == expected_shape:
print(f"✓ 输出形状正确: {output.shape}")
else:
print(f"✗ 输出形状错误: 期望 {expected_shape}, 得到 {output.shape}")
except Exception as e:
print(f"✗ 模型测试失败: {e}")
return False
print("\n✓ 所有模型架构测试通过!")
return True
def test_scheduler():
"""测试噪声调度器修复"""
print("\n测试噪声调度器...")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
try:
scheduler = OptimizedNoiseScheduler(num_timesteps=1000, schedule_type='cosine')
# 测试噪声添加
x_0 = torch.randn(4, 1, 32, 32).to(device)
t = torch.randint(0, 1000, (4,)).to(device)
x_t, noise = scheduler.add_noise(x_0, t)
print(f"✓ 原始图像形状: {x_0.shape}")
print(f"✓ 噪声图像形状: {x_t.shape}")
print(f"✓ 噪声形状: {noise.shape}")
# 测试去噪步骤
model = ManhattanAwareUNet().to(device)
x_denoised = scheduler.step(model, x_t, t)
print(f"✓ 去噪图像形状: {x_denoised.shape}")
# 检查形状是否保持一致
if x_denoised.shape == x_t.shape:
print("✓ 去噪步骤形状正确")
else:
print(f"✗ 去噪步骤形状错误: 期望 {x_t.shape}, 得到 {x_denoised.shape}")
return False
except Exception as e:
print(f"✗ 调度器测试失败: {e}")
return False
print("✓ 调度器测试通过!")
return True
def test_trainer():
"""测试训练器修复"""
print("\n测试训练器...")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
try:
# 创建模型和调度器
model = ManhattanAwareUNet().to(device)
scheduler = OptimizedNoiseScheduler(num_timesteps=100)
trainer = OptimizedDiffusionTrainer(model, scheduler, device, use_edge_condition=False)
# 创建虚拟数据
batch_size = 2
images = torch.randn(batch_size, 1, 32, 32).to(device)
# 测试单个训练步骤
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
# 模拟数据加载器
class MockDataloader:
def __init__(self, data):
self.data = data
def __iter__(self):
yield self.data
mock_dataloader = MockDataloader(images)
# 执行训练步骤
losses = trainer.train_step(optimizer, mock_dataloader, manhattan_weight=0.1)
print(f"✓ 训练步骤完成,损失: {losses}")
# 测试生成
samples = trainer.generate(
num_samples=2,
image_size=32,
save_dir=None,
use_post_process=False
)
print(f"✓ 生成样本形状: {samples.shape}")
except Exception as e:
print(f"✗ 训练器测试失败: {e}")
import traceback
traceback.print_exc()
return False
print("✓ 训练器测试通过!")
return True
def test_data_loading():
"""测试数据加载修复"""
print("\n测试数据加载...")
# 创建临时测试目录(如果不存在)
test_dir = Path("test_images")
try:
if not test_dir.exists():
print("创建测试图像目录...")
test_dir.mkdir(exist_ok=True)
# 创建一些简单的测试图像
from PIL import Image
import numpy as np
for i in range(3):
# 创建随机灰度图像
img_array = np.random.randint(0, 255, (64, 64), dtype=np.uint8)
img = Image.fromarray(img_array, mode='L')
img.save(test_dir / f"test_{i}.png")
# 测试数据集创建
dataset = ICDiffusionDataset(
image_dir=str(test_dir),
image_size=32,
augment=False,
use_edge_condition=False
)
print(f"✓ 数据集大小: {len(dataset)}")
if len(dataset) == 0:
print("✗ 数据集为空")
return False
# 测试数据加载
sample = dataset[0]
print(f"✓ 样本形状: {sample.shape}")
# 测试边缘条件
dataset_edge = ICDiffusionDataset(
image_dir=str(test_dir),
image_size=32,
augment=False,
use_edge_condition=True
)
image, edge = dataset_edge[0]
print(f"✓ 图像形状: {image.shape}, 边缘形状: {edge.shape}")
except Exception as e:
print(f"✗ 数据加载测试失败: {e}")
import traceback
traceback.print_exc()
return False
print("✓ 数据加载测试通过!")
return True
def main():
"""运行所有测试"""
print("开始测试修复后的IC版图扩散模型...")
print("=" * 50)
all_tests_passed = True
# 运行各项测试
tests = [
("模型架构", test_model_architecture),
("噪声调度器", test_scheduler),
("训练器", test_trainer),
("数据加载", test_data_loading),
]
for test_name, test_func in tests:
print(f"\n{'='*20} {test_name} {'='*20}")
if not test_func():
all_tests_passed = False
print("\n" + "=" * 50)
if all_tests_passed:
print("🎉 所有测试通过!模型修复成功。")
else:
print("❌ 部分测试失败,需要进一步修复。")
return all_tests_passed
if __name__ == "__main__":
success = main()
exit(0 if success else 1)

View File

@@ -131,28 +131,54 @@ def train_optimized_diffusion(args):
use_edge_condition=args.edge_condition use_edge_condition=args.edge_condition
) )
# 数据集分割 # 检查数据集是否为空
if len(dataset) == 0:
logger.error(f"数据集为空!请检查数据目录: {args.data_dir}")
raise ValueError(f"数据集为空,在目录 {args.data_dir} 中未找到图像文件")
logger.info(f"找到 {len(dataset)} 个训练样本")
# 数据集分割 - 修复空数据集问题
total_size = len(dataset) total_size = len(dataset)
if total_size < 10: # 如果数据集太小,全部用于训练
logger.warning(f"数据集较小 ({total_size} 样本),全部用于训练")
train_dataset = dataset
val_dataset = None
else:
train_size = int(0.9 * total_size) train_size = int(0.9 * total_size)
val_size = total_size - train_size val_size = total_size - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size]) train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
logger.info(f"训练集: {len(train_dataset)}, 验证集: {len(val_dataset)}")
# 数据加载器 # 数据加载器 - 修复None验证集问题
if device.type == 'cuda':
train_dataloader = DataLoader( train_dataloader = DataLoader(
train_dataset, train_dataset,
batch_size=args.batch_size, batch_size=min(args.batch_size, len(train_dataset)), # 确保批次大小不超过数据集大小
shuffle=True, shuffle=True,
num_workers=4, num_workers=min(4, max(1, len(train_dataset) // args.batch_size)),
pin_memory=True pin_memory=True,
drop_last=True # 避免最后一个不完整的批次
) )
else:
# CPU模式下使用较少的worker
train_dataloader = DataLoader(
train_dataset,
batch_size=min(args.batch_size, len(train_dataset)),
shuffle=True,
num_workers=0, # CPU模式下避免多进程
drop_last=True
)
if val_dataset is not None:
val_dataloader = DataLoader( val_dataloader = DataLoader(
val_dataset, val_dataset,
batch_size=args.batch_size, batch_size=min(args.batch_size, len(val_dataset)),
shuffle=False, shuffle=False,
num_workers=2 num_workers=2
) )
else:
logger.info(f"训练集大小: {len(train_dataset)}, 验证集大小: {len(val_dataset)}") val_dataloader = None
# 创建模型 # 创建模型
logger.info("创建优化模型...") logger.info("创建优化模型...")
@@ -209,8 +235,13 @@ def train_optimized_diffusion(args):
optimizer, train_dataloader, args.manhattan_weight optimizer, train_dataloader, args.manhattan_weight
) )
# 验证 # 验证 - 修复None验证集问题
if val_dataloader is not None:
val_loss = validate_model(trainer, val_dataloader, device) val_loss = validate_model(trainer, val_dataloader, device)
else:
# 如果没有验证集,使用训练损失作为验证损失
val_loss = train_losses['total_loss']
logger.warning("未使用验证集 - 使用训练损失作为参考")
# 学习率调度 # 学习率调度
lr_scheduler.step() lr_scheduler.step()
@@ -238,7 +269,7 @@ def train_optimized_diffusion(args):
f"LR: {current_lr:.2e}" f"LR: {current_lr:.2e}"
) )
# 保存最佳模型 # 保存最佳模型 - 即使没有验证集也保存
if val_loss < best_val_loss: if val_loss < best_val_loss:
best_val_loss = val_loss best_val_loss = val_loss
best_model_path = output_dir / "best_model.pth" best_model_path = output_dir / "best_model.pth"