stable diffusion中的UNet2DConditionModel代码解读

stable diffusion中的UNet2DConditionModel代码解读

    正在检查是否收录...

UNet2DConditionModel总体结构

图片来自于 https://zhuanlan.zhihu.com/p/635204519

stable diffusion 运行unet部分的代码。

noise_pred = self.unet( sample=latent_model_input, #(2,4,64,64) 生成的latent timestep=t, #时刻t encoder_hidden_states=prompt_embeds, #(2,77,768) #输入的prompt和negative prompt 生成的embedding timestep_cond=timestep_cond,#默认空 cross_attention_kwargs=self.cross_attention_kwargs, #默认空 added_cond_kwargs=added_cond_kwargs, #默认空 return_dict=False, )[0] 

1.time

get_time_embed使用了sinusoidal timestep embeddings,
time_embedding 使用了两个线性层和激活层进行映射,将320转换到1280。
如果还有class_labels,added_cond_kwargs等参数,也转换为embedding,并且相加。

t_emb = self.get_time_embed(sample=sample, timestep=timestep) #(2,320) emb = self.time_embedding(t_emb, timestep_cond) #(2,1280) 

2.pre-process

卷积转换,输入latent从(2,4,64,64) 到(2,320,64,64)

sample = self.conv_in(sample) #(2,320,64,64) self.conv_in = nn.Conv2d( in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding ) 

3.down

down_block 由三个CrossAttnDownBlock2D和一个DownBlock2D组成。输入包括:
hidden_states:latent
temb:时刻t的embdedding
encoder_hidden_states:prompt和negative prompt的embedding

网络结构

CrossAttnDownBlock2D( ResnetBlock2D() Transformer2DModel() ResnetBlock2D() Transformer2DModel() Downsample2D() #(2,320,32,32) ) CrossAttnDownBlock2D( ResnetBlock2D() Transformer2DModel() ResnetBlock2D() Transformer2DModel() Downsample2D() #(2,640,16,16) ) CrossAttnDownBlock2D( ResnetBlock2D() Transformer2DModel() ResnetBlock2D() Transformer2DModel() Downsample2D() #(2,1280,8,8) ) DownBlock2D( ResnetBlock2D() ResnetBlock2D() #(2,1280,8,8) ) 

4.mid

UNetMidBlock2DCrossAttn 包含 resnet,attn,resnet三个模块,输入输出维度不变。输入包括:
hidden_states:latent
temb,时刻t的embdedding
encoder_hidden_states:prompt和negative prompt的embedding

UNetMidBlock2DCrossAttn( ResnetBlock2D() Transformer2DModel() ResnetBlock2D() ) 

5.up

up由一个UpBlock2D和三个CrossAttnUpBlock2D组成,输入包括:
hidden_states:latent
temb: 时刻t的embdedding
encoder_hidden_states:prompt和negative prompt的embedding
res_hidden_states_tupleL:下采样时记录的残差结果。

UpBlock2D( ResnetBlock2D() ResnetBlock2D() ResnetBlock2D() Upsample2D() #(2,1280,16,16) ) CrossAttnUpBlock2D( ResnetBlock2D() Transformer2DModel() ResnetBlock2D() Transformer2DModel() ResnetBlock2D() Transformer2DModel() Downsample2D() #(2,1280,32,32) ) CrossAttnUpBlock2D( ResnetBlock2D() Transformer2DModel() ResnetBlock2D() Transformer2DModel() ResnetBlock2D() Transformer2DModel() Downsample2D() #(2,640,64,64) ) CrossAttnUpBlock2D( ResnetBlock2D() #(2,320,64,64) Transformer2DModel() ResnetBlock2D() Transformer2DModel() ResnetBlock2D() Transformer2DModel() ) 

6.post-process

卷积变换通道数,得到最终结果

 if self.conv_norm_out: sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample) #(2,4,64,64) 

时刻t,类别class等参数作用在resnet部分,都是和输入直接相加。
由prompt,negative prompt 计算得到的encoder_hidden_states,作用在attention部分,作为key和value,参与计算。

ResnetBlock2D

x在标准化、激活、卷积之后,和temb相加,再次标准化、激活、卷积之后作为残差,与x相加。

hidden_states = input_tensor hidden_states = self.norm1(hidden_states) hidden_states = self.nonlinearity(hidden_states) #激活函数 hidden_states = self.conv1(hidden_states) if self.time_emb_proj is not None: if not self.skip_time_act: temb = self.nonlinearity(temb) temb = self.time_emb_proj(temb)[:, :, None, None] #(2,320,1,1) if self.time_embedding_norm == "default": if temb is not None: hidden_states = hidden_states + temb #与temb相加 hidden_states = self.norm2(hidden_states) hidden_states = self.nonlinearity(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.conv2(hidden_states) output_tensor = (input_tensor + hidden_states) / self.output_scale_factor return output_tensor 

Transformer2DModel attentions部分

每个attention 包括 Self-Attention 和Cross-Attention两部分。

#Self-Attention ,encoder_hidden_states=None attn_output = self.attn1( norm_hidden_states, encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, attention_mask=attention_mask, **cross_attention_kwargs, ) #Cross-Attention,encoder_hidden_states由prompt计算得来,在这里和latent交互。 attn_output = self.attn2( norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=encoder_attention_mask, **cross_attention_kwargs, ) #query由norm_hidden_states计算而来, #key、value由encoder_hidden_states计算而来。 query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) #(2,8,4096,40) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) #(2,8,77,40) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) #(2,8,77,40) hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False #(2,8,4096,40) ) 

参考:stable diffusion 中使用的 UNet 2D Condition Model 结构解析(diffusers库)

总结

### UNet2DConditionModel 总体结构概述
UNet2DConditionModel 是在 Stable Diffusion 框架中用于处理和生成图像的关键组件之一。该模型通过结合卷积神经网络(CNN)和transformer注意力机制,实现了对图像潜在表示的精准操控和生成。下面是对其总体结构的详细总结:
#### 输入与初始化
- **输入**: 模型接收多个输入,包括 `latent_model_input`(图像的潜在表示)、`timestep`(时间步长,用于控制生成过程中的演进)、`prompt_embeds`(输入提示生成的embedding)、以及可选的几个条件参数(如`timestep_cond`等)。
- **初始化**: 使用 sinusoidal timestep embeddings(正弦时间步长嵌入)对时间步长进行处理,将其映射到更高维度的 `temb` 中,可能还会结合其他条件嵌入(如类别标签等)进行增强。
#### 数据预处理
- **初始卷积**: 通过一系列卷积操作将输入的latent从 `(2, 4, 64, 64)` 转换成 `(2, 320, 64, 64)`,以便后续处理。
#### 下降采样(Downsampling)
- 由三个 `CrossAttnDownBlock2D` 和一个 `DownBlock2D` 组成:
- 每个 `CrossAttnDownBlock2D` 包含两个ResnetBlock2D、两个Transformer2DModel,和一个Downsample2D。
- 逐步实现下采样,并将特征图的尺寸减小到 `(2, 1280, 8, 8)`。
#### 中间层(Mid)
- 使用 `UNetMidBlock2DCrossAttn`,通过Resnet和Attention组合维持特征图的尺寸 `(2, 1280, 8, 8)`,进一步加深和细化特征。
#### 上升采样(Upsampling)
- 由一个 `UpBlock2D` 和三个 `CrossAttnUpBlock2D` 组成:
- 逐步提升特征图的尺寸,并与下采样过程中的残差结果进行连接(skip connections),增强特征复用。
- 最终恢复到接近原始输入的尺寸 `(2, 320, 64, 64)`。
#### 数据后处理
- 经过最后的卷积层,将特征图通道的维度调整回 `(2, 4, 64, 64)`,这一步骤包括卷积和可能的归一化及激活函数处理。
#### 核心组成
- **ResnetBlock2D**: 在Resnet块中,引入时间步长嵌入 (`temb`) 与输入特征相加,增强特征的时间依赖性。
- **Transformer2DModel**: 结合 Self-Attention 和 Cross-Attention,使模型具备对输入prompt的全面理解能力,进而指导图像的生成。Self-Attention 处理latent的内部交互,Cross-Attention 利用prompt信息对latent特征进行调制。
#### 总结
UNet2DConditionModel通过结合高效的卷积结构和强大的注意力机制,在Stable Diffusion中实现了高度的特征表示能力。这种方式不仅提升了图像的生成质量,还增强了对prompt等控制条件的响应灵敏度,进而为用户提供了更为灵活的图像创作体验。 transformerpromptsatcodetpudiffusionstable diffusionpsa注意力注意力机制神经网络卷积神经网络生成质量控制生成数据预处理图像创作提示生成cto
  • 本文作者:李琛
  • 本文链接: https://wapzz.net/post-18934.html
  • 版权声明:本博客所有文章除特别声明外,均默认采用 CC BY-NC-SA 4.0 许可协议。
本站部分内容来源于网络转载,仅供学习交流使用。如涉及版权问题,请及时联系我们,我们将第一时间处理。
文章很赞!支持一下吧 还没有人为TA充电
为TA充电
还没有人为TA充电
0
  • 支付宝打赏
    支付宝扫一扫
  • 微信打赏
    微信扫一扫
感谢支持
文章很赞!支持一下吧
关于作者
2.3W+
5
0
1
WAP站长官方

PHP采集页面的四种方法

上一篇

快速体验LLaMA-Factory 私有化部署和高效微调Llama3模型(曙光超算互联网平台异构加速卡DCU)

下一篇
  • 复制图片
按住ctrl可打开默认菜单