change some problem 1

This commit is contained in:
Jiao77
2025-11-20 03:01:56 +08:00
parent f95a2bd2db
commit d2c75a2d14

View File

@@ -336,20 +336,36 @@ class OptimizedNoiseScheduler:
def add_noise(self, x_0, t):
"""向干净图像添加噪声"""
noise = torch.randn_like(x_0)
# 确保调度器张量与输入张量在同一设备上
device = x_0.device
if self.sqrt_alphas_cumprod.device != device:
self.sqrt_alphas_cumprod = self.sqrt_alphas_cumprod.to(device)
self.sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod.to(device)
sqrt_alphas_cumprod_t = self.sqrt_alphas_cumprod[t].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
return sqrt_alphas_cumprod_t * x_0 + sqrt_one_minus_alphas_cumprod_t * noise, noise
def sample_timestep(self, batch_size):
def sample_timestep(self, batch_size, device=None):
"""采样时间步"""
return torch.randint(0, self.num_timesteps, (batch_size,))
t = torch.randint(0, self.num_timesteps, (batch_size,))
if device is not None:
t = t.to(device)
return t
def step(self, model, x_t, t):
"""单步去噪"""
# 预测噪声
predicted_noise = model(x_t, t)
# 确保调度器张量与输入张量在同一设备上
device = x_t.device
if self.alphas.device != device:
self.alphas = self.alphas.to(device)
self.betas = self.betas.to(device)
self.sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod.to(device)
# 计算系数
alpha_t = self.alphas[t].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
sqrt_alpha_t = torch.sqrt(alpha_t)
@@ -395,15 +411,31 @@ class OptimizedDiffusionTrainer:
def __init__(self, model, scheduler, device='cuda', use_edge_condition=False):
self.model = model.to(device)
self.scheduler = scheduler
self.device = device
self.use_edge_condition = use_edge_condition
# 确保调度器的所有张量都在正确的设备上
self._move_scheduler_to_device(scheduler)
self.scheduler = scheduler
# 组合损失函数
self.edge_loss = EdgeAwareLoss()
self.structure_loss = MultiScaleStructureLoss()
self.edge_loss = EdgeAwareLoss().to(device)
self.structure_loss = MultiScaleStructureLoss().to(device)
self.mse_loss = nn.MSELoss()
def _move_scheduler_to_device(self, scheduler):
"""将调度器的所有张量移动到指定设备"""
if hasattr(scheduler, 'betas'):
scheduler.betas = scheduler.betas.to(self.device)
if hasattr(scheduler, 'alphas'):
scheduler.alphas = scheduler.alphas.to(self.device)
if hasattr(scheduler, 'alphas_cumprod'):
scheduler.alphas_cumprod = scheduler.alphas_cumprod.to(self.device)
if hasattr(scheduler, 'sqrt_alphas_cumprod'):
scheduler.sqrt_alphas_cumprod = scheduler.sqrt_alphas_cumprod.to(self.device)
if hasattr(scheduler, 'sqrt_one_minus_alphas_cumprod'):
scheduler.sqrt_one_minus_alphas_cumprod = scheduler.sqrt_one_minus_alphas_cumprod.to(self.device)
def train_step(self, optimizer, dataloader, manhattan_weight=0.1):
"""单步训练"""
self.model.train()