change some problem 1
This commit is contained in:
@@ -336,20 +336,36 @@ class OptimizedNoiseScheduler:
|
||||
def add_noise(self, x_0, t):
|
||||
"""向干净图像添加噪声"""
|
||||
noise = torch.randn_like(x_0)
|
||||
# 确保调度器张量与输入张量在同一设备上
|
||||
device = x_0.device
|
||||
if self.sqrt_alphas_cumprod.device != device:
|
||||
self.sqrt_alphas_cumprod = self.sqrt_alphas_cumprod.to(device)
|
||||
self.sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod.to(device)
|
||||
|
||||
sqrt_alphas_cumprod_t = self.sqrt_alphas_cumprod[t].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
||||
sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
||||
|
||||
return sqrt_alphas_cumprod_t * x_0 + sqrt_one_minus_alphas_cumprod_t * noise, noise
|
||||
|
||||
def sample_timestep(self, batch_size):
|
||||
def sample_timestep(self, batch_size, device=None):
|
||||
"""采样时间步"""
|
||||
return torch.randint(0, self.num_timesteps, (batch_size,))
|
||||
t = torch.randint(0, self.num_timesteps, (batch_size,))
|
||||
if device is not None:
|
||||
t = t.to(device)
|
||||
return t
|
||||
|
||||
def step(self, model, x_t, t):
|
||||
"""单步去噪"""
|
||||
# 预测噪声
|
||||
predicted_noise = model(x_t, t)
|
||||
|
||||
# 确保调度器张量与输入张量在同一设备上
|
||||
device = x_t.device
|
||||
if self.alphas.device != device:
|
||||
self.alphas = self.alphas.to(device)
|
||||
self.betas = self.betas.to(device)
|
||||
self.sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod.to(device)
|
||||
|
||||
# 计算系数
|
||||
alpha_t = self.alphas[t].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
||||
sqrt_alpha_t = torch.sqrt(alpha_t)
|
||||
@@ -395,15 +411,31 @@ class OptimizedDiffusionTrainer:
|
||||
|
||||
def __init__(self, model, scheduler, device='cuda', use_edge_condition=False):
|
||||
self.model = model.to(device)
|
||||
self.scheduler = scheduler
|
||||
self.device = device
|
||||
self.use_edge_condition = use_edge_condition
|
||||
|
||||
# 确保调度器的所有张量都在正确的设备上
|
||||
self._move_scheduler_to_device(scheduler)
|
||||
self.scheduler = scheduler
|
||||
|
||||
# 组合损失函数
|
||||
self.edge_loss = EdgeAwareLoss()
|
||||
self.structure_loss = MultiScaleStructureLoss()
|
||||
self.edge_loss = EdgeAwareLoss().to(device)
|
||||
self.structure_loss = MultiScaleStructureLoss().to(device)
|
||||
self.mse_loss = nn.MSELoss()
|
||||
|
||||
def _move_scheduler_to_device(self, scheduler):
|
||||
"""将调度器的所有张量移动到指定设备"""
|
||||
if hasattr(scheduler, 'betas'):
|
||||
scheduler.betas = scheduler.betas.to(self.device)
|
||||
if hasattr(scheduler, 'alphas'):
|
||||
scheduler.alphas = scheduler.alphas.to(self.device)
|
||||
if hasattr(scheduler, 'alphas_cumprod'):
|
||||
scheduler.alphas_cumprod = scheduler.alphas_cumprod.to(self.device)
|
||||
if hasattr(scheduler, 'sqrt_alphas_cumprod'):
|
||||
scheduler.sqrt_alphas_cumprod = scheduler.sqrt_alphas_cumprod.to(self.device)
|
||||
if hasattr(scheduler, 'sqrt_one_minus_alphas_cumprod'):
|
||||
scheduler.sqrt_one_minus_alphas_cumprod = scheduler.sqrt_one_minus_alphas_cumprod.to(self.device)
|
||||
|
||||
def train_step(self, optimizer, dataloader, manhattan_weight=0.1):
|
||||
"""单步训练"""
|
||||
self.model.train()
|
||||
|
||||
Reference in New Issue
Block a user