如何修改大模型的位置编码 --以LLama为例

如何修改大模型的位置编码 --以LLama为例

    正在检查是否收录...

最近在看RoPE相关内容,一些方法通过简单修改位置编码就可以无需训练支持更长的文本内容。由于一些模型,已经训练好了,但是怎么修改已经训练好的模型位置编码。查了以下相关代码,记录一下。原理这里就不细讲了,贴几个相关博客。
十分钟读懂旋转编码(RoPE)
Transformer升级之路:11、将β进制位置进行到底
Transformer升级之路:10、RoPE是一种β进制编码

NTK

下图为NTK的原理证明:截取自Transformer升级之路:10、RoPE是一种β进制编码

看了上面的公式,我在考虑为什么需要建立 λ \lambda λ和k之间的关系?

因为我们要修改 β \beta β进制为 β λ \beta\lambda βλ,由于k我们是可以知道的比如我们需要把位置编码缩小为10倍,直接设置k为10,但是采用NTK的方式,维度缩小为10倍,那么我们就不确定, λ \lambda λ怎么设置了。所以需要简历 λ \lambda λ和k之间的关系,从上图可知, λ = k 2 / ( d − 2 ) \lambda=k^{2/(d-2)} λ=k2/(d−2)。
下面开始理解如何修改RoPE为NTK的形式:
以下为LLama的RoPE代码实现

class LlamaRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): super().__init__() self.scaling_factor = scaling_factor self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) # For BC we register cos and sin cached self.max_seq_len_cached = max_position_embeddings t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) t = t / self.scaling_factor freqs = torch.outer(t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) self.register_buffer("_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False) self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False) @property def sin_cached(self): logger.warning_once( "The sin_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use " "the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class" ) return self._sin_cached @property def cos_cached(self): logger.warning_once( "The cos_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use " "the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class" ) return self._cos_cached @torch.no_grad() def forward(self, x, position_ids): # x: [bs, num_attention_heads, seq_len, head_size] inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() # Force float32 since bfloat16 loses precision on long contexts # See https://github.com/huggingface/transformers/pull/29285 device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) 

λ \lambda λ和k之间的关系,那么代码怎么实现呢,我们只需要修改 β λ \beta\lambda βλ的结果即可,其中 β \beta β为 1000 0 2 / d 10000^{2/d} 100002/d。
参考代码为:点击

import transformers old_init = transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__ def ntk_scaled_init(self, dim, max_position_embeddings=2048, base=10000, device=None): #The method is just these three lines max_position_embeddings = 16384 k = 8 #Alpha value base = base * k ** (dim / (dim-2)) #Base change formula old_init(self, dim, max_position_embeddings, base, device) transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__ = ntk_scaled_init 

为什么采用base = base * k ** (dim / (dim-2)),原始的base为10000, 而 β \beta β为 1000 0 2 / d 10000^{2/d} 100002/d,发现在代码里面修改的仅仅是base的结果, β \beta β= b a s e 2 / d base^{2/d} base2/d,而 λ = k 2 / ( d − 2 ) \lambda=k^{2/(d-2)} λ=k2/(d−2),我们需要把k和base进行融合,修改成,base*k的形式形成新的base, λ \lambda λ等于k的指数 2 / ( d − 2 ) ∗ d / 2 ∗ 2 / d = d / ( d − 2 ) ∗ 2 / d 2/(d-2)*d/2*2/d=d/(d-2)*2/d 2/(d−2)∗d/2∗2/d=d/(d−2)∗2/d, λ = ( k d / ( d − 2 ) ) 2 / d \lambda=(k^{d/(d-2)})^{2/d} λ=(kd/(d−2))2/d,因为2/d在RoPE的代码里面已经计算过了:

inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim) 

所以我们则赋值新的base为base * k ** (dim / (dim-2))。

Dynamic NTK

Dynamic在NTK的基础上进行简单的修改,采用NTK的时候更加灵活。
截图源于:RoPE到底是何方神圣(数学推理+优化方法)

代码实现:

class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" def forward(self, x, position_ids): # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length seq_len = torch.max(position_ids) + 1 if seq_len > self.max_position_embeddings:#只有长度超过了预训练的阈值,进行NTK base = self.base * ( (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) ) ** (self.dim / (self.dim - 2)) inv_freq = 1.0 / ( base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim) ) self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation cos, sin = super().forward(x, position_ids) return cos, sin 

llamatransformerctotransformerstodogithubredditgithuggingface无需训练预训练mozillacpu
  • 本文作者:李琛
  • 本文链接: https://wapzz.net/post-12185.html
  • 版权声明:本博客所有文章除特别声明外,均默认采用 CC BY-NC-SA 4.0 许可协议。
本站部分内容来源于网络转载,仅供学习交流使用。如涉及版权问题,请及时联系我们,我们将第一时间处理。
文章很赞!支持一下吧 还没有人为TA充电
为TA充电
还没有人为TA充电
0
  • 支付宝打赏
    支付宝扫一扫
  • 微信打赏
    微信扫一扫
感谢支持
文章很赞!支持一下吧
关于作者
2.3W+
5
0
1
WAP站长官方

大学生论文写作助手(大学生论文ai写作)

上一篇

PHP技术揭秘:文章采集神器的利器

下一篇
  • 复制图片
按住ctrl可打开默认菜单