【Llama源码】旋转位置编码ROPE--源码阅读

【Llama源码】旋转位置编码ROPE--源码阅读

    正在检查是否收录...

旋转矩阵计算

rotary_emb 对应 L l a m a R o t a r y E m b e d d i n g LlamaRotaryEmbedding LlamaRotaryEmbedding层,其中内置 i n i t init init 初始化方法和 f o r w a r d forward forward 前向调用,负责生成旋转矩阵中的 c o s cos cos 和 s i n sin sin。

代码

class LlamaRotaryEmbedding(torch.nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) self.register_buffer("inv_freq", inv_freq) # Build here to make `torch.jit.trace` work. self.max_seq_len_cached = max_position_embeddings t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) freqs = torch.einsum("i,j->ij", t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) def forward(self, x, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. if seq_len > self.max_seq_len_cached: self.max_seq_len_cached = seq_len t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) freqs = torch.einsum("i,j->ij", t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1).to(x.device) self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) return ( self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), ) 

__init__函数关键代码

根据公式计算 θ \theta θ
源码: inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) 公式:
θ i = 1000 0 − 2 i d \theta_i = 10000^{\frac {-2i}d} θi​=10000d−2i​ example
假设base=10000, dim=8, device=“cpu”
dim, base, device=8, 10000, 'cpu' inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) inv_freq 

inv_freq为size=torch.Size([dim//2])的tensor

tensor([1.0000, 0.1000, 0.0100, 0.0010]) 
生成所有位置对应的ID
源码: t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) example:
假设max_position_embeddings=10,则
max_position_embeddings =10 t = torch.arange(max_position_embeddings, dtype=inv_freq.dtype) t 

输出:

tensor([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]) 
计算 m θ m\theta mθ
源码:freqs = torch.einsum("i,j->ij", t, self.inv_freq) example:
freqs = torch.einsum("i,j->ij", t, inv_freq) freqs 

输出:

tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [1.0000e+00, 1.0000e-01, 1.0000e-02, 1.0000e-03], [2.0000e+00, 2.0000e-01, 2.0000e-02, 2.0000e-03], [3.0000e+00, 3.0000e-01, 3.0000e-02, 3.0000e-03], [4.0000e+00, 4.0000e-01, 4.0000e-02, 4.0000e-03], [5.0000e+00, 5.0000e-01, 5.0000e-02, 5.0000e-03], [6.0000e+00, 6.0000e-01, 6.0000e-02, 6.0000e-03], [7.0000e+00, 7.0000e-01, 7.0000e-02, 7.0000e-03], [8.0000e+00, 8.0000e-01, 8.0000e-02, 8.0000e-03], [9.0000e+00, 9.0000e-01, 9.0000e-02, 9.0000e-03]]) 
将 m θ m\theta mθ拼接两次
源码: emb = torch.cat((freqs, freqs), dim=-1).to(x.device) example:
emb = torch.cat((freqs, freqs), dim=-1) emb 
tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [1.0000e+00, 1.0000e-01, 1.0000e-02, 1.0000e-03, 1.0000e+00, 1.0000e-01, 1.0000e-02, 1.0000e-03], [2.0000e+00, 2.0000e-01, 2.0000e-02, 2.0000e-03, 2.0000e+00, 2.0000e-01, 2.0000e-02, 2.0000e-03], [3.0000e+00, 3.0000e-01, 3.0000e-02, 3.0000e-03, 3.0000e+00, 3.0000e-01, 3.0000e-02, 3.0000e-03], [4.0000e+00, 4.0000e-01, 4.0000e-02, 4.0000e-03, 4.0000e+00, 4.0000e-01, 4.0000e-02, 4.0000e-03], [5.0000e+00, 5.0000e-01, 5.0000e-02, 5.0000e-03, 5.0000e+00, 5.0000e-01, 5.0000e-02, 5.0000e-03], [6.0000e+00, 6.0000e-01, 6.0000e-02, 6.0000e-03, 6.0000e+00, 6.0000e-01, 6.0000e-02, 6.0000e-03], [7.0000e+00, 7.0000e-01, 7.0000e-02, 7.0000e-03, 7.0000e+00, 7.0000e-01, 7.0000e-02, 7.0000e-03], [8.0000e+00, 8.0000e-01, 8.0000e-02, 8.0000e-03, 8.0000e+00, 8.0000e-01, 8.0000e-02, 8.0000e-03], [9.0000e+00, 9.0000e-01, 9.0000e-02, 9.0000e-03, 9.0000e+00, 9.0000e-01, 9.0000e-02, 9.0000e-03]]) 
计算sin、cos
源码:
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) 

rotate_half

def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) 

将原始向量从中间分为x1、x2 两部分,然后拼接为 [-x2, x1] :

[q1,q2,q3,q4,q5,q6,q7,q8,q9,q10] -> [-q6,-q7,-q8,-q9,-q10,q1,q2,q3,q4,q5] 

apply_rotary_pos_emb

def apply_rotary_pos_emb(q, k, cos, sin, position_ids): gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1] gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3]) cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed 
生成index tensor
gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1] gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3]) 
lookup index获取对应位置的 cos ⁡ m θ \cos{mθ} cosmθ 和 sin ⁡ m θ \sin{mθ} sinmθ 值
 cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) 
分别计算前面的 c o s cos cos 部分,再计算后面的 sin ⁡ θ ∗ \sinθ * sinθ∗ rotate_half 部分
 q_embed = (q * cos) + (rotate_half(q) * sin) 
注意

q、cos、sin对应下标顺序与

q = [-q6,-q7,-q8,-q9,-q10,q1,q2,q3,q4,q5] cos = [cosθ1,cosθ2,cosθ3,cosθ4,cosθ5,cosθ1,cosθ2,cosθ3, cosθ4, cosθ5] sin = [sinθ1,sinθ2,sinθ3,sinθ4,sinθ5,sinθ1,sinθ2,sinθ3,sinθ4, sinθ5] 

参考

LLM - 旋转位置编码 RoPE 代码详解
RoPE旋转位置编码深度解析:理论推导、代码实现、长度外推
图解RoPE旋转位置编码及其特性
Rotary Positional Embeddings (RoPE)

llamaappcpullm深度解析公式计算
  • 本文作者:李琛
  • 本文链接: https://wapzz.net/post-17566.html
  • 版权声明:本博客所有文章除特别声明外,均默认采用 CC BY-NC-SA 4.0 许可协议。
本站部分内容来源于网络转载,仅供学习交流使用。如涉及版权问题,请及时联系我们,我们将第一时间处理。
文章很赞!支持一下吧 还没有人为TA充电
为TA充电
还没有人为TA充电
0
  • 支付宝打赏
    支付宝扫一扫
  • 微信打赏
    微信扫一扫
感谢支持
文章很赞!支持一下吧
关于作者
2.3W+
5
0
1
WAP站长官方

微软开源GraphRAG:极大增强大模型问答、摘要、推理

上一篇

如何使用共享GPU平台搭建LLAMA3环境(LLaMA-Factory)

下一篇
  • 复制图片
按住ctrl可打开默认菜单