The overall framework concists of two main blocks, ddpm and vae
The vae devides intro two types, either the VQ-GAN or KL-VAE. Refer VQ-GAN for more details. Here we only look as the structure details about the KL-VAE. And in the future version, stable diffusion only used the KL-VAE.
It assumes that the norm of the gradient of the GAN should be equal to that of the reconstruction loss, make them in same level. Of course, finnaly, it will multiply another weights to control the importance of the GAN loss compared with other loss
reconstruction loss
It used the \(L_1\) loss for reconstruction loss. If the perceptual_weight>0, it also combines with the perceptual_loss, which is usually a loss to constrained the feature map distances between the generation and input via a pretrained CNN feature extraction like "VGG"
KL Divergence
posteriors.kl() the KL divence between the posteriors and the Gaussian distribution \(U(0,1)\). It is defined in the Gaissian Sampler DiagonalGaussianDistribution. It it is a deterministic sampling, this loss returns 0.
Gan Loss
Since here it aimed to train the generator (VAE), we dont have the real part for adversary loss. If it is a conditional GAN, then put the conditional information into the input, which is a quite standard adversary loss.
Given the discriminator outputs for real samples, \( D(x) \), and for generated samples, \( D(\tilde{x}) \), the hinge loss for the discriminator is defined as:
These formulas correspond exactly to the implementation in the code:
For the hinge loss, the code computes \(\text{loss\_real} = \text{mean}(\text{ReLU}(1 - \text{logits\_real}))\) and \(\text{loss\_fake} = \text{mean}(\text{ReLU}(1 + \text{logits\_fake}))\), then takes their average.
For the vanilla loss, the code computes \(\text{mean}(\text{softplus}(-\text{logits\_real}))\) and \(\text{mean}(\text{softplus}(\text{logits\_fake}))\) and averages them.
Here we plot the curve of the above two types of losses for intuitive comparison.
classDDPM(pl.LightningModule):# classic DDPM with Gaussian diffusion, in image spacedef__init__(self,unet_config,timesteps=1000,beta_schedule="linear",loss_type="l2",ckpt_path=None,ignore_keys=[],load_only_unet=False,monitor="val/loss",use_ema=True,first_stage_key="image",image_size=256,channels=3,log_every_t=100,clip_denoised=True,linear_start=1e-4,linear_end=2e-2,cosine_s=8e-3,given_betas=None,original_elbo_weight=0.,v_posterior=0.,# weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * betal_simple_weight=1.,conditioning_key=None,parameterization="eps",# all assuming fixed variance schedulesscheduler_config=None,use_positional_encodings=False,learn_logvar=False,logvar_init=0.,):super().__init__()assertparameterizationin["eps","x0"],'currently only supporting "eps" and "x0"'self.parameterization=parameterizationprint(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")self.cond_stage_model=Noneself.clip_denoised=clip_denoisedself.log_every_t=log_every_tself.first_stage_key=first_stage_keyself.image_size=image_size# try conv?self.channels=channelsself.use_positional_encodings=use_positional_encodingsself.model=DiffusionWrapper(unet_config,conditioning_key)count_params(self.model,verbose=True)self.use_ema=use_emaifself.use_ema:self.model_ema=LitEma(self.model)print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")self.use_scheduler=scheduler_configisnotNoneifself.use_scheduler:self.scheduler_config=scheduler_configself.v_posterior=v_posteriorself.original_elbo_weight=original_elbo_weightself.l_simple_weight=l_simple_weightifmonitorisnotNone:self.monitor=monitorifckpt_pathisnotNone:self.init_from_ckpt(ckpt_path,ignore_keys=ignore_keys,only_model=load_only_unet)self.register_schedule(given_betas=given_betas,beta_schedule=beta_schedule,timesteps=timesteps,linear_start=linear_start,linear_end=linear_end,cosine_s=cosine_s)self.loss_type=loss_typeself.learn_logvar=learn_logvarself.logvar=torch.full(fill_value=logvar_init,size=(self.num_timesteps,))ifself.learn_logvar:self.logvar=nn.Parameter(self.logvar,requires_grad=True)
The DDPM class is herited from the pl.LightningModule, includes the forward, reversem training and sampling processes.
In the init part,
timesteps=1000 is the diffusion steps, usually be 1000,
beta_schedule='linear' is the noise scheduler that controls \(\beta_t\)
parameterization="eps" is the training target, 'eps' means network predicted the noise, x0 means predict the denoied image
use_ema=True means the expotentinoal moving average to stabalize the training. LitEma(self.model) defines the ema model
logvar_init=0: the initial value of log var in the traiing. Can be choose to learn or use fixed value.
self.model = DiffusionWrapper(unet_config, conditioning_key) defines the U-Net used in the diffusion proces
defregister_schedule(self,given_betas=None,beta_schedule="linear",timesteps=1000,linear_start=1e-4,linear_end=2e-2,cosine_s=8e-3):ifexists(given_betas):betas=given_betaselse:betas=make_beta_schedule(beta_schedule,timesteps,linear_start=linear_start,linear_end=linear_end,cosine_s=cosine_s)alphas=1.-betasalphas_cumprod=np.cumprod(alphas,axis=0)alphas_cumprod_prev=np.append(1.,alphas_cumprod[:-1])timesteps,=betas.shapeself.num_timesteps=int(timesteps)self.linear_start=linear_startself.linear_end=linear_endassertalphas_cumprod.shape[0]==self.num_timesteps,'alphas have to be defined for each timestep'to_torch=partial(torch.tensor,dtype=torch.float32)self.register_buffer('betas',to_torch(betas))self.register_buffer('alphas_cumprod',to_torch(alphas_cumprod))self.register_buffer('alphas_cumprod_prev',to_torch(alphas_cumprod_prev))# calculations for diffusion q(x_t | x_{t-1}) and othersself.register_buffer('sqrt_alphas_cumprod',to_torch(np.sqrt(alphas_cumprod)))self.register_buffer('sqrt_one_minus_alphas_cumprod',to_torch(np.sqrt(1.-alphas_cumprod)))self.register_buffer('log_one_minus_alphas_cumprod',to_torch(np.log(1.-alphas_cumprod)))self.register_buffer('sqrt_recip_alphas_cumprod',to_torch(np.sqrt(1./alphas_cumprod)))self.register_buffer('sqrt_recipm1_alphas_cumprod',to_torch(np.sqrt(1./alphas_cumprod-1)))# calculations for posterior q(x_{t-1} | x_t, x_0)posterior_variance=(1-self.v_posterior)*betas*(1.-alphas_cumprod_prev)/(1.-alphas_cumprod)+self.v_posterior*betas# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)self.register_buffer('posterior_variance',to_torch(posterior_variance))# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chainself.register_buffer('posterior_log_variance_clipped',to_torch(np.log(np.maximum(posterior_variance,1e-20))))self.register_buffer('posterior_mean_coef1',to_torch(betas*np.sqrt(alphas_cumprod_prev)/(1.-alphas_cumprod)))self.register_buffer('posterior_mean_coef2',to_torch((1.-alphas_cumprod_prev)*np.sqrt(alphas)/(1.-alphas_cumprod)))ifself.parameterization=="eps":lvlb_weights=self.betas**2/(2*self.posterior_variance*to_torch(alphas)*(1-self.alphas_cumprod))elifself.parameterization=="x0":lvlb_weights=0.5*np.sqrt(torch.Tensor(alphas_cumprod))/(2.*1-torch.Tensor(alphas_cumprod))else:raiseNotImplementedError("mu not supported")# TODO how to choose this termlvlb_weights[0]=lvlb_weights[1]self.register_buffer('lvlb_weights',lvlb_weights,persistent=False)assertnottorch.isnan(self.lvlb_weights).all()
register_schedule is mainly used tp calculate the noise scheduler, including the \(\beta_t\) (noise), \(\alpha_t\) (original image info)、 accumulated \(\alpha_t\), (\(\bar{\alpha}_t\)), and mean and variance of posterior distribution \(q(x_{t-1} | x_t, x_0)\).
Here betas is \(\beta_t\),alphas is \(\alpha_t\),alphas_cumprod is \(\bar{\alpha}_t\).
Here is uses the make_beta_schedule functoin, which is
The function make_beta_schedule is designed to generate the noise schedule \(\beta_t\) used in the diffusion process. Different scheduling methods impact the convergence and final image quality of the diffusion model.
Input Parameters
schedule: Specifies the method for computing \(\beta_t\).
n_timestep: Total number of diffusion steps T, typically 1000.
linear_start & linear_end:
Linear schedule: \(\beta_t\) gradually increases from linear_start to linear_end.
Square root schedule: \(\beta_t\) increases based on a square root transformation.
cosine_s=8e-3:
Used in cosine schedule, typically for Stable Diffusion, improving stability.
The choice of schedule impacts model convergence, image quality, and noise control. Experiments are often necessary to determine the most suitable \(\beta_t\) scheduling method.
defq_sample(self,x_start,t,noise=None):noise=default(noise,lambda:torch.randn_like(x_start))return(extract_into_tensor(self.sqrt_alphas_cumprod,t,x_start.shape)*x_start+extract_into_tensor(self.sqrt_one_minus_alphas_cumprod,t,x_start.shape)*noise)defforward(self,x,*args,**kwargs):# input image shape is: [b, c, h, w]t=torch.randint(0,self.num_timesteps,(x.shape[0],),device=self.device).long()# random pick a time step t, as the argument of the modelreturnself.p_losses(x,t,*args,**kwargs)defp_losses(self,x_start,t,noise=None):noise=default(noise,lambda:torch.randn_like(x_start))x_noisy=self.q_sample(x_start=x_start,t=t,noise=noise)model_out=self.model(x_noisy,t)loss_dict={}ifself.parameterization=="eps":target=noiseelifself.parameterization=="x0":target=x_startelse:raiseNotImplementedError(f"Paramterization {self.parameterization} not yet supported")loss=self.get_loss(model_out,target,mean=False).mean(dim=[1,2,3])log_prefix='train'ifself.trainingelse'val'loss_dict.update({f'{log_prefix}/loss_simple':loss.mean()})loss_simple=loss.mean()*self.l_simple_weightloss_vlb=(self.lvlb_weights[t]*loss).mean()loss_dict.update({f'{log_prefix}/loss_vlb':loss_vlb})loss=loss_simple+self.original_elbo_weight*loss_vlbloss_dict.update({f'{log_prefix}/loss':loss})returnloss,loss_dict
There are two types of target, one is x_0, one is eps, corresponds to the noiseless image and the noise. We usually used the noise prediction. The loss is then used \(L_1\) or \(L_2\).
The loss typically is quite simple, but it gots another re-weighted copy which is
This code snippet is part of the loss computation in a diffusion model when the model is parameterized to predict the noise (often referred to as "eps" parameterization). Here’s what it does:
weights help balance the contribution of each term.
Accounting for the Noise Schedule and Posterior Variance:
self.betas represents the noise schedule parameters at each timestep.
self.posterior_variance is the variance of the reverse (posterior) process \( q(x_{t-1}\mid x_t, x_0) \).
alphas (converted to a torch tensor) and self.alphas_cumprod (the cumulative product of \( 1-\beta_t \)) are used to capture how the signal decays over time.
The formula
$$
\text{lvlb_weights} = \frac{\beta_t^2}{2 \cdot \text{posterior_variance} \cdot \alpha_t \cdot (1 - \text{alphas_cumprod}_t)}
$$
adjusts the loss term based on these factors, ensuring that the loss is appropriately scaled at each timestep.
Usage in the Loss Function:
When the model predicts \(\epsilon\) (the noise added to \( x_0 \)), the training loss often reduces to a weighted mean squared error (MSE) between the predicted and actual noise. The lvlb_weights are used to weight this error term, so that the model’s learning objective better approximates the true variational lower bound of the data likelihood.
We can visualize the curve of te weights according to \(t\), to see whether weights are focused on.
This seems that the weights are more focused on the \(t\) close to zero, which means the model focus more on the denosing steps that is close to the noiseless images.
So far, we have understand the ddpm process. Let's go further on the latent ddpm process
defq_mean_variance(self,x_start,t):""" Get the distribution q(x_t | x_0). :param x_start: the [N x C x ...] tensor of noiseless inputs. :param t: the number of diffusion steps (minus 1). Here, 0 means one step. :return: A tuple (mean, variance, log_variance), all of x_start's shape. """mean=(extract_into_tensor(self.sqrt_alphas_cumprod,t,x_start.shape)*x_start)variance=extract_into_tensor(1.0-self.alphas_cumprod,t,x_start.shape)log_variance=extract_into_tensor(self.log_one_minus_alphas_cumprod,t,x_start.shape)returnmean,variance,log_variancedefpredict_start_from_noise(self,x_t,t,noise):return(extract_into_tensor(self.sqrt_recip_alphas_cumprod,t,x_t.shape)*x_t-extract_into_tensor(self.sqrt_recipm1_alphas_cumprod,t,x_t.shape)*noise)defq_posterior(self,x_start,x_t,t):posterior_mean=(extract_into_tensor(self.posterior_mean_coef1,t,x_t.shape)*x_start+extract_into_tensor(self.posterior_mean_coef2,t,x_t.shape)*x_t)posterior_variance=extract_into_tensor(self.posterior_variance,t,x_t.shape)posterior_log_variance_clipped=extract_into_tensor(self.posterior_log_variance_clipped,t,x_t.shape)returnposterior_mean,posterior_variance,posterior_log_variance_clippeddefp_mean_variance(self,x,t,clip_denoised:bool):model_out=self.model(x,t)ifself.parameterization=="eps":x_recon=self.predict_start_from_noise(x,t=t,noise=model_out)elifself.parameterization=="x0":x_recon=model_outifclip_denoised:x_recon.clamp_(-1.,1.)model_mean,posterior_variance,posterior_log_variance=self.q_posterior(x_start=x_recon,x_t=x,t=t)returnmodel_mean,posterior_variance,posterior_log_variance@torch.no_grad()defp_sample(self,x,t,clip_denoised=True,repeat_noise=False):b,*_,device=*x.shape,x.devicemodel_mean,_,model_log_variance=self.p_mean_variance(x=x,t=t,clip_denoised=clip_denoised)noise=noise_like(x.shape,device,repeat_noise)# no noise when t == 0nonzero_mask=(1-(t==0).float()).reshape(b,*((1,)*(len(x.shape)-1)))returnmodel_mean+nonzero_mask*(0.5*model_log_variance).exp()*noise@torch.no_grad()defp_sample_loop(self,shape,return_intermediates=False):device=self.betas.deviceb=shape[0]img=torch.randn(shape,device=device)intermediates=[img]foriintqdm(reversed(range(0,self.num_timesteps)),desc='Sampling t',total=self.num_timesteps):img=self.p_sample(img,torch.full((b,),i,device=device,dtype=torch.long),clip_denoised=self.clip_denoised)ifi%self.log_every_t==0ori==self.num_timesteps-1:intermediates.append(img)ifreturn_intermediates:returnimg,intermediatesreturnimg@torch.no_grad()defsample(self,batch_size=16,return_intermediates=False):image_size=self.image_sizechannels=self.channelsreturnself.p_sample_loop((batch_size,channels,image_size,image_size),return_intermediates=return_intermediates)
We give the following explanation for the DDPM sampling steps and corresponds it to the original DDPM reverse process
Next, we focus on how the latent DM put the condition into the network. The training is not guided on the conditioning, it just puts the condition information into the betwork for conditional generation in a implicite architecture.
The diffusion model accepts three different types of conditions
The conditioning_key have "hybrid", "concat", "crossattn", "adm" types.
Concat
In the "Concat" format, the condition information is concated in to the original input, which require the same shape with the input \(x\). It can handles the conditions like
segmentation map
edge map
depthmap
any spacial condition with the same shape
CrossAttention
In this type, the condition should be a sequence of shape [B,L,D]. Which usually handles for conditions
texts (ex. from the clip model)
class embeddings
any conditions in sequence format
These infomation will be fed into the attention block in the UNet via cross attention
ADM
The condition shape if of [B,D] which is a single token embedding. Can handle conditions like
The information will be assigned to class label \(y\).
In the network, the class label condition is added into the time step condition
Input x Context
[b, n, dim] [b, m, dim]
│ │
↓ ├─────────────┐
to_q Linear ↓ ↓
│ to_k Linear to_v Linear
↓ │ │
[b, n, h*d] [b, m, h*d] [b, m, h*d]
│ │ │
↓ ↓ ↓
Reshape to heads Reshape to heads Reshape
[b*h, n, d] [b*h, m, d] [b*h, m, d]
│ │ │
│ │ │
└──→ Q·K^T * scale ←─┘ │
[b*h, n, m] │
│ │
↓ │
softmax(dim=-1) │
[b*h, n, m] │
│ │
└──────→ · V ←─────────┘
│
↓
[b*h, n, d]
│
↓
Reshape to batch
[b, n, h*d]
│
↓
Linear Output
│
↓
[b, n, query_dim]
Based on the above basic block, the overall structure is
"""SAMPLING ONLY."""importtorchimportnumpyasnpfromtqdmimporttqdmfromfunctoolsimportpartialfromldm.modules.diffusionmodules.utilimportmake_ddim_sampling_parameters,make_ddim_timesteps,noise_likeclassDDIMSampler(object):def__init__(self,model,schedule="linear",**kwargs):super().__init__()self.model=modelself.ddpm_num_timesteps=model.num_timestepsself.schedule=scheduledefregister_buffer(self,name,attr):iftype(attr)==torch.Tensor:ifattr.device!=torch.device("cuda"):attr=attr.to(torch.device("cuda"))setattr(self,name,attr)defmake_schedule(self,ddim_num_steps,ddim_discretize="uniform",ddim_eta=0.,verbose=True):self.ddim_timesteps=make_ddim_timesteps(ddim_discr_method=ddim_discretize,num_ddim_timesteps=ddim_num_steps,num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)alphas_cumprod=self.model.alphas_cumprodassertalphas_cumprod.shape[0]==self.ddpm_num_timesteps,'alphas have to be defined for each timestep'to_torch=lambdax:x.clone().detach().to(torch.float32).to(self.model.device)self.register_buffer('betas',to_torch(self.model.betas))self.register_buffer('alphas_cumprod',to_torch(alphas_cumprod))self.register_buffer('alphas_cumprod_prev',to_torch(self.model.alphas_cumprod_prev))# calculations for diffusion q(x_t | x_{t-1}) and othersself.register_buffer('sqrt_alphas_cumprod',to_torch(np.sqrt(alphas_cumprod.cpu())))self.register_buffer('sqrt_one_minus_alphas_cumprod',to_torch(np.sqrt(1.-alphas_cumprod.cpu())))self.register_buffer('log_one_minus_alphas_cumprod',to_torch(np.log(1.-alphas_cumprod.cpu())))self.register_buffer('sqrt_recip_alphas_cumprod',to_torch(np.sqrt(1./alphas_cumprod.cpu())))self.register_buffer('sqrt_recipm1_alphas_cumprod',to_torch(np.sqrt(1./alphas_cumprod.cpu()-1)))# ddim sampling parametersddim_sigmas,ddim_alphas,ddim_alphas_prev=make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),ddim_timesteps=self.ddim_timesteps,eta=ddim_eta,verbose=verbose)self.register_buffer('ddim_sigmas',ddim_sigmas)self.register_buffer('ddim_alphas',ddim_alphas)self.register_buffer('ddim_alphas_prev',ddim_alphas_prev)self.register_buffer('ddim_sqrt_one_minus_alphas',np.sqrt(1.-ddim_alphas))sigmas_for_original_sampling_steps=ddim_eta*torch.sqrt((1-self.alphas_cumprod_prev)/(1-self.alphas_cumprod)*(1-self.alphas_cumprod/self.alphas_cumprod_prev))self.register_buffer('ddim_sigmas_for_original_num_steps',sigmas_for_original_sampling_steps)@torch.no_grad()defsample(self,S,batch_size,shape,conditioning=None,callback=None,normals_sequence=None,img_callback=None,quantize_x0=False,eta=0.,mask=None,x0=None,temperature=1.,noise_dropout=0.,score_corrector=None,corrector_kwargs=None,verbose=True,x_T=None,log_every_t=100,unconditional_guidance_scale=1.,unconditional_conditioning=None,# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...**kwargs):ifconditioningisnotNone:ifisinstance(conditioning,dict):cbs=conditioning[list(conditioning.keys())[0]].shape[0]ifcbs!=batch_size:print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")else:ifconditioning.shape[0]!=batch_size:print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")self.make_schedule(ddim_num_steps=S,ddim_eta=eta,verbose=verbose)# samplingC,H,W=shapesize=(batch_size,C,H,W)print(f'Data shape for DDIM sampling is {size}, eta {eta}')samples,intermediates=self.ddim_sampling(conditioning,size,callback=callback,img_callback=img_callback,quantize_denoised=quantize_x0,mask=mask,x0=x0,ddim_use_original_steps=False,noise_dropout=noise_dropout,temperature=temperature,score_corrector=score_corrector,corrector_kwargs=corrector_kwargs,x_T=x_T,log_every_t=log_every_t,unconditional_guidance_scale=unconditional_guidance_scale,unconditional_conditioning=unconditional_conditioning,)returnsamples,intermediates@torch.no_grad()defddim_sampling(self,cond,shape,x_T=None,ddim_use_original_steps=False,callback=None,timesteps=None,quantize_denoised=False,mask=None,x0=None,img_callback=None,log_every_t=100,temperature=1.,noise_dropout=0.,score_corrector=None,corrector_kwargs=None,unconditional_guidance_scale=1.,unconditional_conditioning=None,):device=self.model.betas.deviceb=shape[0]ifx_TisNone:img=torch.randn(shape,device=device)else:img=x_TiftimestepsisNone:timesteps=self.ddpm_num_timestepsifddim_use_original_stepselseself.ddim_timestepseliftimestepsisnotNoneandnotddim_use_original_steps:subset_end=int(min(timesteps/self.ddim_timesteps.shape[0],1)*self.ddim_timesteps.shape[0])-1timesteps=self.ddim_timesteps[:subset_end]intermediates={'x_inter':[img],'pred_x0':[img]}time_range=reversed(range(0,timesteps))ifddim_use_original_stepselsenp.flip(timesteps)total_steps=timestepsifddim_use_original_stepselsetimesteps.shape[0]print(f"Running DDIM Sampling with {total_steps} timesteps")iterator=tqdm(time_range,desc='DDIM Sampler',total=total_steps)fori,stepinenumerate(iterator):index=total_steps-i-1ts=torch.full((b,),step,device=device,dtype=torch.long)ifmaskisnotNone:assertx0isnotNoneimg_orig=self.model.q_sample(x0,ts)# TODO: deterministic forward pass?img=img_orig*mask+(1.-mask)*imgouts=self.p_sample_ddim(img,cond,ts,index=index,use_original_steps=ddim_use_original_steps,quantize_denoised=quantize_denoised,temperature=temperature,noise_dropout=noise_dropout,score_corrector=score_corrector,corrector_kwargs=corrector_kwargs,unconditional_guidance_scale=unconditional_guidance_scale,unconditional_conditioning=unconditional_conditioning)img,pred_x0=outsifcallback:callback(i)ifimg_callback:img_callback(pred_x0,i)ifindex%log_every_t==0orindex==total_steps-1:intermediates['x_inter'].append(img)intermediates['pred_x0'].append(pred_x0)returnimg,intermediates@torch.no_grad()defp_sample_ddim(self,x,c,t,index,repeat_noise=False,use_original_steps=False,quantize_denoised=False,temperature=1.,noise_dropout=0.,score_corrector=None,corrector_kwargs=None,unconditional_guidance_scale=1.,unconditional_conditioning=None):b,*_,device=*x.shape,x.deviceifunconditional_conditioningisNoneorunconditional_guidance_scale==1.:e_t=self.model.apply_model(x,t,c)else:x_in=torch.cat([x]*2)t_in=torch.cat([t]*2)c_in=torch.cat([unconditional_conditioning,c])e_t_uncond,e_t=self.model.apply_model(x_in,t_in,c_in).chunk(2)e_t=e_t_uncond+unconditional_guidance_scale*(e_t-e_t_uncond)ifscore_correctorisnotNone:assertself.model.parameterization=="eps"e_t=score_corrector.modify_score(self.model,e_t,x,t,c,**corrector_kwargs)alphas=self.model.alphas_cumprodifuse_original_stepselseself.ddim_alphasalphas_prev=self.model.alphas_cumprod_previfuse_original_stepselseself.ddim_alphas_prevsqrt_one_minus_alphas=self.model.sqrt_one_minus_alphas_cumprodifuse_original_stepselseself.ddim_sqrt_one_minus_alphassigmas=self.model.ddim_sigmas_for_original_num_stepsifuse_original_stepselseself.ddim_sigmas# select parameters corresponding to the currently considered timestepa_t=torch.full((b,1,1,1),alphas[index],device=device)a_prev=torch.full((b,1,1,1),alphas_prev[index],device=device)sigma_t=torch.full((b,1,1,1),sigmas[index],device=device)sqrt_one_minus_at=torch.full((b,1,1,1),sqrt_one_minus_alphas[index],device=device)# current prediction for x_0pred_x0=(x-sqrt_one_minus_at*e_t)/a_t.sqrt()ifquantize_denoised:pred_x0,_,*_=self.model.first_stage_model.quantize(pred_x0)# direction pointing to x_tdir_xt=(1.-a_prev-sigma_t**2).sqrt()*e_tnoise=sigma_t*noise_like(x.shape,device,repeat_noise)*temperatureifnoise_dropout>0.:noise=torch.nn.functional.dropout(noise,p=noise_dropout)x_prev=a_prev.sqrt()*pred_x0+dir_xt+noisereturnx_prev,pred_x0
In DDIM (Denoising Diffusion Implicit Models), the reverse (denoising) step from \(x_t\) to \(x_{t-1}\) is given by:
where:
- \(x_t\) is the current noisy image.
- \(\epsilon_\theta(x_t,t)\) is the model’s predicted noise.
- \(\bar{\alpha}_t\) is the cumulative product of the \(\alpha\) values up to time \(t\).
where:
- \(\sqrt{\bar{\alpha}_{t-1}}\) scales the predicted \(x_0\).
- \(\sqrt{1-\bar{\alpha}_{t-1} - \sigma_t^2}\) scales the noise prediction.
- \(\sigma_t\) (controlled via the parameter \(\eta\)) governs the stochasticity; if \(\sigma_t=0\) (i.e. \(\eta=0\)), the process is deterministic.
- \(z\) is standard Gaussian noise.
The code implements this DDIM denoising process primarily in the function p_sample_ddim. Let’s break down the correspondence:
This branch adjusts the noise prediction \(e_t\) to steer the generation process in conditional settings. Although not part of the base DDIM equations, it is a common extension for improving sample quality.
DDIM Timesteps and Schedule:
The function make_schedule precomputes the necessary parameters (such as ddim_alphas, ddim_alphas_prev, ddim_sigmas, etc.) based on the chosen DDIM discretization schedule.
These parameters represent the cumulative products \(\bar{\alpha}_t\) and their variants, ensuring that the updates in p_sample_ddim use the correct coefficients at each timestep.
In summary
Prediction of \(x_0\):
The code computes
$$
\hat{x}0 = \frac{x_t - \sqrt{1-\bar{\alpha}_t}\,\epsilon\theta(x_t,t)}{\sqrt{\bar{\alpha}_t}}
$$
via
Directional Component and Noise Addition:
The update for \(x_{t-1}\) is given by
$$
x_{t-1} = \sqrt{\bar{\alpha}{t-1}}\,\hat{x}_0 + \sqrt{1-\bar{\alpha}\,\epsilon_\theta(x_t,t) + \sigma_t\,z,
$$
which is implemented as:} - \sigma_t^2
Control of Stochasticity:
The parameter \(\eta\) (used when creating ddim_sigmas) controls \(\sigma_t\). When \(\eta=0\), \(\sigma_t=0\) and the process is deterministic.
Additional Conditioning:
The code includes mechanisms for conditional generation and classifier-free guidance, which modify the noise prediction before performing the update.
This detailed comparison shows how the DDIM denoising process in the code mirrors the theoretical formulas, translating the mathematical update steps into practical sampling operations.
💬 Comments Share your thoughts!