AIGC笔记--Stable Diffusion源码剖析之UNetModel
1--前言
以论文《High-Resolution Image Synthesis with Latent Diffusion Models》 开源的项目为例,剖析Stable Diffusion经典组成部分,巩固学习加深印象。
2--UNetModel
一个可以debug的小demo:SD_UNet
以文生图为例,剖析UNetModel核心组成模块。
2-1--Forward总揽
提供的文生图Demo中,实际传入的参数只有x、timesteps和context三个,其中:
x 表示随机初始化的噪声Tensor(shape: [B*2, 4, 64, 64],*2表示使用Classifier-Free Diffusion Guidance)。
timesteps 表示去噪过程中每一轮传入的timestep(shape: [B*2])。
context表示经过CLIP编码后对应的文本Prompt(shape: [B*2, 77, 768])。
 def forward(self, x, timesteps=None, context=None, y=None,**kwargs): """ Apply the model to an input batch. :param x: an [N x C x ...] Tensor of inputs. :param timesteps: a 1-D batch of timesteps. :param context: conditioning plugged in via crossattn :param y: an [N] Tensor of labels, if class-conditional. :return: an [N x C x ...] Tensor of outputs. """ assert (y is not None) == ( self.num_classes is not None ), "must specify y if and only if the model is class-conditional" hs = [] t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) # Create sinusoidal timestep embeddings. emb = self.time_embed(t_emb) # MLP if self.num_classes is not None: assert y.shape == (x.shape[0],) emb = emb + self.label_emb(y) h = x.type(self.dtype) for module in self.input_blocks: h = module(h, emb, context) hs.append(h) h = self.middle_block(h, emb, context) for module in self.output_blocks: h = th.cat([h, hs.pop()], dim=1) h = module(h, emb, context) h = h.type(x.dtype) if self.predict_codebook_ids: return self.id_predictor(h) else: return self.out(h)2-2--timestep embedding生成
使用函数 timestep_embedding() 和 self.time_embed() 对传入的timestep进行位置编码,生成sinusoidal timestep embeddings。
其中 timestep_embedding() 函数定义如下,而self.time_embed()是一个MLP函数。
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): """ Create sinusoidal timestep embeddings. :param timesteps: a 1-D Tensor of N indices, one per batch element. These may be fractional. :param dim: the dimension of the output. :param max_period: controls the minimum frequency of the embeddings. :return: an [N x dim] Tensor of positional embeddings. """ if not repeat_only: half = dim // 2 freqs = torch.exp( -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half ).to(device=timesteps.device) args = timesteps[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) else: embedding = repeat(timesteps, 'b -> b d', d=dim) return embeddingself.time_embed = nn.Sequential( linear(model_channels, time_embed_dim), nn.SiLU(), linear(time_embed_dim, time_embed_dim), )2-3--self.input_blocks下采样
在 Forward() 中,使用 self.input_blocks 将输入噪声进行分辨率下采样,经过下采样具体维度变化为:[B*2, 4, 64, 64] > [B*2, 1280, 8, 8];
下采样模块共有12个 module,其组成如下:
ModuleList( (0): TimestepEmbedSequential( (0): Conv2d(4, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (1-2): 2 x TimestepEmbedSequential( (0): ResBlock( (in_layers): Sequential( (0): GroupNorm32(32, 320, eps=1e-05, affine=True) (1): SiLU() (2): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (h_upd): Identity() (x_upd): Identity() (emb_layers): Sequential( (0): SiLU() (1): Linear(in_features=1280, out_features=320, bias=True) ) (out_layers): Sequential( (0): GroupNorm32(32, 320, eps=1e-05, affine=True) (1): SiLU() (2): Dropout(p=0, inplace=False) (3): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (skip_connection): Identity() ) (1): SpatialTransformer( (norm): GroupNorm(32, 320, eps=1e-06, affine=True) (proj_in): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1)) (transformer_blocks): ModuleList( (0): BasicTransformerBlock( (attn1): CrossAttention( (to_q): Linear(in_features=320, out_features=320, bias=False) (to_k): Linear(in_features=320, out_features=320, bias=False) (to_v): Linear(in_features=320, out_features=320, bias=False) (to_out): Sequential( (0): Linear(in_features=320, out_features=320, bias=True) (1): Dropout(p=0.0, inplace=False) ) ) (ff): FeedForward( (net): Sequential( (0): GEGLU( (proj): Linear(in_features=320, out_features=2560, bias=True) ) (1): Dropout(p=0.0, inplace=False) (2): Linear(in_features=1280, out_features=320, bias=True) ) ) (attn2): CrossAttention( (to_q): Linear(in_features=320, out_features=320, bias=False) (to_k): Linear(in_features=768, out_features=320, bias=False) (to_v): Linear(in_features=768, out_features=320, bias=False) (to_out): Sequential( (0): Linear(in_features=320, out_features=320, bias=True) (1): Dropout(p=0.0, inplace=False) ) ) (norm1): LayerNorm((320,), eps=1e-05, elementwise_affine=True) (norm2): LayerNorm((320,), eps=1e-05, elementwise_affine=True) (norm3): LayerNorm((320,), eps=1e-05, elementwise_affine=True) ) ) (proj_out): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1)) ) ) (3): TimestepEmbedSequential( (0): Downsample( (op): Conv2d(320, 320, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) ) ) (4): TimestepEmbedSequential( (0): ResBlock( (in_layers): Sequential( (0): GroupNorm32(32, 320, eps=1e-05, affine=True) (1): SiLU() (2): Conv2d(320, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (h_upd): Identity() (x_upd): Identity() (emb_layers): Sequential( (0): SiLU() (1): Linear(in_features=1280, out_features=640, bias=True) ) (out_layers): Sequential( (0): GroupNorm32(32, 640, eps=1e-05, affine=True) (1): SiLU() (2): Dropout(p=0, inplace=False) (3): Conv2d(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (skip_connection): Conv2d(320, 640, kernel_size=(1, 1), stride=(1, 1)) ) (1): SpatialTransformer( (norm): GroupNorm(32, 640, eps=1e-06, affine=True) (proj_in): Conv2d(640, 640, kernel_size=(1, 1), stride=(1, 1)) (transformer_blocks): ModuleList( (0): BasicTransformerBlock( (attn1): CrossAttention( (to_q): Linear(in_features=640, out_features=640, bias=False) (to_k): Linear(in_features=640, out_features=640, bias=False) (to_v): Linear(in_features=640, out_features=640, bias=False) (to_out): Sequential( (0): Linear(in_features=640, out_features=640, bias=True) (1): Dropout(p=0.0, inplace=False) ) ) (ff): FeedForward( (net): Sequential( (0): GEGLU( (proj): Linear(in_features=640, out_features=5120, bias=True) ) (1): Dropout(p=0.0, inplace=False) (2): Linear(in_features=2560, out_features=640, bias=True) ) ) (attn2): CrossAttention( (to_q): Linear(in_features=640, out_features=640, bias=False) (to_k): Linear(in_features=768, out_features=640, bias=False) (to_v): Linear(in_features=768, out_features=640, bias=False) (to_out): Sequential( (0): Linear(in_features=640, out_features=640, bias=True) (1): Dropout(p=0.0, inplace=False) ) ) (norm1): LayerNorm((640,), eps=1e-05, elementwise_affine=True) (norm2): LayerNorm((640,), eps=1e-05, elementwise_affine=True) (norm3): LayerNorm((640,), eps=1e-05, elementwise_affine=True) ) ) (proj_out): Conv2d(640, 640, kernel_size=(1, 1), stride=(1, 1)) ) ) (5): TimestepEmbedSequential( (0): ResBlock( (in_layers): Sequential( (0): GroupNorm32(32, 640, eps=1e-05, affine=True) (1): SiLU() (2): Conv2d(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (h_upd): Identity() (x_upd): Identity() (emb_layers): Sequential( (0): SiLU() (1): Linear(in_features=1280, out_features=640, bias=True) ) (out_layers): Sequential( (0): GroupNorm32(32, 640, eps=1e-05, affine=True) (1): SiLU() (2): Dropout(p=0, inplace=False) (3): Conv2d(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (skip_connection): Identity() ) (1): SpatialTransformer( (norm): GroupNorm(32, 640, eps=1e-06, affine=True) (proj_in): Conv2d(640, 640, kernel_size=(1, 1), stride=(1, 1)) (transformer_blocks): ModuleList( (0): BasicTransformerBlock( (attn1): CrossAttention( (to_q): Linear(in_features=640, out_features=640, bias=False) (to_k): Linear(in_features=640, out_features=640, bias=False) (to_v): Linear(in_features=640, out_features=640, bias=False) (to_out): Sequential( (0): Linear(in_features=640, out_features=640, bias=True) (1): Dropout(p=0.0, inplace=False) ) ) (ff): FeedForward( (net): Sequential( (0): GEGLU( (proj): Linear(in_features=640, out_features=5120, bias=True) ) (1): Dropout(p=0.0, inplace=False) (2): Linear(in_features=2560, out_features=640, bias=True) ) ) (attn2): CrossAttention( (to_q): Linear(in_features=640, out_features=640, bias=False) (to_k): Linear(in_features=768, out_features=640, bias=False) (to_v): Linear(in_features=768, out_features=640, bias=False) (to_out): Sequential( (0): Linear(in_features=640, out_features=640, bias=True) (1): Dropout(p=0.0, inplace=False) ) ) (norm1): LayerNorm((640,), eps=1e-05, elementwise_affine=True) (norm2): LayerNorm((640,), eps=1e-05, elementwise_affine=True) (norm3): LayerNorm((640,), eps=1e-05, elementwise_affine=True) ) ) (proj_out): Conv2d(640, 640, kernel_size=(1, 1), stride=(1, 1)) ) ) (6): TimestepEmbedSequential( (0): Downsample( (op): Conv2d(640, 640, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) ) ) (7): TimestepEmbedSequential( (0): ResBlock( (in_layers): Sequential( (0): GroupNorm32(32, 640, eps=1e-05, affine=True) (1): SiLU() (2): Conv2d(640, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (h_upd): Identity() (x_upd): Identity() (emb_layers): Sequential( (0): SiLU() (1): Linear(in_features=1280, out_features=1280, bias=True) ) (out_layers): Sequential( (0): GroupNorm32(32, 1280, eps=1e-05, affine=True) (1): SiLU() (2): Dropout(p=0, inplace=False) (3): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (skip_connection): Conv2d(640, 1280, kernel_size=(1, 1), stride=(1, 1)) ) (1): SpatialTransformer( (norm): GroupNorm(32, 1280, eps=1e-06, affine=True) (proj_in): Conv2d(1280, 1280, kernel_size=(1, 1), stride=(1, 1)) (transformer_blocks): ModuleList( (0): BasicTransformerBlock( (attn1): CrossAttention( (to_q): Linear(in_features=1280, out_features=1280, bias=False) (to_k): Linear(in_features=1280, out_features=1280, bias=False) (to_v): Linear(in_features=1280, out_features=1280, bias=False) (to_out): Sequential( (0): Linear(in_features=1280, out_features=1280, bias=True) (1): Dropout(p=0.0, inplace=False) ) ) (ff): FeedForward( (net): Sequential( (0): GEGLU( (proj): Linear(in_features=1280, out_features=10240, bias=True) ) (1): Dropout(p=0.0, inplace=False) (2): Linear(in_features=5120, out_features=1280, bias=True) ) ) (attn2): CrossAttention( (to_q): Linear(in_features=1280, out_features=1280, bias=False) (to_k): Linear(in_features=768, out_features=1280, bias=False) (to_v): Linear(in_features=768, out_features=1280, bias=False) (to_out): Sequential( (0): Linear(in_features=1280, out_features=1280, bias=True) (1): Dropout(p=0.0, inplace=False) ) ) (norm1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True) (norm2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True) (norm3): LayerNorm((1280,), eps=1e-05, elementwise_affine=True) ) ) (proj_out): Conv2d(1280, 1280, kernel_size=(1, 1), stride=(1, 1)) ) ) (8): TimestepEmbedSequential( (0): ResBlock( (in_layers): Sequential( (0): GroupNorm32(32, 1280, eps=1e-05, affine=True) (1): SiLU() (2): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (h_upd): Identity() (x_upd): Identity() (emb_layers): Sequential( (0): SiLU() (1): Linear(in_features=1280, out_features=1280, bias=True) ) (out_layers): Sequential( (0): GroupNorm32(32, 1280, eps=1e-05, affine=True) (1): SiLU() (2): Dropout(p=0, inplace=False) (3): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (skip_connection): Identity() ) (1): SpatialTransformer( (norm): GroupNorm(32, 1280, eps=1e-06, affine=True) (proj_in): Conv2d(1280, 1280, kernel_size=(1, 1), stride=(1, 1)) (transformer_blocks): ModuleList( (0): BasicTransformerBlock( (attn1): CrossAttention( (to_q): Linear(in_features=1280, out_features=1280, bias=False) (to_k): Linear(in_features=1280, out_features=1280, bias=False) (to_v): Linear(in_features=1280, out_features=1280, bias=False) (to_out): Sequential( (0): Linear(in_features=1280, out_features=1280, bias=True) (1): Dropout(p=0.0, inplace=False) ) ) (ff): FeedForward( (net): Sequential( (0): GEGLU( (proj): Linear(in_features=1280, out_features=10240, bias=True) ) (1): Dropout(p=0.0, inplace=False) (2): Linear(in_features=5120, out_features=1280, bias=True) ) ) (attn2): CrossAttention( (to_q): Linear(in_features=1280, out_features=1280, bias=False) (to_k): Linear(in_features=768, out_features=1280, bias=False) (to_v): Linear(in_features=768, out_features=1280, bias=False) (to_out): Sequential( (0): Linear(in_features=1280, out_features=1280, bias=True) (1): Dropout(p=0.0, inplace=False) ) ) (norm1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True) (norm2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True) (norm3): LayerNorm((1280,), eps=1e-05, elementwise_affine=True) ) ) (proj_out): Conv2d(1280, 1280, kernel_size=(1, 1), stride=(1, 1)) ) ) (9): TimestepEmbedSequential( (0): Downsample( (op): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) ) ) (10-11): 2 x TimestepEmbedSequential( (0): ResBlock( (in_layers): Sequential( (0): GroupNorm32(32, 1280, eps=1e-05, affine=True) (1): SiLU() (2): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (h_upd): Identity() (x_upd): Identity() (emb_layers): Sequential( (0): SiLU() (1): Linear(in_features=1280, out_features=1280, bias=True) ) (out_layers): Sequential( (0): GroupNorm32(32, 1280, eps=1e-05, affine=True) (1): SiLU() (2): Dropout(p=0, inplace=False) (3): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (skip_connection): Identity() ) ) )12个 module 都使用了 TimestepEmbedSequential 类进行封装,根据不同的网络层,将输入噪声x与timestep embedding和prompt context进行运算。
class TimestepEmbedSequential(nn.Sequential, TimestepBlock): """ A sequential module that passes timestep embeddings to the children that support it as an extra input. """ def forward(self, x, emb, context=None): for layer in self: if isinstance(layer, TimestepBlock): x = layer(x, emb) elif isinstance(layer, SpatialTransformer): x = layer(x, context) else: x = layer(x) return x2-3-1--Module0
Module 0 是一个2D卷积层,主要对输入噪声进行特征提取;
# init 初始化 self.input_blocks = nn.ModuleList( [ TimestepEmbedSequential( conv_nd(dims, in_channels, model_channels, 3, padding=1) ) ] ) # 打印 self.input_blocks[0] TimestepEmbedSequential( (0): Conv2d(4, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) )2-3-2--Module1和Module2
Module1和Module2的结构相同,都由一个ResBlock和一个SpatialTransformer组成;
# init 初始化 for _ in range(num_res_blocks): layers = [ ResBlock( ch, time_embed_dim, dropout, out_channels=mult * model_channels, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, ) ] ch = mult * model_channels if ds in attention_resolutions: if num_head_channels == -1: dim_head = ch // num_heads else: num_heads = ch // num_head_channels dim_head = num_head_channels if legacy: #num_heads = 1 dim_head = ch // num_heads if use_spatial_transformer else num_head_channels layers.append( AttentionBlock( ch, use_checkpoint=use_checkpoint, num_heads=num_heads, num_head_channels=dim_head, use_new_attention_order=use_new_attention_order, ) if not use_spatial_transformer else SpatialTransformer( ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim ) ) self.input_blocks.append(TimestepEmbedSequential(*layers)) self._feature_size += ch input_block_chans.append(ch) # 打印 self.input_blocks[1] TimestepEmbedSequential( (0): ResBlock( (in_layers): Sequential( (0): GroupNorm32(32, 320, eps=1e-05, affine=True) (1): SiLU() (2): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (h_upd): Identity() (x_upd): Identity() (emb_layers): Sequential( (0): SiLU() (1): Linear(in_features=1280, out_features=320, bias=True) ) (out_layers): Sequential( (0): GroupNorm32(32, 320, eps=1e-05, affine=True) (1): SiLU() (2): Dropout(p=0, inplace=False) (3): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (skip_connection): Identity() ) (1): SpatialTransformer( (norm): GroupNorm(32, 320, eps=1e-06, affine=True) (proj_in): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1)) (transformer_blocks): ModuleList( (0): BasicTransformerBlock( (attn1): CrossAttention( (to_q): Linear(in_features=320, out_features=320, bias=False) (to_k): Linear(in_features=320, out_features=320, bias=False) (to_v): Linear(in_features=320, out_features=320, bias=False) (to_out): Sequential( (0): Linear(in_features=320, out_features=320, bias=True) (1): Dropout(p=0.0, inplace=False) ) ) (ff): FeedForward( (net): Sequential( (0): GEGLU( (proj): Linear(in_features=320, out_features=2560, bias=True) ) (1): Dropout(p=0.0, inplace=False) (2): Linear(in_features=1280, out_features=320, bias=True) ) ) (attn2): CrossAttention( (to_q): Linear(in_features=320, out_features=320, bias=False) (to_k): Linear(in_features=768, out_features=320, bias=False) (to_v): Linear(in_features=768, out_features=320, bias=False) (to_out): Sequential( (0): Linear(in_features=320, out_features=320, bias=True) (1): Dropout(p=0.0, inplace=False) ) ) (norm1): LayerNorm((320,), eps=1e-05, elementwise_affine=True) (norm2): LayerNorm((320,), eps=1e-05, elementwise_affine=True) (norm3): LayerNorm((320,), eps=1e-05, elementwise_affine=True) ) ) (proj_out): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1)) ) ) # 打印 self.input_blocks[2] TimestepEmbedSequential( (0): ResBlock( (in_layers): Sequential( (0): GroupNorm32(32, 320, eps=1e-05, affine=True) (1): SiLU() (2): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (h_upd): Identity() (x_upd): Identity() (emb_layers): Sequential( (0): SiLU() (1): Linear(in_features=1280, out_features=320, bias=True) ) (out_layers): Sequential( (0): GroupNorm32(32, 320, eps=1e-05, affine=True) (1): SiLU() (2): Dropout(p=0, inplace=False) (3): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (skip_connection): Identity() ) (1): SpatialTransformer( (norm): GroupNorm(32, 320, eps=1e-06, affine=True) (proj_in): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1)) (transformer_blocks): ModuleList( (0): BasicTransformerBlock( (attn1): CrossAttention( (to_q): Linear(in_features=320, out_features=320, bias=False) (to_k): Linear(in_features=320, out_features=320, bias=False) (to_v): Linear(in_features=320, out_features=320, bias=False) (to_out): Sequential( (0): Linear(in_features=320, out_features=320, bias=True) (1): Dropout(p=0.0, inplace=False) ) ) (ff): FeedForward( (net): Sequential( (0): GEGLU( (proj): Linear(in_features=320, out_features=2560, bias=True) ) (1): Dropout(p=0.0, inplace=False) (2): Linear(in_features=1280, out_features=320, bias=True) ) ) (attn2): CrossAttention( (to_q): Linear(in_features=320, out_features=320, bias=False) (to_k): Linear(in_features=768, out_features=320, bias=False) (to_v): Linear(in_features=768, out_features=320, bias=False) (to_out): Sequential( (0): Linear(in_features=320, out_features=320, bias=True) (1): Dropout(p=0.0, inplace=False) ) ) (norm1): LayerNorm((320,), eps=1e-05, elementwise_affine=True) (norm2): LayerNorm((320,), eps=1e-05, elementwise_affine=True) (norm3): LayerNorm((320,), eps=1e-05, elementwise_affine=True) ) ) (proj_out): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1)) ) )2-3-3--Module3
Module3是一个下采样2D卷积层。
# init 初始化 if level != len(channel_mult) - 1: out_ch = ch self.input_blocks.append( TimestepEmbedSequential( ResBlock( ch, time_embed_dim, dropout, out_channels=out_ch, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, down=True, ) if resblock_updown else Downsample( ch, conv_resample, dims=dims, out_channels=out_ch ) ) ) # 打印 self.input_blocks[3] TimestepEmbedSequential( (0): Downsample( (op): Conv2d(320, 320, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) ) )2-3-4--Module4、Module5、Module7和Module8
与Module1和Module2的结构相同,都由一个ResBlock和一个SpatialTransformer组成,只有特征维度上的区别;
2-3-4--Module6和Module9
与Module3的结构相同,是一个下采样2D卷积层。
2-3--5-Module10和Module11
Module10和Module12的结构相同,只由一个ResBlock组成。
# 打印 self.input_blocks[10] TimestepEmbedSequential( (0): ResBlock( (in_layers): Sequential( (0): GroupNorm32(32, 1280, eps=1e-05, affine=True) (1): SiLU() (2): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (h_upd): Identity() (x_upd): Identity() (emb_layers): Sequential( (0): SiLU() (1): Linear(in_features=1280, out_features=1280, bias=True) ) (out_layers): Sequential( (0): GroupNorm32(32, 1280, eps=1e-05, affine=True) (1): SiLU() (2): Dropout(p=0, inplace=False) (3): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (skip_connection): Identity() ) ) # 打印 self.input_blocks[11] TimestepEmbedSequential( (0): ResBlock( (in_layers): Sequential( (0): GroupNorm32(32, 1280, eps=1e-05, affine=True) (1): SiLU() (2): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (h_upd): Identity() (x_upd): Identity() (emb_layers): Sequential( (0): SiLU() (1): Linear(in_features=1280, out_features=1280, bias=True) ) (out_layers): Sequential( (0): GroupNorm32(32, 1280, eps=1e-05, affine=True) (1): SiLU() (2): Dropout(p=0, inplace=False) (3): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (skip_connection): Identity() ) )2-3-6--ResBlock
ResBlock的输入是噪声图x和timestep embedding,通过卷积处理和残差连接等方式将timestep embedding融入噪声图特征中,核心代码如下:
class ResBlock(TimestepBlock): """ A residual block that can optionally change the number of channels. :param channels: the number of input channels. :param emb_channels: the number of timestep embedding channels. :param dropout: the rate of dropout. :param out_channels: if specified, the number of out channels. :param use_conv: if True and out_channels is specified, use a spatial convolution instead of a smaller 1x1 convolution to change the channels in the skip connection. :param dims: determines if the signal is 1D, 2D, or 3D. :param use_checkpoint: if True, use gradient checkpointing on this module. :param up: if True, use this block for upsampling. :param down: if True, use this block for downsampling. """ def __init__( self, channels, emb_channels, dropout, out_channels=None, use_conv=False, use_scale_shift_norm=False, dims=2, use_checkpoint=False, up=False, down=False, ): super().__init__() self.channels = channels self.emb_channels = emb_channels self.dropout = dropout self.out_channels = out_channels or channels self.use_conv = use_conv self.use_checkpoint = use_checkpoint self.use_scale_shift_norm = use_scale_shift_norm self.in_layers = nn.Sequential( normalization(channels), nn.SiLU(), conv_nd(dims, channels, self.out_channels, 3, padding=1), ) self.updown = up or down if up: self.h_upd = Upsample(channels, False, dims) self.x_upd = Upsample(channels, False, dims) elif down: self.h_upd = Downsample(channels, False, dims) self.x_upd = Downsample(channels, False, dims) else: self.h_upd = self.x_upd = nn.Identity() self.emb_layers = nn.Sequential( nn.SiLU(), linear( emb_channels, 2 * self.out_channels if use_scale_shift_norm else self.out_channels, ), ) self.out_layers = nn.Sequential( normalization(self.out_channels), nn.SiLU(), nn.Dropout(p=dropout), zero_module( conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) ), ) if self.out_channels == channels: self.skip_connection = nn.Identity() elif use_conv: self.skip_connection = conv_nd( dims, channels, self.out_channels, 3, padding=1 ) else: self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) def forward(self, x, emb): """ Apply the block to a Tensor, conditioned on a timestep embedding. :param x: an [N x C x ...] Tensor of features. :param emb: an [N x emb_channels] Tensor of timestep embeddings. :return: an [N x C x ...] Tensor of outputs. """ return checkpoint( self._forward, (x, emb), self.parameters(), self.use_checkpoint ) def _forward(self, x, emb): if self.updown: in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] h = in_rest(x) h = self.h_upd(h) x = self.x_upd(x) h = in_conv(h) else: h = self.in_layers(x) # [6, 320, 64, 64] -> [6, 320, 64, 64] emb_out = self.emb_layers(emb).type(h.dtype) # [6, 1280] -> [6, 320] while len(emb_out.shape) < len(h.shape): # [6, 320] -> [6, 320, 1, 1] emb_out = emb_out[..., None] if self.use_scale_shift_norm: out_norm, out_rest = self.out_layers[0], self.out_layers[1:] scale, shift = th.chunk(emb_out, 2, dim=1) h = out_norm(h) * (1 + scale) + shift h = out_rest(h) else: h = h + emb_out # [6, 320, 64, 64] + [6, 320, 1, 1] -> [6, 320, 64, 64] h = self.out_layers(h) # [6, 320, 64, 64] return self.skip_connection(x) + h2-3-7--SpatialTransformer
SpatialTransformer的输入是噪声图x和文本特征context,通过CrossAttention机制将文本特征融入到噪声图x中,完成条件驱动文生图,核心代码如下:
from inspect import isfunction import math import torch import torch.nn.functional as F from torch import nn, einsum from einops import rearrange, repeat from util import checkpoint def exists(val): return val is not None def uniq(arr): return{el: True for el in arr}.keys() def default(val, d): if exists(val): return val return d() if isfunction(d) else d def max_neg_value(t): return -torch.finfo(t.dtype).max def init_(tensor): dim = tensor.shape[-1] std = 1 / math.sqrt(dim) tensor.uniform_(-std, std) return tensor # feedforward class GEGLU(nn.Module): def __init__(self, dim_in, dim_out): super().__init__() self.proj = nn.Linear(dim_in, dim_out * 2) def forward(self, x): x, gate = self.proj(x).chunk(2, dim=-1) return x * F.gelu(gate) class FeedForward(nn.Module): def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): super().__init__() inner_dim = int(dim * mult) dim_out = default(dim_out, dim) project_in = nn.Sequential( nn.Linear(dim, inner_dim), nn.GELU() ) if not glu else GEGLU(dim, inner_dim) self.net = nn.Sequential( project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) ) def forward(self, x): return self.net(x) def zero_module(module): """ Zero out the parameters of a module and return it. """ for p in module.parameters(): p.detach().zero_() return module def Normalize(in_channels): return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) class CrossAttention(nn.Module): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): super().__init__() inner_dim = dim_head * heads # dim_head: 40, heads: 8 context_dim = default(context_dim, query_dim) self.scale = dim_head ** -0.5 self.heads = heads self.to_q = nn.Linear(query_dim, inner_dim, bias=False) self.to_k = nn.Linear(context_dim, inner_dim, bias=False) self.to_v = nn.Linear(context_dim, inner_dim, bias=False) self.to_out = nn.Sequential( nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) ) def forward(self, x, context=None, mask=None): h = self.heads # 8 q = self.to_q(x) # [6, 4096, 320] -> [6, 4096, 320] context = default(context, x) # return context [6, 77, 768] k = self.to_k(context) # [6, 77, 768] -> [6, 77, 320] v = self.to_v(context) # [6, 77, 768] -> [6, 77, 320] q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) # [6, 4096, 320] -> [48, 4096, 40] # [6, 77, 320] -> [48, 77, 40] sim = einsum('b i d, b j d -> b i j', q, k) * self.scale # [48, 4096, 40] * [48, 77, 40] -> [48, 4096, 77] if exists(mask): mask = rearrange(mask, 'b ... -> b (...)') max_neg_value = -torch.finfo(sim.dtype).max mask = repeat(mask, 'b j -> (b h) () j', h=h) sim.masked_fill_(~mask, max_neg_value) # attention, what we cannot get enough of attn = sim.softmax(dim=-1) # softmax out = einsum('b i j, b j d -> b i d', attn, v) # [48, 4096, 77] * [48, 77, 40] -> [48, 4096, 40] out = rearrange(out, '(b h) n d -> b n (h d)', h=h) # [48, 4096, 40] -> [6, 4096, 320] return self.to_out(out) class BasicTransformerBlock(nn.Module): def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True): super().__init__() self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none self.norm1 = nn.LayerNorm(dim) self.norm2 = nn.LayerNorm(dim) self.norm3 = nn.LayerNorm(dim) self.checkpoint = checkpoint def forward(self, x, context=None): return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) def _forward(self, x, context=None): x = self.attn1(self.norm1(x)) + x # self Attention, [6, 4096, 320] -> [6, 4096, 320] x = self.attn2(self.norm2(x), context=context) + x # cross Attention, [6, 4096, 320] -> [6, 4096, 320] x = self.ff(self.norm3(x)) + x # FFN, [6, 4096, 320] -> [6, 4096, 320] return x class SpatialTransformer(nn.Module): """ Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply standard transformer action. Finally, reshape to image """ def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0., context_dim=None): super().__init__() self.in_channels = in_channels inner_dim = n_heads * d_head self.norm = Normalize(in_channels) self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) self.transformer_blocks = nn.ModuleList( [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) for d in range(depth)] ) self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)) def forward(self, x, context=None): # note: if no context is given, cross-attention defaults to self-attention b, c, h, w = x.shape # [6, 320, 64, 64] x_in = x x = self.norm(x) # [6, 320, 64, 64] x = self.proj_in(x) # [6, 320, 64, 64] x = rearrange(x, 'b c h w -> b (h w) c') # [6, 4096, 320] for block in self.transformer_blocks: x = block(x, context=context) x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) x = self.proj_out(x) return x + x_in2-4--self.middle_block
self.middle_block由两个ResBlock和一个SpatialTransformer组成:
TimestepEmbedSequential( (0): ResBlock( (in_layers): Sequential( (0): GroupNorm32(32, 1280, eps=1e-05, affine=True) (1): SiLU() (2): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (h_upd): Identity() (x_upd): Identity() (emb_layers): Sequential( (0): SiLU() (1): Linear(in_features=1280, out_features=1280, bias=True) ) (out_layers): Sequential( (0): GroupNorm32(32, 1280, eps=1e-05, affine=True) (1): SiLU() (2): Dropout(p=0, inplace=False) (3): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (skip_connection): Identity() ) (1): SpatialTransformer( (norm): GroupNorm(32, 1280, eps=1e-06, affine=True) (proj_in): Conv2d(1280, 1280, kernel_size=(1, 1), stride=(1, 1)) (transformer_blocks): ModuleList( (0): BasicTransformerBlock( (attn1): CrossAttention( (to_q): Linear(in_features=1280, out_features=1280, bias=False) (to_k): Linear(in_features=1280, out_features=1280, bias=False) (to_v): Linear(in_features=1280, out_features=1280, bias=False) (to_out): Sequential( (0): Linear(in_features=1280, out_features=1280, bias=True) (1): Dropout(p=0.0, inplace=False) ) ) (ff): FeedForward( (net): Sequential( (0): GEGLU( (proj): Linear(in_features=1280, out_features=10240, bias=True) ) (1): Dropout(p=0.0, inplace=False) (2): Linear(in_features=5120, out_features=1280, bias=True) ) ) (attn2): CrossAttention( (to_q): Linear(in_features=1280, out_features=1280, bias=False) (to_k): Linear(in_features=768, out_features=1280, bias=False) (to_v): Linear(in_features=768, out_features=1280, bias=False) (to_out): Sequential( (0): Linear(in_features=1280, out_features=1280, bias=True) (1): Dropout(p=0.0, inplace=False) ) ) (norm1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True) (norm2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True) (norm3): LayerNorm((1280,), eps=1e-05, elementwise_affine=True) ) ) (proj_out): Conv2d(1280, 1280, kernel_size=(1, 1), stride=(1, 1)) ) (2): ResBlock( (in_layers): Sequential( (0): GroupNorm32(32, 1280, eps=1e-05, affine=True) (1): SiLU() (2): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (h_upd): Identity() (x_upd): Identity() (emb_layers): Sequential( (0): SiLU() (1): Linear(in_features=1280, out_features=1280, bias=True) ) (out_layers): Sequential( (0): GroupNorm32(32, 1280, eps=1e-05, affine=True) (1): SiLU() (2): Dropout(p=0, inplace=False) (3): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (skip_connection): Identity() ) )2-5--self.output_blocks上采样
在 Forward() 中,使用 self.output_blocks 将噪声图进行分辨率上采样,经过上采样具体维度变化为:[B*2, 1280, 8, 8] > [B*2, 4, 64, 64];
下采样模块共有12个 module,其结构与下采样模块类似,组成如下:
ModuleList( (0-1): 2 x TimestepEmbedSequential( (0): ResBlock( (in_layers): Sequential( (0): GroupNorm32(32, 2560, eps=1e-05, affine=True) (1): SiLU() (2): Conv2d(2560, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (h_upd): Identity() (x_upd): Identity() (emb_layers): Sequential( (0): SiLU() (1): Linear(in_features=1280, out_features=1280, bias=True) ) (out_layers): Sequential( (0): GroupNorm32(32, 1280, eps=1e-05, affine=True) (1): SiLU() (2): Dropout(p=0, inplace=False) (3): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (skip_connection): Conv2d(2560, 1280, kernel_size=(1, 1), stride=(1, 1)) ) ) (2): TimestepEmbedSequential( (0): ResBlock( (in_layers): Sequential( (0): GroupNorm32(32, 2560, eps=1e-05, affine=True) (1): SiLU() (2): Conv2d(2560, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (h_upd): Identity() (x_upd): Identity() (emb_layers): Sequential( (0): SiLU() (1): Linear(in_features=1280, out_features=1280, bias=True) ) (out_layers): Sequential( (0): GroupNorm32(32, 1280, eps=1e-05, affine=True) (1): SiLU() (2): Dropout(p=0, inplace=False) (3): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (skip_connection): Conv2d(2560, 1280, kernel_size=(1, 1), stride=(1, 1)) ) (1): Upsample( (conv): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) ) (3-4): 2 x TimestepEmbedSequential( (0): ResBlock( (in_layers): Sequential( (0): GroupNorm32(32, 2560, eps=1e-05, affine=True) (1): SiLU() (2): Conv2d(2560, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (h_upd): Identity() (x_upd): Identity() (emb_layers): Sequential( (0): SiLU() (1): Linear(in_features=1280, out_features=1280, bias=True) ) (out_layers): Sequential( (0): GroupNorm32(32, 1280, eps=1e-05, affine=True) (1): SiLU() (2): Dropout(p=0, inplace=False) (3): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (skip_connection): Conv2d(2560, 1280, kernel_size=(1, 1), stride=(1, 1)) ) (1): SpatialTransformer( (norm): GroupNorm(32, 1280, eps=1e-06, affine=True) (proj_in): Conv2d(1280, 1280, kernel_size=(1, 1), stride=(1, 1)) (transformer_blocks): ModuleList( (0): BasicTransformerBlock( (attn1): CrossAttention( (to_q): Linear(in_features=1280, out_features=1280, bias=False) (to_k): Linear(in_features=1280, out_features=1280, bias=False) (to_v): Linear(in_features=1280, out_features=1280, bias=False) (to_out): Sequential( (0): Linear(in_features=1280, out_features=1280, bias=True) (1): Dropout(p=0.0, inplace=False) ) ) (ff): FeedForward( (net): Sequential( (0): GEGLU( (proj): Linear(in_features=1280, out_features=10240, bias=True) ) (1): Dropout(p=0.0, inplace=False) (2): Linear(in_features=5120, out_features=1280, bias=True) ) ) (attn2): CrossAttention( (to_q): Linear(in_features=1280, out_features=1280, bias=False) (to_k): Linear(in_features=768, out_features=1280, bias=False) (to_v): Linear(in_features=768, out_features=1280, bias=False) (to_out): Sequential( (0): Linear(in_features=1280, out_features=1280, bias=True) (1): Dropout(p=0.0, inplace=False) ) ) (norm1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True) (norm2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True) (norm3): LayerNorm((1280,), eps=1e-05, elementwise_affine=True) ) ) (proj_out): Conv2d(1280, 1280, kernel_size=(1, 1), stride=(1, 1)) ) ) (5): TimestepEmbedSequential( (0): ResBlock( (in_layers): Sequential( (0): GroupNorm32(32, 1920, eps=1e-05, affine=True) (1): SiLU() (2): Conv2d(1920, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (h_upd): Identity() (x_upd): Identity() (emb_layers): Sequential( (0): SiLU() (1): Linear(in_features=1280, out_features=1280, bias=True) ) (out_layers): Sequential( (0): GroupNorm32(32, 1280, eps=1e-05, affine=True) (1): SiLU() (2): Dropout(p=0, inplace=False) (3): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (skip_connection): Conv2d(1920, 1280, kernel_size=(1, 1), stride=(1, 1)) ) (1): SpatialTransformer( (norm): GroupNorm(32, 1280, eps=1e-06, affine=True) (proj_in): Conv2d(1280, 1280, kernel_size=(1, 1), stride=(1, 1)) (transformer_blocks): ModuleList( (0): BasicTransformerBlock( (attn1): CrossAttention( (to_q): Linear(in_features=1280, out_features=1280, bias=False) (to_k): Linear(in_features=1280, out_features=1280, bias=False) (to_v): Linear(in_features=1280, out_features=1280, bias=False) (to_out): Sequential( (0): Linear(in_features=1280, out_features=1280, bias=True) (1): Dropout(p=0.0, inplace=False) ) ) (ff): FeedForward( (net): Sequential( (0): GEGLU( (proj): Linear(in_features=1280, out_features=10240, bias=True) ) (1): Dropout(p=0.0, inplace=False) (2): Linear(in_features=5120, out_features=1280, bias=True) ) ) (attn2): CrossAttention( (to_q): Linear(in_features=1280, out_features=1280, bias=False) (to_k): Linear(in_features=768, out_features=1280, bias=False) (to_v): Linear(in_features=768, out_features=1280, bias=False) (to_out): Sequential( (0): Linear(in_features=1280, out_features=1280, bias=True) (1): Dropout(p=0.0, inplace=False) ) ) (norm1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True) (norm2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True) (norm3): LayerNorm((1280,), eps=1e-05, elementwise_affine=True) ) ) (proj_out): Conv2d(1280, 1280, kernel_size=(1, 1), stride=(1, 1)) ) (2): Upsample( (conv): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) ) (6): TimestepEmbedSequential( (0): ResBlock( (in_layers): Sequential( (0): GroupNorm32(32, 1920, eps=1e-05, affine=True) (1): SiLU() (2): Conv2d(1920, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (h_upd): Identity() (x_upd): Identity() (emb_layers): Sequential( (0): SiLU() (1): Linear(in_features=1280, out_features=640, bias=True) ) (out_layers): Sequential( (0): GroupNorm32(32, 640, eps=1e-05, affine=True) (1): SiLU() (2): Dropout(p=0, inplace=False) (3): Conv2d(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (skip_connection): Conv2d(1920, 640, kernel_size=(1, 1), stride=(1, 1)) ) (1): SpatialTransformer( (norm): GroupNorm(32, 640, eps=1e-06, affine=True) (proj_in): Conv2d(640, 640, kernel_size=(1, 1), stride=(1, 1)) (transformer_blocks): ModuleList( (0): BasicTransformerBlock( (attn1): CrossAttention( (to_q): Linear(in_features=640, out_features=640, bias=False) (to_k): Linear(in_features=640, out_features=640, bias=False) (to_v): Linear(in_features=640, out_features=640, bias=False) (to_out): Sequential( (0): Linear(in_features=640, out_features=640, bias=True) (1): Dropout(p=0.0, inplace=False) ) ) (ff): FeedForward( (net): Sequential( (0): GEGLU( (proj): Linear(in_features=640, out_features=5120, bias=True) ) (1): Dropout(p=0.0, inplace=False) (2): Linear(in_features=2560, out_features=640, bias=True) ) ) (attn2): CrossAttention( (to_q): Linear(in_features=640, out_features=640, bias=False) (to_k): Linear(in_features=768, out_features=640, bias=False) (to_v): Linear(in_features=768, out_features=640, bias=False) (to_out): Sequential( (0): Linear(in_features=640, out_features=640, bias=True) (1): Dropout(p=0.0, inplace=False) ) ) (norm1): LayerNorm((640,), eps=1e-05, elementwise_affine=True) (norm2): LayerNorm((640,), eps=1e-05, elementwise_affine=True) (norm3): LayerNorm((640,), eps=1e-05, elementwise_affine=True) ) ) (proj_out): Conv2d(640, 640, kernel_size=(1, 1), stride=(1, 1)) ) ) (7): TimestepEmbedSequential( (0): ResBlock( (in_layers): Sequential( (0): GroupNorm32(32, 1280, eps=1e-05, affine=True) (1): SiLU() (2): Conv2d(1280, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (h_upd): Identity() (x_upd): Identity() (emb_layers): Sequential( (0): SiLU() (1): Linear(in_features=1280, out_features=640, bias=True) ) (out_layers): Sequential( (0): GroupNorm32(32, 640, eps=1e-05, affine=True) (1): SiLU() (2): Dropout(p=0, inplace=False) (3): Conv2d(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (skip_connection): Conv2d(1280, 640, kernel_size=(1, 1), stride=(1, 1)) ) (1): SpatialTransformer( (norm): GroupNorm(32, 640, eps=1e-06, affine=True) (proj_in): Conv2d(640, 640, kernel_size=(1, 1), stride=(1, 1)) (transformer_blocks): ModuleList( (0): BasicTransformerBlock( (attn1): CrossAttention( (to_q): Linear(in_features=640, out_features=640, bias=False) (to_k): Linear(in_features=640, out_features=640, bias=False) (to_v): Linear(in_features=640, out_features=640, bias=False) (to_out): Sequential( (0): Linear(in_features=640, out_features=640, bias=True) (1): Dropout(p=0.0, inplace=False) ) ) (ff): FeedForward( (net): Sequential( (0): GEGLU( (proj): Linear(in_features=640, out_features=5120, bias=True) ) (1): Dropout(p=0.0, inplace=False) (2): Linear(in_features=2560, out_features=640, bias=True) ) ) (attn2): CrossAttention( (to_q): Linear(in_features=640, out_features=640, bias=False) (to_k): Linear(in_features=768, out_features=640, bias=False) (to_v): Linear(in_features=768, out_features=640, bias=False) (to_out): Sequential( (0): Linear(in_features=640, out_features=640, bias=True) (1): Dropout(p=0.0, inplace=False) ) ) (norm1): LayerNorm((640,), eps=1e-05, elementwise_affine=True) (norm2): LayerNorm((640,), eps=1e-05, elementwise_affine=True) (norm3): LayerNorm((640,), eps=1e-05, elementwise_affine=True) ) ) (proj_out): Conv2d(640, 640, kernel_size=(1, 1), stride=(1, 1)) ) ) (8): TimestepEmbedSequential( (0): ResBlock( (in_layers): Sequential( (0): GroupNorm32(32, 960, eps=1e-05, affine=True) (1): SiLU() (2): Conv2d(960, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (h_upd): Identity() (x_upd): Identity() (emb_layers): Sequential( (0): SiLU() (1): Linear(in_features=1280, out_features=640, bias=True) ) (out_layers): Sequential( (0): GroupNorm32(32, 640, eps=1e-05, affine=True) (1): SiLU() (2): Dropout(p=0, inplace=False) (3): Conv2d(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (skip_connection): Conv2d(960, 640, kernel_size=(1, 1), stride=(1, 1)) ) (1): SpatialTransformer( (norm): GroupNorm(32, 640, eps=1e-06, affine=True) (proj_in): Conv2d(640, 640, kernel_size=(1, 1), stride=(1, 1)) (transformer_blocks): ModuleList( (0): BasicTransformerBlock( (attn1): CrossAttention( (to_q): Linear(in_features=640, out_features=640, bias=False) (to_k): Linear(in_features=640, out_features=640, bias=False) (to_v): Linear(in_features=640, out_features=640, bias=False) (to_out): Sequential( (0): Linear(in_features=640, out_features=640, bias=True) (1): Dropout(p=0.0, inplace=False) ) ) (ff): FeedForward( (net): Sequential( (0): GEGLU( (proj): Linear(in_features=640, out_features=5120, bias=True) ) (1): Dropout(p=0.0, inplace=False) (2): Linear(in_features=2560, out_features=640, bias=True) ) ) (attn2): CrossAttention( (to_q): Linear(in_features=640, out_features=640, bias=False) (to_k): Linear(in_features=768, out_features=640, bias=False) (to_v): Linear(in_features=768, out_features=640, bias=False) (to_out): Sequential( (0): Linear(in_features=640, out_features=640, bias=True) (1): Dropout(p=0.0, inplace=False) ) ) (norm1): LayerNorm((640,), eps=1e-05, elementwise_affine=True) (norm2): LayerNorm((640,), eps=1e-05, elementwise_affine=True) (norm3): LayerNorm((640,), eps=1e-05, elementwise_affine=True) ) ) (proj_out): Conv2d(640, 640, kernel_size=(1, 1), stride=(1, 1)) ) (2): Upsample( (conv): Conv2d(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) ) (9): TimestepEmbedSequential( (0): ResBlock( (in_layers): Sequential( (0): GroupNorm32(32, 960, eps=1e-05, affine=True) (1): SiLU() (2): Conv2d(960, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (h_upd): Identity() (x_upd): Identity() (emb_layers): Sequential( (0): SiLU() (1): Linear(in_features=1280, out_features=320, bias=True) ) (out_layers): Sequential( (0): GroupNorm32(32, 320, eps=1e-05, affine=True) (1): SiLU() (2): Dropout(p=0, inplace=False) (3): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (skip_connection): Conv2d(960, 320, kernel_size=(1, 1), stride=(1, 1)) ) (1): SpatialTransformer( (norm): GroupNorm(32, 320, eps=1e-06, affine=True) (proj_in): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1)) (transformer_blocks): ModuleList( (0): BasicTransformerBlock( (attn1): CrossAttention( (to_q): Linear(in_features=320, out_features=320, bias=False) (to_k): Linear(in_features=320, out_features=320, bias=False) (to_v): Linear(in_features=320, out_features=320, bias=False) (to_out): Sequential( (0): Linear(in_features=320, out_features=320, bias=True) (1): Dropout(p=0.0, inplace=False) ) ) (ff): FeedForward( (net): Sequential( (0): GEGLU( (proj): Linear(in_features=320, out_features=2560, bias=True) ) (1): Dropout(p=0.0, inplace=False) (2): Linear(in_features=1280, out_features=320, bias=True) ) ) (attn2): CrossAttention( (to_q): Linear(in_features=320, out_features=320, bias=False) (to_k): Linear(in_features=768, out_features=320, bias=False) (to_v): Linear(in_features=768, out_features=320, bias=False) (to_out): Sequential( (0): Linear(in_features=320, out_features=320, bias=True) (1): Dropout(p=0.0, inplace=False) ) ) (norm1): LayerNorm((320,), eps=1e-05, elementwise_affine=True) (norm2): LayerNorm((320,), eps=1e-05, elementwise_affine=True) (norm3): LayerNorm((320,), eps=1e-05, elementwise_affine=True) ) ) (proj_out): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1)) ) ) (10-11): 2 x TimestepEmbedSequential( (0): ResBlock( (in_layers): Sequential( (0): GroupNorm32(32, 640, eps=1e-05, affine=True) (1): SiLU() (2): Conv2d(640, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (h_upd): Identity() (x_upd): Identity() (emb_layers): Sequential( (0): SiLU() (1): Linear(in_features=1280, out_features=320, bias=True) ) (out_layers): Sequential( (0): GroupNorm32(32, 320, eps=1e-05, affine=True) (1): SiLU() (2): Dropout(p=0, inplace=False) (3): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (skip_connection): Conv2d(640, 320, kernel_size=(1, 1), stride=(1, 1)) ) (1): SpatialTransformer( (norm): GroupNorm(32, 320, eps=1e-06, affine=True) (proj_in): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1)) (transformer_blocks): ModuleList( (0): BasicTransformerBlock( (attn1): CrossAttention( (to_q): Linear(in_features=320, out_features=320, bias=False) (to_k): Linear(in_features=320, out_features=320, bias=False) (to_v): Linear(in_features=320, out_features=320, bias=False) (to_out): Sequential( (0): Linear(in_features=320, out_features=320, bias=True) (1): Dropout(p=0.0, inplace=False) ) ) (ff): FeedForward( (net): Sequential( (0): GEGLU( (proj): Linear(in_features=320, out_features=2560, bias=True) ) (1): Dropout(p=0.0, inplace=False) (2): Linear(in_features=1280, out_features=320, bias=True) ) ) (attn2): CrossAttention( (to_q): Linear(in_features=320, out_features=320, bias=False) (to_k): Linear(in_features=768, out_features=320, bias=False) (to_v): Linear(in_features=768, out_features=320, bias=False) (to_out): Sequential( (0): Linear(in_features=320, out_features=320, bias=True) (1): Dropout(p=0.0, inplace=False) ) ) (norm1): LayerNorm((320,), eps=1e-05, elementwise_affine=True) (norm2): LayerNorm((320,), eps=1e-05, elementwise_affine=True) (norm3): LayerNorm((320,), eps=1e-05, elementwise_affine=True) ) ) (proj_out): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1)) ) ) ) 
		 
                             
                             
                             
                             
                             
                            