change some problem 7

This commit is contained in:
Jiao77
2025-11-20 05:17:11 +08:00
parent bacf8cd69d
commit afd48c2d86
2 changed files with 110 additions and 34 deletions

View File

@@ -78,24 +78,24 @@ def ddim_sample(model, scheduler, num_samples, image_size, device, num_steps=50,
# 预测噪声
predicted_noise = model(x, t_batch)
# 计算原始图像的估计
alpha_t = scheduler.alphas[t].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
alpha_cumprod_t = scheduler.alphas_cumprod[t].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
beta_t = scheduler.betas[t].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
sqrt_one_minus_alpha_cumprod_t = scheduler.sqrt_one_minus_alphas_cumprod[t].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
# 计算原始图像的估计 - 确保调度器张量在正确设备上
alpha_t = scheduler.alphas[t].to(device).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
alpha_cumprod_t = scheduler.alphas_cumprod[t].to(device).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
beta_t = scheduler.betas[t].to(device).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
sqrt_one_minus_alpha_cumprod_t = scheduler.sqrt_one_minus_alphas_cumprod[t].to(device).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
# 计算x_0的估计
x_0_pred = (x - sqrt_one_minus_alpha_cumprod_t * predicted_noise) / torch.sqrt(alpha_cumprod_t)
# 计算前一时间步的方向
if i < len(timesteps) - 1:
alpha_t_prev = scheduler.alphas[timesteps[i+1]]
alpha_cumprod_t_prev = scheduler.alphas_cumprod[timesteps[i+1]]
alpha_t_prev = scheduler.alphas[timesteps[i+1]].to(device)
alpha_cumprod_t_prev = scheduler.alphas_cumprod[timesteps[i+1]].to(device)
sqrt_alpha_cumprod_t_prev = torch.sqrt(alpha_cumprod_t_prev).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
sqrt_one_minus_alpha_cumprod_t_prev = torch.sqrt(1 - alpha_cumprod_t_prev).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
# 计算方差
variance = eta * torch.sqrt(beta_t).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
variance = eta * torch.sqrt(beta_t).squeeze().squeeze().squeeze()
# 计算前一时间步的x
x = sqrt_alpha_cumprod_t_prev * x_0_pred + torch.sqrt(1 - alpha_cumprod_t_prev - variance**2) * predicted_noise
@@ -183,6 +183,18 @@ def generate_optimized_samples(args):
schedule_type=config.get('schedule_type', args.schedule_type)
)
# 确保调度器的所有张量都在正确的设备上
if hasattr(scheduler, 'betas'):
scheduler.betas = scheduler.betas.to(device)
if hasattr(scheduler, 'alphas'):
scheduler.alphas = scheduler.alphas.to(device)
if hasattr(scheduler, 'alphas_cumprod'):
scheduler.alphas_cumprod = scheduler.alphas_cumprod.to(device)
if hasattr(scheduler, 'sqrt_alphas_cumprod'):
scheduler.sqrt_alphas_cumprod = scheduler.sqrt_alphas_cumprod.to(device)
if hasattr(scheduler, 'sqrt_one_minus_alphas_cumprod'):
scheduler.sqrt_one_minus_alphas_cumprod = scheduler.sqrt_one_minus_alphas_cumprod.to(device)
# 生成参数
generation_config = {
'num_samples': args.num_samples,