diff --git a/tools/diffusion/ic_layout_diffusion_optimized.py b/tools/diffusion/ic_layout_diffusion_optimized.py index f963ec3..98b2abb 100644 --- a/tools/diffusion/ic_layout_diffusion_optimized.py +++ b/tools/diffusion/ic_layout_diffusion_optimized.py @@ -336,20 +336,36 @@ class OptimizedNoiseScheduler: def add_noise(self, x_0, t): """向干净图像添加噪声""" noise = torch.randn_like(x_0) + # 确保调度器张量与输入张量在同一设备上 + device = x_0.device + if self.sqrt_alphas_cumprod.device != device: + self.sqrt_alphas_cumprod = self.sqrt_alphas_cumprod.to(device) + self.sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod.to(device) + sqrt_alphas_cumprod_t = self.sqrt_alphas_cumprod[t].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) return sqrt_alphas_cumprod_t * x_0 + sqrt_one_minus_alphas_cumprod_t * noise, noise - def sample_timestep(self, batch_size): + def sample_timestep(self, batch_size, device=None): """采样时间步""" - return torch.randint(0, self.num_timesteps, (batch_size,)) + t = torch.randint(0, self.num_timesteps, (batch_size,)) + if device is not None: + t = t.to(device) + return t def step(self, model, x_t, t): """单步去噪""" # 预测噪声 predicted_noise = model(x_t, t) + # 确保调度器张量与输入张量在同一设备上 + device = x_t.device + if self.alphas.device != device: + self.alphas = self.alphas.to(device) + self.betas = self.betas.to(device) + self.sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod.to(device) + # 计算系数 alpha_t = self.alphas[t].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) sqrt_alpha_t = torch.sqrt(alpha_t) @@ -395,15 +411,31 @@ class OptimizedDiffusionTrainer: def __init__(self, model, scheduler, device='cuda', use_edge_condition=False): self.model = model.to(device) - self.scheduler = scheduler self.device = device self.use_edge_condition = use_edge_condition + # 确保调度器的所有张量都在正确的设备上 + self._move_scheduler_to_device(scheduler) + self.scheduler = scheduler + # 组合损失函数 - self.edge_loss = EdgeAwareLoss() - self.structure_loss = MultiScaleStructureLoss() + self.edge_loss = EdgeAwareLoss().to(device) + self.structure_loss = MultiScaleStructureLoss().to(device) self.mse_loss = nn.MSELoss() + def _move_scheduler_to_device(self, scheduler): + """将调度器的所有张量移动到指定设备""" + if hasattr(scheduler, 'betas'): + scheduler.betas = scheduler.betas.to(self.device) + if hasattr(scheduler, 'alphas'): + scheduler.alphas = scheduler.alphas.to(self.device) + if hasattr(scheduler, 'alphas_cumprod'): + scheduler.alphas_cumprod = scheduler.alphas_cumprod.to(self.device) + if hasattr(scheduler, 'sqrt_alphas_cumprod'): + scheduler.sqrt_alphas_cumprod = scheduler.sqrt_alphas_cumprod.to(self.device) + if hasattr(scheduler, 'sqrt_one_minus_alphas_cumprod'): + scheduler.sqrt_one_minus_alphas_cumprod = scheduler.sqrt_one_minus_alphas_cumprod.to(self.device) + def train_step(self, optimizer, dataloader, manhattan_weight=0.1): """单步训练""" self.model.train()