From 0a45856b14c8403fadbcf8a7690bbcec1a86430f Mon Sep 17 00:00:00 2001 From: Jiao77 Date: Thu, 20 Nov 2025 03:03:10 +0800 Subject: [PATCH] change some problem 2 --- tools/diffusion/ic_layout_diffusion_optimized.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tools/diffusion/ic_layout_diffusion_optimized.py b/tools/diffusion/ic_layout_diffusion_optimized.py index 98b2abb..b631ee4 100644 --- a/tools/diffusion/ic_layout_diffusion_optimized.py +++ b/tools/diffusion/ic_layout_diffusion_optimized.py @@ -318,16 +318,16 @@ class OptimizedNoiseScheduler: if schedule_type == 'cosine': # 余弦调度,通常效果更好 steps = num_timesteps + 1 - x = torch.linspace(0, num_timesteps, steps, dtype=torch.float64) + x = torch.linspace(0, num_timesteps, steps, dtype=torch.float32) alphas_cumprod = torch.cos(((x / num_timesteps) + 0.008) / 1.008 * np.pi / 2) ** 2 alphas_cumprod = alphas_cumprod / alphas_cumprod[0] betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) self.betas = torch.clip(betas, 0, 0.999) else: # 线性调度 - self.betas = torch.linspace(beta_start, beta_end, num_timesteps) + self.betas = torch.linspace(beta_start, beta_end, num_timesteps, dtype=torch.float32) - # 预计算 + # 预计算 - 确保所有张量都是float32 self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, axis=0) self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)