change some problem 2

This commit is contained in:
Jiao77
2025-11-20 03:03:10 +08:00
parent d2c75a2d14
commit 0a45856b14

View File

@@ -318,16 +318,16 @@ class OptimizedNoiseScheduler:
if schedule_type == 'cosine': if schedule_type == 'cosine':
# 余弦调度,通常效果更好 # 余弦调度,通常效果更好
steps = num_timesteps + 1 steps = num_timesteps + 1
x = torch.linspace(0, num_timesteps, steps, dtype=torch.float64) x = torch.linspace(0, num_timesteps, steps, dtype=torch.float32)
alphas_cumprod = torch.cos(((x / num_timesteps) + 0.008) / 1.008 * np.pi / 2) ** 2 alphas_cumprod = torch.cos(((x / num_timesteps) + 0.008) / 1.008 * np.pi / 2) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0] alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
self.betas = torch.clip(betas, 0, 0.999) self.betas = torch.clip(betas, 0, 0.999)
else: else:
# 线性调度 # 线性调度
self.betas = torch.linspace(beta_start, beta_end, num_timesteps) self.betas = torch.linspace(beta_start, beta_end, num_timesteps, dtype=torch.float32)
# 预计算 # 预计算 - 确保所有张量都是float32
self.alphas = 1.0 - self.betas self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, axis=0) self.alphas_cumprod = torch.cumprod(self.alphas, axis=0)
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod) self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)