LLaMA-Adapter源码解析

LLaMA-Adapter源码解析

    正在检查是否收录...
一言准备中...

LLaMA-Adapter源码解析

伪代码

def transformer_block_with_llama_adapter(x, gating_factor, soft_prompt): residual =x y= zero_init_attention(soft_prompt, x) # llama-adapter: prepend prefix x= self_attention(x) x = x+ gating_factor * y # llama-adapter: apply zero_init_attention x = LayerNorm(x+residual) residual = x x = FullyConnectedLayers(x) x = AdapterLayers(x) x = LayerNorm(x + residual) return x 

源码

class Attention(nn.Module): def __init__(self, args: ModelArgs): super().__init__() self.n_local_heads = args.n_heads // fs_init.get_model_parallel_world_size() self.head_dim = args.dim // args.n_heads self.wq = ColumnParallelLinear( args.dim, args.n_heads * self.head_dim, bias=False, gather_output=False, init_method=lambda x: x, ) self.wk = ColumnParallelLinear( args.dim, args.n_heads * self.head_dim, bias=False, gather_output=False, init_method=lambda x: x, ) self.wv = ColumnParallelLinear( args.dim, args.n_heads * self.head_dim, bias=False, gather_output=False, init_method=lambda x: x, ) self.wo = RowParallelLinear( args.n_heads * self.head_dim, args.dim, bias=False, input_is_parallel=True, init_method=lambda x: x, ) self.cache_k = torch.zeros( (args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim) ).cuda() self.cache_v = torch.zeros( (args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim) ).cuda() self.gate = torch.nn.Parameter(torch.zeros(1)) def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], adapter=None): bsz, seqlen, _ = x.shape xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim) xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim) xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) self.cache_k = self.cache_k.to(xq) self.cache_v = self.cache_v.to(xq) self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv keys = self.cache_k[:bsz, : start_pos + seqlen] values = self.cache_v[:bsz, : start_pos + seqlen] if adapter is not None: adapter_len = adapter.shape[1] adapter_k = self.wk(adapter).view(1, adapter_len, self.n_local_heads, self.head_dim).repeat(bsz, 1, 1, 1) adapter_v = self.wv(adapter).view(1, adapter_len, self.n_local_heads, self.head_dim).repeat(bsz, 1, 1, 1) adapter_k = adapter_k.transpose(1, 2) adapter_v = adapter_v.transpose(1, 2) xq = xq.transpose(1, 2) keys = keys.transpose(1, 2) values = values.transpose(1, 2) scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) if mask is not None: scores = scores + mask # (bs, n_local_heads, slen, cache_len + slen) scores = F.softmax(scores.float(), dim=-1).type_as(xq) output = torch.matmul(scores, values) # (bs, n_local_heads, slen, head_dim) if adapter is not None: adapter_scores = torch.matmul(xq, adapter_k.transpose(2, 3)) / math.sqrt(self.head_dim) adapter_scores = self.gate * F.softmax(adapter_scores.float(), dim=-1).type_as(xq) output = output + torch.matmul(adapter_scores, adapter_v) output = output.transpose( 1, 2 ).contiguous().view(bsz, seqlen, -1) return self.wo(output) 

tpullamacodepromptappctotransformerurlfix
  • 本文作者:WAP站长网
  • 本文链接: https://wapzz.net/post-2310.html
  • 版权声明:本博客所有文章除特别声明外,均默认采用 CC BY-NC-SA 4.0 许可协议。
本站部分内容来源于网络转载,仅供学习交流使用。如涉及版权问题,请及时联系我们,我们将第一时间处理。
文章很赞!支持一下吧 还没有人为TA充电
为TA充电
还没有人为TA充电
0
  • 支付宝打赏
    支付宝扫一扫
  • 微信打赏
    微信扫一扫
感谢支持
文章很赞!支持一下吧
关于作者
2.7W+
9
1
2
WAP站长官方

AIGC:【LLM(四)】——LangChain+ChatGLM:本地知识库问答方案

上一篇

给我推荐20个比较流行的AI作画模型

下一篇
  • 复制图片
按住ctrl可打开默认菜单