change some problem 2
This commit is contained in:
@@ -318,16 +318,16 @@ class OptimizedNoiseScheduler:
|
|||||||
if schedule_type == 'cosine':
|
if schedule_type == 'cosine':
|
||||||
# 余弦调度,通常效果更好
|
# 余弦调度,通常效果更好
|
||||||
steps = num_timesteps + 1
|
steps = num_timesteps + 1
|
||||||
x = torch.linspace(0, num_timesteps, steps, dtype=torch.float64)
|
x = torch.linspace(0, num_timesteps, steps, dtype=torch.float32)
|
||||||
alphas_cumprod = torch.cos(((x / num_timesteps) + 0.008) / 1.008 * np.pi / 2) ** 2
|
alphas_cumprod = torch.cos(((x / num_timesteps) + 0.008) / 1.008 * np.pi / 2) ** 2
|
||||||
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
|
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
|
||||||
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
||||||
self.betas = torch.clip(betas, 0, 0.999)
|
self.betas = torch.clip(betas, 0, 0.999)
|
||||||
else:
|
else:
|
||||||
# 线性调度
|
# 线性调度
|
||||||
self.betas = torch.linspace(beta_start, beta_end, num_timesteps)
|
self.betas = torch.linspace(beta_start, beta_end, num_timesteps, dtype=torch.float32)
|
||||||
|
|
||||||
# 预计算
|
# 预计算 - 确保所有张量都是float32
|
||||||
self.alphas = 1.0 - self.betas
|
self.alphas = 1.0 - self.betas
|
||||||
self.alphas_cumprod = torch.cumprod(self.alphas, axis=0)
|
self.alphas_cumprod = torch.cumprod(self.alphas, axis=0)
|
||||||
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
|
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
|
||||||
|
|||||||
Reference in New Issue
Block a user