change some problem 2
This commit is contained in:
@@ -318,16 +318,16 @@ class OptimizedNoiseScheduler:
|
||||
if schedule_type == 'cosine':
|
||||
# 余弦调度,通常效果更好
|
||||
steps = num_timesteps + 1
|
||||
x = torch.linspace(0, num_timesteps, steps, dtype=torch.float64)
|
||||
x = torch.linspace(0, num_timesteps, steps, dtype=torch.float32)
|
||||
alphas_cumprod = torch.cos(((x / num_timesteps) + 0.008) / 1.008 * np.pi / 2) ** 2
|
||||
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
|
||||
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
||||
self.betas = torch.clip(betas, 0, 0.999)
|
||||
else:
|
||||
# 线性调度
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_timesteps)
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_timesteps, dtype=torch.float32)
|
||||
|
||||
# 预计算
|
||||
# 预计算 - 确保所有张量都是float32
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = torch.cumprod(self.alphas, axis=0)
|
||||
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
|
||||
|
||||
Reference in New Issue
Block a user