源码解析LLaMA-Factory/src/llmtuner/data/template.py + Qwen模板

源码解析LLaMA-Factory/src/llmtuner/data/template.py + Qwen模板

    正在检查是否收录...
@dataclass class Template: format_user: "Formatter" format_assistant: "Formatter" format_system: "Formatter" format_function: "Formatter" format_observation: "Formatter" format_tools: "Formatter" format_separator: "Formatter" default_system: str stop_words: List[str] efficient_eos: bool replace_eos: bool force_system: bool def encode_oneturn( self, tokenizer: "PreTrainedTokenizer", messages: List[Dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, cutoff_len: int = 1_000_000, reserved_label_len: int = 1, ) -> Tuple[List[int], List[int]]: r""" Returns a single pair of token ids representing prompt and response respectively. """ encoded_pairs = self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len) prompt_ids = [] for query_ids, resp_ids in encoded_pairs[:-1]: prompt_ids += query_ids + resp_ids prompt_ids = prompt_ids + encoded_pairs[-1][0] answer_ids = encoded_pairs[-1][1] return prompt_ids, answer_ids def encode_multiturn( self, tokenizer: "PreTrainedTokenizer", messages: List[Dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, cutoff_len: int = 1_000_000, reserved_label_len: int = 1, ) -> Sequence[Tuple[List[int], List[int]]]: r""" Returns multiple pairs of token ids representing prompts and responses respectively. """ return self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len) def _encode( self, tokenizer: "PreTrainedTokenizer", messages: List[Dict[str, str]], system: str, tools: str, cutoff_len: int, reserved_label_len: int, ) -> Sequence[Tuple[List[int], List[int]]]: r""" Encodes formatted inputs to pairs of token ids. Turn 0: system + query resp Turn t: sep + query resp """ system = system or self.default_system encoded_messages = [] for i, message in enumerate(messages): elements = [] if i == 0 and (system or tools or self.force_system): tool_text = self.format_tools.apply(content=tools)[0] if tools else "" elements += self.format_system.apply(content=(system + tool_text)) elif i > 0 and i % 2 == 0: elements += self.format_separator.apply() if message["role"] == Role.USER.value: elements += self.format_user.apply(content=message["content"], idx=str(i // 2)) elif message["role"] == Role.ASSISTANT.value: elements += self.format_assistant.apply(content=message["content"]) elif message["role"] == Role.OBSERVATION.value: elements += self.format_observation.apply(content=message["content"]) elif message["role"] == Role.FUNCTION.value: elements += self.format_function.apply(content=message["content"]) else: raise NotImplementedError("Unexpected role: {}".format(message["role"])) encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements)) return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len)

LLamafactory用作Formatter的方法Qwen/lib/python3.12/abc.py-CSDN博客

这段代码定义了一个名为 Template 的数据类,并包含了多个方法用于将消息编码为一系列的 token id。我们逐行解释这段代码:

数据类定义

from dataclasses import dataclass from typing import List, Dict, Optional, Sequence, Tuple @dataclass class Template: 
使用 @dataclass 装饰器定义了一个 Template 类。dataclass 会自动生成初始化方法和其他方法。

类属性定义

 format_user: "Formatter" format_assistant: "Formatter" format_system: "Formatter" format_function: "Formatter" format_observation: "Formatter" format_tools: "Formatter" format_separator: "Formatter" default_system: str stop_words: List[str] efficient_eos: bool replace_eos: bool force_system: bool 
定义了多个类属性,这些属性主要是不同角色的格式化器(Formatter),以及一些控制编码行为的布尔值和字符串。

encode_oneturn 方法

python

复制

 def encode_oneturn( self, tokenizer: "PreTrainedTokenizer", messages: List[Dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, cutoff_len: int = 1_000_000, reserved_label_len: int = 1, ) -> Tuple[List[int], List[int]]: r""" Returns a single pair of token ids representing prompt and response respectively. """ 
定义了一个用于编码单轮对话的 encode_oneturn 方法。

python

复制

 encoded_pairs = self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len) 
调用私有方法 _encode 进行实际的编码工作,返回编码后的消息对。

python

复制

 prompt_ids = [] for query_ids, resp_ids in encoded_pairs[:-1]: prompt_ids += query_ids + resp_ids prompt_ids = prompt_ids + encoded_pairs[-1][0] answer_ids = encoded_pairs[-1][1] return prompt_ids, answer_ids 
将所有轮次的 query_ids 和 resp_ids 拼接成一个完整的 prompt_ids,并返回最后一轮的 response_ids 作为 answer_ids

encode_multiturn 方法

python

复制

 def encode_multiturn( self, tokenizer: "PreTrainedTokenizer", messages: List[Dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, cutoff_len: int = 1_000_000, reserved_label_len: int = 1, ) -> Sequence[Tuple[List[int], List[int]]]: r""" Returns multiple pairs of token ids representing prompts and responses respectively. """ 
定义了一个用于编码多轮对话的 encode_multiturn 方法。

python

复制

 return self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len) 
直接调用 _encode 方法并返回其结果。

_encode 私有方法

python

复制

 def _encode( self, tokenizer: "PreTrainedTokenizer", messages: List[Dict[str, str]], system: str, tools: str, cutoff_len: int, reserved_label_len: int, ) -> Sequence[Tuple[List[int], List[int]]]: r""" Encodes formatted inputs to pairs of token ids. Turn 0: system + query resp Turn t: sep + query resp """ 
定义了一个私有方法 _encode,用于实际的编码过程。
 system = system or self.default_system 
如果 system 参数为空,使用类属性 default_system

python

复制

 encoded_messages = [] for i, message in enumerate(messages): elements = [] if i == 0 and (system or tools or self.force_system): tool_text = self.format_tools.apply(content=tools)[0] if tools else "" elements += self.format_system.apply(content=(system + tool_text)) elif i > 0 and i % 2 == 0: elements += self.format_separator.apply() 
循环遍历消息列表,根据消息位置和内容组装元素列表。 对于第一条消息,如果有 system 或 tools 或者 force_system 为真,则应用 format_system 和(如果有的话) format_tools

从 _encode 方法的消息处理循环部分:

python

复制

 if message["role"] == Role.USER.value: elements += self.format_user.apply(content=message["content"], idx=str(i // 2)) elif message["role"] == Role.ASSISTANT.value: elements += self.format_assistant.apply(content=message["content"]) elif message["role"] == Role.OBSERVATION.value: elements += self.format_observation.apply(content=message["content"]) elif message["role"] == Role.FUNCTION.value: elements += self.format_function.apply(content=message["content"]) else: raise NotImplementedError("Unexpected role: {}".format(message["role"])) 
根据消息的角色(role)应用不同的格式化器: 如果角色是用户(Role.USER),使用 format_user。 如果角色是助手(Role.ASSISTANT),使用 format_assistant。 如果角色是观察(Role.OBSERVATION),使用 format_observation。 如果角色是功能(Role.FUNCTION),使用 format_function。 如果遇到未实现的角色,抛出 NotImplementedError

python

复制

 encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements)) 
将格式化后的元素转换为 token id 并添加到 encoded_messages 列表中。

python

复制

 return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len) 
调用 _make_pairs 方法,将编码后的消息组装成 token id 对并返回。

其他私有方法(假设)

以下是一些假设的私有方法,用于完成编码和转换的具体操作:

python

复制

 def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: List[str]) -> List[int]: """ Converts a list of formatted elements into token ids using the tokenizer. """ token_ids = [] for element in elements: token_ids.extend(tokenizer.encode(element, add_special_tokens=False)) return token_ids def _make_pairs(self, encoded_messages: List[List[int]], cutoff_len: int, reserved_label_len: int) -> Sequence[Tuple[List[int], List[int]]]: """ Creates pairs of prompt and response token ids with cutoff and reserved lengths. """ pairs = [] for i in range(0, len(encoded_messages), 2): query_ids = encoded_messages[i] resp_ids = encoded_messages[i + 1] if i + 1 < len(encoded_messages) else [] if len(query_ids) + len(resp_ids) > cutoff_len: query_ids = query_ids[:cutoff_len - len(resp_ids)] resp_ids = resp_ids[:reserved_label_len] pairs.append((query_ids, resp_ids)) return pairs 
_convert_elements_to_ids 方法:将格式化的元素列表转换为 token id 列表。 _make_pairs 方法:根据 cutoff_len 和 reserved_label_len 将编码后的消息组合成 token id 对。

总结

这段代码定义了一个 Template 类,通过不同的格式化器和编码方法,将一系列消息转换为 token id。主要方法包括 encode_oneturn 和 encode_multiturn,分别用于编码单轮和多轮对话。私有方法 _encode 负责实际的编码过程,利用 _convert_elements_to_ids 和 _make_pairs 方法完成元素转换和对话对的创建。

def get_template_and_fix_tokenizer( tokenizer: "PreTrainedTokenizer", name: Optional[str] = None, ) -> Template: if name is None: template = templates["vanilla"] # placeholder else: template = templates.get(name, None) if template is None: raise ValueError("Template {} does not exist.".format(name)) stop_words = template.stop_words if template.replace_eos: if not stop_words: raise ValueError("Stop words are required to replace the EOS token.") _add_or_replace_eos_token(tokenizer, eos_token=stop_words[0]) stop_words = stop_words[1:] if tokenizer.eos_token_id is None: _add_or_replace_eos_token(tokenizer, eos_token="<|endoftext|>") if tokenizer.pad_token_id is None: tokenizer.pad_token = tokenizer.eos_token logger.info("Add pad token: {}".format(tokenizer.pad_token)) if stop_words: num_added_tokens = tokenizer.add_special_tokens( dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False ) logger.info("Add {} to stop words.".format(",".join(stop_words))) if num_added_tokens > 0: logger.warning("New tokens have been added, make sure `resize_vocab` is True.") try: tokenizer.chat_template = _get_jinja_template(template, tokenizer) except ValueError: logger.info("Cannot add this chat template to tokenizer.") return template _register_template( name="alpaca", format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n\n### Response:\n"]), format_separator=EmptyFormatter(slots=["\n\n"]), default_system=( "Below is an instruction that describes a task. " "Write a response that appropriately completes the request." ), ) 

逐行解释这段代码。

get_template_and_fix_tokenizer 函数

python

复制

def get_template_and_fix_tokenizer( tokenizer: "PreTrainedTokenizer", name: Optional[str] = None, ) -> Template: 
定义一个函数 get_template_and_fix_tokenizer,用于获取模板并修正 tokenizer。参数包括 tokenizer 和可选的模板名称 name

python

复制

 if name is None: template = templates["vanilla"] # placeholder else: template = templates.get(name, None) if template is None: raise ValueError("Template {} does not exist.".format(name)) 
如果 name 参数为空,默认使用 vanilla 模板。 否则,尝试获取指定名称的模板。如果模板不存在,抛出 ValueError

python

复制

 stop_words = template.stop_words if template.replace_eos: if not stop_words: raise ValueError("Stop words are required to replace the EOS token.") _add_or_replace_eos_token(tokenizer, eos_token=stop_words[0]) stop_words = stop_words[1:] 
获取模板中的 stop_words。 如果模板要求替换 EOS(End Of Sentence)标记,但 stop_words 为空,抛出 ValueError。 否则,用 stop_words 中的第一个词替换 EOS 标记,并移除已使用的词。

解释 get_template_and_fix_tokenizer 函数:

python

复制

 if tokenizer.eos_token_id is None: _add_or_replace_eos_token(tokenizer, eos_token="

二、Qwen模板

_register_template( name="qwen", format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_separator=EmptyFormatter(slots=["\n"]), default_system="You are a helpful assistant.", stop_words=["<|im_end|>"], replace_eos=True, ) 

新代码解释:

_register_template(name="qwen", ...): 这行代码注册了一个名为"qwen"的模板。

format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]):
定义了用户输入的格式。它使用StringFormatter,将用户的内容包装在特定的标记中。

format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]):
定义了系统消息的格式,同样使用StringFormatter。

format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]):
定义了工具观察结果的格式,这在之前的代码中没有。

format_separator=EmptyFormatter(slots=["\n"]):
定义了分隔符,这里使用了一个空行。

default_system="You are a helpful assistant.":
设置了默认的系统消息。

stop_words=["<|im_end|>"]:
定义了停止词,当模型生成这个标记时会停止继续生成。

replace_eos=True:
指示是否替换结束标记。

与之前代码的主要区别: 旧代码指swift

实现方式:

新代码使用了函数式的方法来注册模板。 旧代码定义了一个类 QwenTemplate,继承自 Template

格式化方法:

新代码使用 StringFormatter 和 EmptyFormatter 来定义各种消息的格式。 旧代码直接在 __init__ 方法中定义字符串格式。

灵活性:

新代码提供了更多的格式化选项,如 format_observation 用于工具观察结果。 新代码明确定义了 format_separator,而旧代码没有。

默认系统消息:

新代码直接指定了默认系统消息。 旧代码使用了一个未显示的 DEFAULT_SYSTEM 常量。

停止词和EOS替换:

新代码明确定义了停止词和EOS替换选项。 旧代码没有这些明确的定义。

参数处理:

新代码没有 auto_add_bos 参数。 旧代码包含了 auto_add_bos 参数,用于控制是否自动添加开始标记。

总的来说,新代码提供了更灵活和详细的模板定义方式,包括了更多的格式化选项和控制参数。它似乎是为了提供更通用和可配置的模板注册方法,而旧代码更像是一个特定的模板实现。新代码可能更容易扩展和修改,以适应不同的需求。

Qwen范例

这段代码是在 LLaMA Factory 中注册 Qwen 模型的对话模板。让我为您渲染一下这个模板的真实输出效果,假设我们有一个包含系统消息、用户输入、助手回复和工具观察的对话:

<|im_start|>system
You are a helpful assistant.
<|im_end|>

<|im_start|>user
你好,请告诉我今天的天气如何。
<|im_end|>

<|im_start|>assistant
当然,我很乐意为您提供今天的天气信息。不过,作为一个AI助手,我没有实时访问天气数据的能力。为了给您最准确的信息,我需要使用一个天气查询工具。让我为您查询一下。

<|im_start|>tool
正在查询当前位置的天气信息...
查询结果:今天天气晴朗,气温在20°C到25°C之间,微风,适合户外活动。
<|im_end|>

参考文献或函数引用:

LLamafactory用作Formatter的方法Qwen/lib/python3.12/abc.py-CSDN博客

总结

tokenstemcodewordapppythonassistantprompt格式化fixllamactochatprompts多轮对话create模板定义ai助手参考文献elo
  • 本文作者:李琛
  • 本文链接: https://wapzz.net/post-20469.html
  • 版权声明:本博客所有文章除特别声明外,均默认采用 CC BY-NC-SA 4.0 许可协议。
本站部分内容来源于网络转载,仅供学习交流使用。如涉及版权问题,请及时联系我们,我们将第一时间处理。
文章很赞!支持一下吧 还没有人为TA充电
为TA充电
还没有人为TA充电
0
  • 支付宝打赏
    支付宝扫一扫
  • 微信打赏
    微信扫一扫
感谢支持
文章很赞!支持一下吧
关于作者
2.3W+
5
0
1
WAP站长官方

再战电商,AI会成为百度的“胜负手”吗?

上一篇

OpenAI“草莓”值万亿吗?

下一篇
  • 复制图片
按住ctrl可打开默认菜单