AIGC笔记--VQVAE模型搭建

AIGC笔记--VQVAE模型搭建

    正在检查是否收录...

1--VQVAE模型

        VAE 模型生成的内容质量不高,原因可能在于将图片编码成连续变量(映射为标准分布),然而将图片编码成离散变量可能会更好(因为现实生活中习惯用离散变量来形容事物,例如人的高矮胖瘦等都是离散的;)

        VQVAE模型的三个关键模块:Encoder、Decoder 和 Codebook;

        Encoder 将输入编码成特征向量,计算特征向量与 Codebook 中 Embedding 向量的相似性(L2距离),取最相似的 Embedding 向量作为特征向量的替代,并输入到 Decoder 中进行重构输入;

        VQVAE的损失函数包括源图片和重构图片的重构损失,以及 Codebook 中量化过程的量化损失 vq_loss;

        VQ-VAE详细介绍参考:轻松理解 VQ-VAE

2--简单代码实例

import torch import torch.nn as nn import torch.nn.functional as F class VectorQuantizer(nn.Module): def __init__(self, num_embeddings, embedding_dim, commitment_cost): super(VectorQuantizer, self).__init__() self._embedding_dim = embedding_dim self._num_embeddings = num_embeddings self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim) self._embedding.weight.data.uniform_(-1/self._num_embeddings, 1/self._num_embeddings) self._commitment_cost = commitment_cost def forward(self, inputs): # convert inputs from BCHW -> BHWC inputs = inputs.permute(0, 2, 3, 1).contiguous() input_shape = inputs.shape # Flatten input flat_input = inputs.view(-1, self._embedding_dim) # Calculate distances distances = (torch.sum(flat_input**2, dim=1, keepdim=True) + torch.sum(self._embedding.weight**2, dim=1) - 2 * torch.matmul(flat_input, self._embedding.weight.t())) # Encoding encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1) encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device) encodings.scatter_(1, encoding_indices, 1) # Quantize and unflatten quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape) # Loss e_latent_loss = F.mse_loss(quantized.detach(), inputs) # 论文中损失函数的第三项 q_latent_loss = F.mse_loss(quantized, inputs.detach()) # 论文中损失函数的第二项 loss = q_latent_loss + self._commitment_cost * e_latent_loss quantized = inputs + (quantized - inputs).detach() # 梯度复制 avg_probs = torch.mean(encodings, dim=0) perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) # convert quantized from BHWC -> BCHW return loss, quantized.permute(0, 3, 1, 2).contiguous(), perplexity, encodings class VectorQuantizerEMA(nn.Module): def __init__(self, num_embeddings, embedding_dim, commitment_cost, decay, epsilon=1e-5): super(VectorQuantizerEMA, self).__init__() self._embedding_dim = embedding_dim self._num_embeddings = num_embeddings self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim) self._embedding.weight.data.normal_() self._commitment_cost = commitment_cost self.register_buffer('_ema_cluster_size', torch.zeros(num_embeddings)) self._ema_w = nn.Parameter(torch.Tensor(num_embeddings, self._embedding_dim)) self._ema_w.data.normal_() self._decay = decay self._epsilon = epsilon def forward(self, inputs): # convert inputs from BCHW -> BHWC inputs = inputs.permute(0, 2, 3, 1).contiguous() input_shape = inputs.shape # B(256) H(8) W(8) C(64) # Flatten input BHWC -> BHW, C flat_input = inputs.view(-1, self._embedding_dim) # Calculate distances 计算与embedding space中所有embedding的距离 distances = (torch.sum(flat_input**2, dim=1, keepdim=True) + torch.sum(self._embedding.weight**2, dim=1) - 2 * torch.matmul(flat_input, self._embedding.weight.t())) # Encoding encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1) # 取最相似的embedding encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device) encodings.scatter_(1, encoding_indices, 1) # 映射为 one-hot vector # Quantize and unflatten quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape) # 根据index使用embedding space对应的embedding # Use EMA to update the embedding vectors if self.training: self._ema_cluster_size = self._ema_cluster_size * self._decay + \ (1 - self._decay) * torch.sum(encodings, 0) # Laplace smoothing of the cluster size n = torch.sum(self._ema_cluster_size.data) self._ema_cluster_size = ( (self._ema_cluster_size + self._epsilon) / (n + self._num_embeddings * self._epsilon) * n) dw = torch.matmul(encodings.t(), flat_input) self._ema_w = nn.Parameter(self._ema_w * self._decay + (1 - self._decay) * dw) self._embedding.weight = nn.Parameter(self._ema_w / self._ema_cluster_size.unsqueeze(1)) # 论文中公式(8) # Loss e_latent_loss = F.mse_loss(quantized.detach(), inputs) # 计算encoder输出(即inputs)和decoder输入(即quantized)之间的损失 loss = self._commitment_cost * e_latent_loss # Straight Through Estimator quantized = inputs + (quantized - inputs).detach() # trick, 将decoder的输入对应的梯度复制,作为encoder的输出对应的梯度 avg_probs = torch.mean(encodings, dim=0) perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) # convert quantized from BHWC -> BCHW return loss, quantized.permute(0, 3, 1, 2).contiguous(), perplexity, encodings class Residual(nn.Module): def __init__(self, in_channels, num_hiddens, num_residual_hiddens): super(Residual, self).__init__() self._block = nn.Sequential( nn.ReLU(True), nn.Conv2d(in_channels = in_channels, out_channels = num_residual_hiddens, kernel_size = 3, stride = 1, padding = 1, bias = False), nn.ReLU(True), nn.Conv2d(in_channels = num_residual_hiddens, out_channels = num_hiddens, kernel_size = 1, stride = 1, bias = False) ) def forward(self, x): return x + self._block(x) class ResidualStack(nn.Module): def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens): super(ResidualStack, self).__init__() self._num_residual_layers = num_residual_layers self._layers = nn.ModuleList([Residual(in_channels, num_hiddens, num_residual_hiddens) for _ in range(self._num_residual_layers)]) def forward(self, x): for i in range(self._num_residual_layers): x = self._layers[i](x) return F.relu(x) class Encoder(nn.Module): def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens): super(Encoder, self).__init__() self._conv_1 = nn.Conv2d(in_channels = in_channels, out_channels = num_hiddens//2, kernel_size = 4, stride = 2, padding = 1) self._conv_2 = nn.Conv2d(in_channels = num_hiddens//2, out_channels = num_hiddens, kernel_size = 4, stride = 2, padding = 1) self._conv_3 = nn.Conv2d(in_channels = num_hiddens, out_channels = num_hiddens, kernel_size = 3, stride = 1, padding = 1) self._residual_stack = ResidualStack(in_channels = num_hiddens, num_hiddens = num_hiddens, num_residual_layers = num_residual_layers, num_residual_hiddens = num_residual_hiddens) def forward(self, inputs): x = self._conv_1(inputs) x = F.relu(x) x = self._conv_2(x) x = F.relu(x) x = self._conv_3(x) return self._residual_stack(x) class Decoder(nn.Module): def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens): super(Decoder, self).__init__() self._conv_1 = nn.Conv2d(in_channels=in_channels, out_channels=num_hiddens, kernel_size=3, stride=1, padding=1) self._residual_stack = ResidualStack(in_channels=num_hiddens, num_hiddens=num_hiddens, num_residual_layers=num_residual_layers, num_residual_hiddens=num_residual_hiddens) self._conv_trans_1 = nn.ConvTranspose2d(in_channels=num_hiddens, out_channels=num_hiddens//2, kernel_size=4, stride=2, padding=1) self._conv_trans_2 = nn.ConvTranspose2d(in_channels=num_hiddens//2, out_channels=3, kernel_size=4, stride=2, padding=1) def forward(self, inputs): x = self._conv_1(inputs) x = self._residual_stack(x) x = self._conv_trans_1(x) x = F.relu(x) return self._conv_trans_2(x) class Model(nn.Module): def __init__(self, num_hiddens, num_residual_layers, num_residual_hiddens, num_embeddings, embedding_dim, commitment_cost, decay=0): super(Model, self).__init__() self._encoder = Encoder(3, num_hiddens, num_residual_layers, num_residual_hiddens) self._pre_vq_conv = nn.Conv2d(in_channels = num_hiddens, out_channels = embedding_dim, kernel_size = 1, stride = 1) if decay > 0.0: self._vq_vae = VectorQuantizerEMA(num_embeddings, embedding_dim, commitment_cost, decay) else: self._vq_vae = VectorQuantizer(num_embeddings, embedding_dim, commitment_cost) self._decoder = Decoder(embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens) def forward(self, x): # x.shape: B(256) C(3) H(32) W(32) z = self._encoder(x) z = self._pre_vq_conv(z) loss, quantized, perplexity, _ = self._vq_vae(z) x_recon = self._decoder(quantized) # decoder解码还原图像 B(256) C(3) H(32) W(32) return loss, x_recon, perplexity 

完整代码参考:liujf69/VQ-VAE

3--部分细节解读:

重构损失计算:

        计算源图像和重构图像的MSE损失

vq_loss, data_recon, perplexity = self.model(data) recon_error = F.mse_loss(data_recon, data) / self.data_variance 

VQ量化损失计算:

        inputs表示Encoder的输出,quantized是Codebook中与 inputs 最接近的向量;

# Loss e_latent_loss = F.mse_loss(quantized.detach(), inputs) # 论文中损失函数的第三项 q_latent_loss = F.mse_loss(quantized, inputs.detach()) # 论文中损失函数的第二项 loss = q_latent_loss + self._commitment_cost * e_latent_loss

Decoder的梯度复制到Encoder中:inputs是Encoder的输出,quantized是Decoder的输入;

quantized = inputs + (quantized - inputs).detach() # 梯度复制

codecodingidectoerp模型生成内容质量
  • 本文作者:李琛
  • 本文链接: https://wapzz.net/post-16593.html
  • 版权声明:本博客所有文章除特别声明外,均默认采用 CC BY-NC-SA 4.0 许可协议。
本站部分内容来源于网络转载,仅供学习交流使用。如涉及版权问题,请及时联系我们,我们将第一时间处理。
文章很赞!支持一下吧 还没有人为TA充电
为TA充电
还没有人为TA充电
0
  • 支付宝打赏
    支付宝扫一扫
  • 微信打赏
    微信扫一扫
感谢支持
文章很赞!支持一下吧
关于作者
2.3W+
5
0
1
WAP站长官方

Py之llama-parse:llama-parse(高效解析和表示文件)的简介、安装和使用方法、案例应用之详细攻略

上一篇

毕业论文不再难:AI写作工具如何助力学生顺利过关?

下一篇
  • 复制图片
按住ctrl可打开默认菜单