change some problem 7
This commit is contained in:
@@ -78,24 +78,24 @@ def ddim_sample(model, scheduler, num_samples, image_size, device, num_steps=50,
|
||||
# 预测噪声
|
||||
predicted_noise = model(x, t_batch)
|
||||
|
||||
# 计算原始图像的估计
|
||||
alpha_t = scheduler.alphas[t].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
||||
alpha_cumprod_t = scheduler.alphas_cumprod[t].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
||||
beta_t = scheduler.betas[t].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
||||
sqrt_one_minus_alpha_cumprod_t = scheduler.sqrt_one_minus_alphas_cumprod[t].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
||||
# 计算原始图像的估计 - 确保调度器张量在正确设备上
|
||||
alpha_t = scheduler.alphas[t].to(device).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
||||
alpha_cumprod_t = scheduler.alphas_cumprod[t].to(device).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
||||
beta_t = scheduler.betas[t].to(device).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
||||
sqrt_one_minus_alpha_cumprod_t = scheduler.sqrt_one_minus_alphas_cumprod[t].to(device).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
||||
|
||||
# 计算x_0的估计
|
||||
x_0_pred = (x - sqrt_one_minus_alpha_cumprod_t * predicted_noise) / torch.sqrt(alpha_cumprod_t)
|
||||
|
||||
# 计算前一时间步的方向
|
||||
if i < len(timesteps) - 1:
|
||||
alpha_t_prev = scheduler.alphas[timesteps[i+1]]
|
||||
alpha_cumprod_t_prev = scheduler.alphas_cumprod[timesteps[i+1]]
|
||||
alpha_t_prev = scheduler.alphas[timesteps[i+1]].to(device)
|
||||
alpha_cumprod_t_prev = scheduler.alphas_cumprod[timesteps[i+1]].to(device)
|
||||
sqrt_alpha_cumprod_t_prev = torch.sqrt(alpha_cumprod_t_prev).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
||||
sqrt_one_minus_alpha_cumprod_t_prev = torch.sqrt(1 - alpha_cumprod_t_prev).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
||||
|
||||
# 计算方差
|
||||
variance = eta * torch.sqrt(beta_t).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
||||
variance = eta * torch.sqrt(beta_t).squeeze().squeeze().squeeze()
|
||||
|
||||
# 计算前一时间步的x
|
||||
x = sqrt_alpha_cumprod_t_prev * x_0_pred + torch.sqrt(1 - alpha_cumprod_t_prev - variance**2) * predicted_noise
|
||||
@@ -183,6 +183,18 @@ def generate_optimized_samples(args):
|
||||
schedule_type=config.get('schedule_type', args.schedule_type)
|
||||
)
|
||||
|
||||
# 确保调度器的所有张量都在正确的设备上
|
||||
if hasattr(scheduler, 'betas'):
|
||||
scheduler.betas = scheduler.betas.to(device)
|
||||
if hasattr(scheduler, 'alphas'):
|
||||
scheduler.alphas = scheduler.alphas.to(device)
|
||||
if hasattr(scheduler, 'alphas_cumprod'):
|
||||
scheduler.alphas_cumprod = scheduler.alphas_cumprod.to(device)
|
||||
if hasattr(scheduler, 'sqrt_alphas_cumprod'):
|
||||
scheduler.sqrt_alphas_cumprod = scheduler.sqrt_alphas_cumprod.to(device)
|
||||
if hasattr(scheduler, 'sqrt_one_minus_alphas_cumprod'):
|
||||
scheduler.sqrt_one_minus_alphas_cumprod = scheduler.sqrt_one_minus_alphas_cumprod.to(device)
|
||||
|
||||
# 生成参数
|
||||
generation_config = {
|
||||
'num_samples': args.num_samples,
|
||||
|
||||
Reference in New Issue
Block a user