Skip Layer Guidance (SLG):
- SD3.5 introduces this novel technique that selectively skips specific transformer layers (typically 7-9)
- Only active during 1-20% of the sampling process
- Uses a separate guidance scale (2.5 in SD3.5 Medium) distinct from the main CFG scale
- Significantly improves image coherence and reduces artifacts
MMDiTX vs MMDiT:
- SD3.5 uses the enhanced MMDiTX architecture with more flexible configuration options
- Added support for cross-block attention with x_block_self_attn
- Improved normalization with RMSNorm support as an alternative to LayerNorm
- Better parameter management and more modular design
Sampling Improvements
Default Samplers:
- SD3: Uses euler sampler as default
- SD3.5: Uses dpmpp_2m (DPM-Solver++) sampler for better quality
Noise Scheduling:
- SD3: Uses shift=1.0
- SD3.5: Uses shift=3.0 for improved noise distribution
ControlNet Integration:
- Native support for various ControlNet types (blur, canny, depth)
- Dedicated ControlNetEmbedder class for processing control inputs
- Support for 8-bit and 2-bit ControlNet variations
Attention Mechanisms:
- More configurable attention with qk_norm options
- Enhanced cross-attention capabilities
- Better handling of long-range dependencies
Technical Implementation
Code Quality:
More modular design in SD3.5
Better type hinting and parameter validation
Enhanced error handling and debugging capabilities
Performance:
More efficient attention mechanisms
Better memory management
Support for different precision modes
In this article, we will study the differences in architecture, such as skip layer guidance and MM-DiTX. We will also explore how ControlNet is implemented in SD3.5.
As for the elements that are similar to SD 3, including the VAE, prompt processing, and sampling scheme, the differences are not significant. Please refer to the previous article stable diffusion 3 reading for more information.
Apparently, the fingers look better. This could be evidence that supports the claimed benefits (improved anatomy). However, other aspects of the image also change.
The skip layer guidance is only active during the initial 1-20% of the sampling process, targeting layers [7, 8, 9], and scaling the CFG to 4.0.
In the MM-DiTX implementation, the skip layers are treated as identity functions:
123456
```py3
for i, block in enumerate(self.joint_blocks):
if i in skip_layers:
continue
context, x = block(context, x, c=c_mod)
```
Both pos_out and skip_layer_out use the same positive condition but differ in their treatment of skip layers. If we consider the skipped layers as a negative condition, this effectively pushes the sample away from that negative influence. What does this negative influence represent when removing layers 7, 8, and 9 (or any specific layers)? If we assume that specific layers are responsible for different features in the image—for example, if layers 7, 8, and 9 handle finer details—then the negative condition would produce images with poor fine structure. Therefore, moving away from this negative influence results in images with enhanced fine details and better structural integrity.
defpre_attention(self,x:torch.Tensor,c:torch.Tensor):assertxisnotNone,"pre_attention called with None input"ifnotself.pre_only:ifnotself.scale_mod_only:shift_msa,scale_msa,gate_msa,shift_mlp,scale_mlp,gate_mlp=self.adaLN_modulation(c).chunk(6,dim=1)else:shift_msa=Noneshift_mlp=Nonescale_msa,gate_msa,scale_mlp,gate_mlp=self.adaLN_modulation(c).chunk(4,dim=1)qkv=self.attn.pre_attention(modulate(self.norm1(x),shift_msa,scale_msa))returnqkv,(x,gate_msa,shift_mlp,scale_mlp,gate_mlp)else:ifnotself.scale_mod_only:shift_msa,scale_msa=self.adaLN_modulation(c).chunk(2,dim=1)else:shift_msa=Nonescale_msa=self.adaLN_modulation(c)qkv=self.attn.pre_attention(modulate(self.norm1(x),shift_msa,scale_msa))returnqkv,None
compared with the pre_attention, the pre_attention_x process x twice.
defpost_attention_x(self,attn,attn2,x,gate_msa,shift_mlp,scale_mlp,gate_mlp,gate_msa2,attn1_dropout:float=0.0,):assertnotself.pre_onlyifattn1_dropout>0.0:# Use torch.bernoulli to implement dropout, only dropout the batch dimensionattn1_dropout=torch.bernoulli(torch.full((attn.size(0),1,1),1-attn1_dropout,device=attn.device))attn_=(gate_msa.unsqueeze(1)*self.attn.post_attention(attn)*attn1_dropout)else:attn_=gate_msa.unsqueeze(1)*self.attn.post_attention(attn)x=x+attn_attn2_=gate_msa2.unsqueeze(1)*self.attn2.post_attention(attn2)x=x+attn2_mlp_=gate_mlp.unsqueeze(1)*self.mlp(modulate(self.norm2(x),shift_mlp,scale_mlp))x=x+mlp_returnx
defblock_mixing(context,x,context_block,x_block,c):assertcontextisnotNone,"block_mixing called with None context"context_qkv,context_intermediates=context_block.pre_attention(context,c)ifx_block.x_block_self_attn:x_qkv,x_qkv2,x_intermediates=x_block.pre_attention_x(x,c)else:x_qkv,x_intermediates=x_block.pre_attention(x,c)q,k,v=tuple(torch.cat(tuple(qkv[i]forqkvin[context_qkv,x_qkv]),dim=1)foriinrange(3))attn=attention(q,k,v,x_block.attn.num_heads)context_attn,x_attn=(attn[:,:context_qkv[0].shape[1]],attn[:,context_qkv[0].shape[1]:],)ifnotcontext_block.pre_only:context=context_block.post_attention(context_attn,*context_intermediates)else:context=Noneifx_block.x_block_self_attn:x_q2,x_k2,x_v2=x_qkv2attn2=attention(x_q2,x_k2,x_v2,x_block.attn2.num_heads)x=x_block.post_attention_x(x_attn,attn2,*x_intermediates)else:x=x_block.post_attention(x_attn,*x_intermediates)returncontext,x
defblock_mixing(context,x,context_block,x_block,c):assertcontextisnotNone,"block_mixing called with None context"context_qkv,context_intermediates=context_block.pre_attention(context,c)x_qkv,x_intermediates=x_block.pre_attention(x,c)o=[]fortinrange(3):o.append(torch.cat((context_qkv[t],x_qkv[t]),dim=1))q,k,v=tuple(o)attn=attention(q,k,v,x_block.attn.num_heads)context_attn,x_attn=(attn[:,:context_qkv[0].shape[1]],attn[:,context_qkv[0].shape[1]:])ifnotcontext_block.pre_only:context=context_block.post_attention(context_attn,*context_intermediates)else:context=Nonex=x_block.post_attention(x_attn,*x_intermediates)returncontext,x
if 'x_block_self_attn' is False in the blocking_mixing, then it is same as old version. If it true, then
1 2 3 4 5 6 7 8 91011121314151617181920212223
```py3
def block_mixing(context, x, context_block, x_block, c):
assert context is not None, "block_mixing called with None context"
context_qkv, context_intermediates = context_block.pre_attention(context, c)
x_qkv, x_qkv2, x_intermediates = x_block.pre_attention_x(x, c)
q, k, v = tuple(
torch.cat(tuple(qkv[i] for qkv in [context_qkv, x_qkv]), dim=1)
for i in range(3)
)
attn = attention(q, k, v, x_block.attn.num_heads)
context_attn, x_attn = (
attn[:, : context_qkv[0].shape[1]],
attn[:, context_qkv[0].shape[1] :],
)
if not context_block.pre_only:
context = context_block.post_attention(context_attn, *context_intermediates)
else:
context = None
x_q2, x_k2, x_v2 = x_qkv2
attn2 = attention(x_q2, x_k2, x_v2, x_block.attn2.num_heads)
x = x_block.post_attention_x(x_attn, attn2, *x_intermediates)
return context, x
```
This added another attention for the copy of \(x\), the main purpise is to increase the capability for latent path, which in sd3, the path for 'context' and 'latent' is symmetric.
defforward_core_with_concat(self,x:torch.Tensor,c_mod:torch.Tensor,context:Optional[torch.Tensor]=None,skip_layers:Optional[List]=[],controlnet_hidden_states:Optional[torch.Tensor]=None,)->torch.Tensor:ifself.register_length>0:context=torch.cat((repeat(self.register,"1 ... -> b ...",b=x.shape[0]),contextifcontextisnotNoneelsetorch.Tensor([]).type_as(x),),1,)# context is B, L', D# x is B, L, Dfori,blockinenumerate(self.joint_blocks):ifiinskip_layers:continuecontext,x=block(context,x,c=c_mod)ifcontrolnet_hidden_statesisnotNone:controlnet_block_interval=len(self.joint_blocks)//len(controlnet_hidden_states)x=x+controlnet_hidden_states[i//controlnet_block_interval]x=self.final_layer(x,c_mod)# (N, T, patch_size ** 2 * out_channels)returnx
The difference is that in sd3.5, there is a controlnet_hidden_states input. At every block, add the controlnet hidden states to the input after MM-DiT block. See more details on the controlnet study in control-net. It is equivalent to add an increment in each block. And since the control-net is decoupled, we can train the diffusion first, and then train the controlnet with diffusion model being freezed to simplify the training process.
The controlnet condition, usually the depthmap, blur, canny, is processed in the same way as the input image, and then prepared into a special kind condition with key controlnet_cond.
💬 Comments Share your thoughts!