Previous training discard images under 512 pixels which could discard large portion of data, lead to a loss in performance and generalization
we provide the original (i.e., before any rescaling) height and width of the images as an additional conditioning to the model csize = (h-original,w-original). Each component is independently embedded using a Fourier feature encoding, and these encodings are concatenated into a single vector that we feed into the model by adding it to the timestep embedding
Random cropping during training coulde leads to incomplete generation like the following. So we put it in the condition and set (\(c_top,c_left\)) be zeros to obtained the object centered samples. Further, we can tune the two parameters to simulate the amount of cropping during inference.
Method
During dataloading, we uniformly sample crop coordinates ctop and cleft (integers specifying the amount of pixels cropped from the top-left corner along the height and width axes, respectively) and feed them into the model as conditioning parameters via Fourier feature embeddings, similar to the size conditioning described above
Most text2image models produces saure images.
- Trainig tricks
- Prepare different bukets of images, each bucket has the same shape, while the total number of pixels is approximaly $1024^2$.
- During training, single batch comes from same bucket, and change the bucket for different step in the training loop
- Condition tricks
- similar to the size condition and crop-parameter condition, the target shape $(h_{target},w_{target})$ is embedded into a Fourier space
Difficulty with Intricate Structures:
The model sometimes struggles with synthesizing fine details in complex structures (e.g., human hands) due to high variance and the challenge of extracting accurate 3D shape information.
Imperfect Photorealism:
While the generated images are very realistic, they may lack certain nuances such as subtle lighting effects or fine texture variations.
Biases in Training Data:
The heavy reliance on large-scale datasets can inadvertently introduce social and racial biases, which may be reflected in the generated outputs.
Concept Bleeding:
The model can sometimes merge or incorrectly bind attributes between distinct objects (e.g., mixing up the colors of a hat and gloves or an orange sunglass bleeding from an orange sweater).
Text Rendering Issues:
The model encounters challenges in rendering long, legible text, occasionally producing random characters or inconsistent text output.
classGeneralConditioner(nn.Module):OUTPUT_DIM2KEYS={2:"vector",3:"crossattn",4:"concat"}# , 5: "concat"}KEY2CATDIM={"vector":1,"crossattn":2,"concat":1,"cond_view":1,"cond_motion":1}def__init__(self,emb_models:Union[List,ListConfig]):super().__init__()embedders=[]forn,embconfiginenumerate(emb_models):embedder=instantiate_from_config(embconfig)assertisinstance(embedder,AbstractEmbModel),f"embedder model {embedder.__class__.__name__} has to inherit from AbstractEmbModel"embedder.is_trainable=embconfig.get("is_trainable",False)embedder.ucg_rate=embconfig.get("ucg_rate",0.0)ifnotembedder.is_trainable:embedder.train=disabled_trainforparaminembedder.parameters():param.requires_grad=Falseembedder.eval()print(f"Initialized embedder #{n}: {embedder.__class__.__name__} "f"with {count_params(embedder,False)} params. Trainable: {embedder.is_trainable}")if"input_key"inembconfig:embedder.input_key=embconfig["input_key"]elif"input_keys"inembconfig:embedder.input_keys=embconfig["input_keys"]else:raiseKeyError(f"need either 'input_key' or 'input_keys' for embedder {embedder.__class__.__name__}")embedder.legacy_ucg_val=embconfig.get("legacy_ucg_value",None)ifembedder.legacy_ucg_valisnotNone:embedder.ucg_prng=np.random.RandomState()embedders.append(embedder)self.embedders=nn.ModuleList(embedders)defpossibly_get_ucg_val(self,embedder:AbstractEmbModel,batch:Dict)->Dict:assertembedder.legacy_ucg_valisnotNonep=embedder.ucg_rateval=embedder.legacy_ucg_valforiinrange(len(batch[embedder.input_key])):ifembedder.ucg_prng.choice(2,p=[1-p,p]):batch[embedder.input_key][i]=valreturnbatchdefforward(self,batch:Dict,force_zero_embeddings:Optional[List]=None)->Dict:output=dict()ifforce_zero_embeddingsisNone:force_zero_embeddings=[]forembedderinself.embedders:embedding_context=nullcontextifembedder.is_trainableelsetorch.no_gradwithembedding_context():ifhasattr(embedder,"input_key")and(embedder.input_keyisnotNone):ifembedder.legacy_ucg_valisnotNone:batch=self.possibly_get_ucg_val(embedder,batch)emb_out=embedder(batch[embedder.input_key])elifhasattr(embedder,"input_keys"):emb_out=embedder(*[batch[k]forkinembedder.input_keys])assertisinstance(emb_out,(torch.Tensor,list,tuple)),f"encoder outputs must be tensors or a sequence, but got {type(emb_out)}"ifnotisinstance(emb_out,(list,tuple)):emb_out=[emb_out]forembinemb_out:ifembedder.input_keyin["cond_view","cond_motion"]:out_key=embedder.input_keyelse:out_key=self.OUTPUT_DIM2KEYS[emb.dim()]ifembedder.ucg_rate>0.0andembedder.legacy_ucg_valisNone:emb=(expand_dims_like(torch.bernoulli((1.0-embedder.ucg_rate)*torch.ones(emb.shape[0],device=emb.device)),emb,)*emb)if(hasattr(embedder,"input_key")andembedder.input_keyinforce_zero_embeddings):emb=torch.zeros_like(emb)ifout_keyinoutput:output[out_key]=torch.cat((output[out_key],emb),self.KEY2CATDIM[out_key])else:output[out_key]=embreturnoutputdefget_unconditional_conditioning(self,batch_c:Dict,batch_uc:Optional[Dict]=None,force_uc_zero_embeddings:Optional[List[str]]=None,force_cond_zero_embeddings:Optional[List[str]]=None,):ifforce_uc_zero_embeddingsisNone:force_uc_zero_embeddings=[]ucg_rates=list()forembedderinself.embedders:ucg_rates.append(embedder.ucg_rate)embedder.ucg_rate=0.0c=self(batch_c,force_cond_zero_embeddings)uc=self(batch_cifbatch_ucisNoneelsebatch_uc,force_uc_zero_embeddings)forembedder,rateinzip(self.embedders,ucg_rates):embedder.ucg_rate=ratereturnc,uc
Convet the original input to the embeding with different embedding model. Have three different kind of conditoin types.
1. vector: similar to time embedding
2. crossattn: similar to text embedding which is a sequence
3. concat: shape of [B,C,H,W] that can concat with original input like the depth, low-res, and masked image condition
It provides three different embedding reresentation
1. last hidder state: [B,S,D]
2. pooled hidden state: [B,D]
3. hidden stage correspoding to given layer index: [B,S,D]
In this model, it uses the 11-th layer
classFrozenOpenCLIPEmbedder2(AbstractEmbModel):LAYERS=["pooled","last","penultimate"]def**init**(self,arch="ViT-H-14",version="laion2b_s32b_b79k",device="cuda",max_length=77,freeze=True,layer="last",always_return_pooled=False,legacy=True,):super().**init**()assertlayerinself.LAYERSmodel,_,_=open_clip.create_model_and_transforms(arch,device=torch.device("cpu"),pretrained=version,)delmodel.visualself.model=modelself.device=deviceself.max_length=max_lengthself.return_pooled=always_return_poolediffreeze:self.freeze()self.layer=layerifself.layer=="last":self.layer_idx=0elifself.layer=="penultimate":self.layer_idx=1else:raiseNotImplementedError()self.legacy=legacy@autocastdefforward(self,text):tokens=open_clip.tokenize(text)z=self.encode_with_transformer(tokens.to(self.device))ifnotself.return_pooledandself.legacy:returnzifself.return_pooled:assertnotself.legacyreturnz[self.layer],z["pooled"]returnz[self.layer]defencode_with_transformer(self,text):x=self.model.token_embedding(text)# [batch_size, n_ctx, d_model]x=x+self.model.positional_embeddingx=x.permute(1,0,2)# NLD -> LNDx=self.text_transformer_forward(x,attn_mask=self.model.attn_mask)ifself.legacy:x=x[self.layer]x=self.model.ln_final(x)returnxelse:# x is a dict and will stay a dicto=x["last"]o=self.model.ln_final(o)pooled=self.pool(o,text)x["pooled"]=pooledreturnxdefpool(self,x,text):# take features from the eot embedding (eot_token is the highest number in each sequence)x=(x[torch.arange(x.shape[0]),text.argmax(dim=-1)]@self.model.text_projection)returnxdeftext_transformer_forward(self,x:torch.Tensor,attn_mask=None):outputs={}fori,rinenumerate(self.model.transformer.resblocks):ifi==len(self.model.transformer.resblocks)-1:outputs["penultimate"]=x.permute(1,0,2)# LND -> NLDif(self.model.transformer.grad_checkpointingandnottorch.jit.is_scripting()):x=checkpoint(r,x,attn_mask)else:x=r(x,attn_mask=attn_mask)outputs["last"]=x.permute(1,0,2)# LND -> NLDreturnoutputs
In this model, it used ViT-bigG-14. It also supports three different types of embedding. Here it chosed 'penultimate', (second last) layer and returned pooled hidden state.
If returned pool is true, this embedder will output two hidden stages: origina one and pooled one, and this two will be treated differently.
pooled embedding will be treated as vector embedding same as the process of time step embedding.
Original hidden stage will be treated as sequential embedding which fed into the U-Net through the cross-attention.
classConcatTimestepEmbedderND(AbstractEmbModel):"""embeds each dimension independently and concatenates them"""def__init__(self,outdim):super().__init__()self.timestep=Timestep(outdim)self.outdim=outdimdefforward(self,x):ifx.ndim==1:x=x[:,None]assertlen(x.shape)==2b,dims=x.shape[0],x.shape[1]x=rearrange(x,"b d -> (b d)")emb=self.timestep(x)emb=rearrange(emb,"(b d) d2 -> b (d d2)",b=b,d=dims,d2=self.outdim)returnemb
this embedding handles the 'scalar' conditions including the class label, time step, cropping parameter, original image size, target image size
deftimestep_embedding(timesteps,dim,max_period=10000,repeat_only=False):""" Create sinusoidal timestep embeddings. :param timesteps: a 1-D Tensor of N indices, one per batch element. These may be fractional. :param dim: the dimension of the output. :param max_period: controls the minimum frequency of the embeddings. :return: an [N x dim] Tensor of positional embeddings. """ifnotrepeat_only:half=dim//2freqs=torch.exp(-math.log(max_period)*torch.arange(start=0,end=half,dtype=torch.float32)/half).to(device=timesteps.device)args=timesteps[:,None].float()*freqs[None]embedding=torch.cat([torch.cos(args),torch.sin(args)],dim=-1)ifdim%2:embedding=torch.cat([embedding,torch.zeros_like(embedding[:,:1])],dim=-1)else:embedding=repeat(timesteps,"b -> b d",d=dim)returnembedding
Here are the formulas expressed in mathematical notation using the English exp and log functions:
Frequency Calculation:
Let
\[
H = \left\lfloor \frac{D}{2} \right\rfloor,
\]
where \( D \) is the embedding dimension. For each index \( i \) (with \( 0 \le i < H \)), define the frequency as:
classStandardDiffusionLoss(nn.Module):def__init__(self,sigma_sampler_config:dict,loss_weighting_config:dict,loss_type:str="l2",offset_noise_level:float=0.0,batch2model_keys:Optional[Union[str,List[str]]]=None,):super().__init__()assertloss_typein["l2","l1","lpips"]self.sigma_sampler=instantiate_from_config(sigma_sampler_config)self.loss_weighting=instantiate_from_config(loss_weighting_config)self.loss_type=loss_typeself.offset_noise_level=offset_noise_levelifloss_type=="lpips":self.lpips=LPIPS().eval()ifnotbatch2model_keys:batch2model_keys=[]ifisinstance(batch2model_keys,str):batch2model_keys=[batch2model_keys]self.batch2model_keys=set(batch2model_keys)defget_noised_input(self,sigmas_bc:torch.Tensor,noise:torch.Tensor,input:torch.Tensor)->torch.Tensor:noised_input=input+noise*sigmas_bcreturnnoised_inputdefforward(self,network:nn.Module,denoiser:Denoiser,conditioner:GeneralConditioner,input:torch.Tensor,batch:Dict,)->torch.Tensor:cond=conditioner(batch)returnself._forward(network,denoiser,cond,input,batch)def_forward(self,network:nn.Module,denoiser:Denoiser,cond:Dict,input:torch.Tensor,batch:Dict,)->Tuple[torch.Tensor,Dict]:additional_model_inputs={key:batch[key]forkeyinself.batch2model_keys.intersection(batch)}sigmas=self.sigma_sampler(input.shape[0]).to(input)noise=torch.randn_like(input)ifself.offset_noise_level>0.0:offset_shape=((input.shape[0],1,input.shape[2])ifself.n_framesisnotNoneelse(input.shape[0],input.shape[1]))noise=noise+self.offset_noise_level*append_dims(torch.randn(offset_shape,device=input.device),input.ndim,)sigmas_bc=append_dims(sigmas,input.ndim)noised_input=self.get_noised_input(sigmas_bc,noise,input)model_output=denoiser(network,noised_input,sigmas,cond,**additional_model_inputs)w=append_dims(self.loss_weighting(sigmas),input.ndim)returnself.get_loss(model_output,input,w)defget_loss(self,model_output,target,w):ifself.loss_type=="l2":returntorch.mean((w*(model_output-target)**2).reshape(target.shape[0],-1),1)elifself.loss_type=="l1":returntorch.mean((w*(model_output-target).abs()).reshape(target.shape[0],-1),1)elifself.loss_type=="lpips":loss=self.lpips(model_output,target).reshape(-1)returnlosselse:raiseNotImplementedError(f"Unknown loss type {self.loss_type}")
Fot the EDM training, the sampling of sigma is continuous class
💬 Comments Share your thoughts!