llama factory 是如何加载数据集 通过对数据集加载的代码的理解编写自定义数据集训练代码

llama factory 是如何加载数据集 通过对数据集加载的代码的理解编写自定义数据集训练代码

    正在检查是否收录...

第一层从训练代码追踪到以下代码

def get_dataset( tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments", data_args: "DataArguments", training_args: "Seq2SeqTrainingArguments", stage: Literal["pt", "sft", "rm", "ppo"], # split: Optional[str] = "train", # TODO: add split ) -> Union["Dataset", "IterableDataset"]: template = get_template_and_fix_tokenizer(tokenizer, data_args.template) if data_args.train_on_prompt and template.efficient_eos: raise ValueError("Current template does not support `train_on_prompt`.") # Load from cache if data_args.cache_path is not None: if os.path.exists(data_args.cache_path): logger.warning("Loading dataset from disk will ignore other data arguments.") dataset = load_from_disk(data_args.cache_path) if data_args.streaming: dataset = dataset.to_iterable_dataset() return dataset with training_args.main_process_first(desc="load dataset"): all_datasets = [] for dataset_attr in get_dataset_list(data_args): all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args)) dataset = merge_dataset(all_datasets, data_args, training_args) with training_args.main_process_first(desc="pre-process dataset"): preprocess_func, print_function = get_preprocess_and_print_func( tokenizer, template, data_args, training_args, stage ) column_names = list(next(iter(dataset)).keys()) kwargs = {} if not data_args.streaming: kwargs = dict( num_proc=data_args.preprocessing_num_workers, load_from_cache_file=(not data_args.overwrite_cache), desc="Running tokenizer on dataset", ) dataset = dataset.map(preprocess_func, batched=True, remove_columns=column_names, **kwargs) if data_args.cache_path is not None and not os.path.exists(data_args.cache_path): if training_args.should_save: dataset.save_to_disk(data_args.cache_path) logger.info("Dataset cache saved at {}.".format(data_args.cache_path)) if training_args.should_log: try: print_function(next(iter(dataset))) except StopIteration: raise RuntimeError("Cannot find valid samples, check `data/README.md` for the data format.") return dataset 

这段Python代码定义了一个名为get_dataset的函数,其目的是根据给定的参数加载和预处理一个数据集。下面是该函数的逐步解读:

函数参数: tokenizer: 一个预训练的tokenizer对象,用于处理文本数据。 model_args, data_args, training_args: 分别包含模型、数据和训练的参数。 stage: 指定当前的训练阶段,如"pt"(预训练)、“sft”(监督微调)、“rm”(奖励模型训练)或"ppo"(PPO训练)。 split: 指定数据集的分割,默认为"train"。 函数逻辑: 首先,获取模板并修复tokenizer(get_template_and_fix_tokenizer函数未在代码中给出)。 检查是否支持train_on_prompt功能,如果不支持则抛出错误。 尝试从磁盘加载数据集。如果设置了cache_path且该路径下数据集存在,则直接从磁盘加载,忽略其他数据参数。如果需要流式传输,则将数据集转换为可迭代的。 如果数据集不存在或需要重新生成,则使用get_dataset_list函数获取所有数据集属性,并使用load_single_dataset函数为每个属性加载数据集。然后,使用merge_dataset函数合并所有数据集。 对数据集进行预处理。预处理函数preprocess_func和打印函数print_functionget_preprocess_and_print_func函数返回。预处理包括将数据集的每一行映射到tokenizer。如果不在流式传输模式下,还会使用多进程进行预处理。 如果设置了cache_path,并且尚未创建,则将数据集保存到磁盘。 如果需要日志记录,则打印数据集的一个样本。 函数返回:
返回一个数据集对象,可以是普通的Dataset或可迭代的IterableDataset
这个函数的主要目的是提供一个统一的接口来加载、合并和预处理数据集,同时支持缓存和流式传输,适用于不同的训练阶段。

第二层 阅读加载单个数据的代码

def load_single_dataset( dataset_attr: "DatasetAttr", model_args: "ModelArguments", data_args: "DataArguments", ): logger.info("Loading dataset {}...".format(dataset_attr)) data_path, data_name, data_dir, data_files = None, None, None, None if dataset_attr.load_from in ["hf_hub", "ms_hub"]: data_path = dataset_attr.dataset_name data_name = dataset_attr.subset data_dir = dataset_attr.folder elif dataset_attr.load_from == "script": data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name) data_name = dataset_attr.subset data_dir = dataset_attr.folder elif dataset_attr.load_from == "file": data_files = [] local_path: str = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name) if os.path.isdir(local_path): # is directory for file_name in os.listdir(local_path): data_files.append(os.path.join(local_path, file_name)) if data_path is None: data_path = FILEEXT2TYPE.get(file_name.split(".")[-1], None) elif data_path != FILEEXT2TYPE.get(file_name.split(".")[-1], None): raise ValueError("File types should be identical.") elif os.path.isfile(local_path): # is file data_files.append(local_path) data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None) else: raise ValueError("File not found.") if data_path is None: raise ValueError("File extension must be txt, csv, json or jsonl.") checksum(data_files, dataset_attr.file_sha1) else: raise NotImplementedError if dataset_attr.load_from == "ms_hub": try: from modelscope import MsDataset from modelscope.utils.config_ds import MS_DATASETS_CACHE cache_dir = model_args.cache_dir or MS_DATASETS_CACHE dataset = MsDataset.load( dataset_name=data_path, subset_name=data_name, data_dir=data_dir, data_files=data_files, split=data_args.split, cache_dir=cache_dir, token=model_args.ms_hub_token, use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")), ).to_hf_dataset() except ImportError: raise ImportError("Please install modelscope via `pip install modelscope -U`") else: if "trust_remote_code" in inspect.signature(load_dataset).parameters: # for datasets==2.16.0 kwargs = {"trust_remote_code": True} else: kwargs = {} dataset = load_dataset( path=data_path, name=data_name, data_dir=data_dir, data_files=data_files, split=data_args.split, cache_dir=model_args.cache_dir, token=model_args.hf_hub_token, streaming=(data_args.streaming and (dataset_attr.load_from != "file")), **kwargs, ) if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter if data_args.max_samples is not None: # truncate dataset num_samples = min(data_args.max_samples, len(dataset)) dataset = dataset.select(range(num_samples)) return align_dataset(dataset, dataset_attr, data_args) 

是一个独立文件读取的Python函数,用于根据提供的参数加载数据集。下面是该函数的中文解释:

日志记录:记录开始加载数据集的信息。 确定数据路径和名称:根据数据集的来源(“hf_hub”、“ms_hub”、“script”或“file”),计算数据集文件的正确路径。 校验和验证:如果数据集是从本地文件加载的,函数会根据dataset_attr中提供的预期值校验文件的有效SHA1校验和。 数据集加载:使用datasets库中的load_dataset函数加载数据集。加载数据集的参数根据来源和提供的额外参数确定。 流调整:如果设置了data_args.streaming且数据集是从文件加载的,则将数据集转换为可迭代的,更适合流式传输的数据集。 数据集截断:如果设置了data_args.max_samples,则截断数据集到指定的样本数。 对齐数据集:调用align_dataset函数将数据集与dataset_attrdata_args对齐。这个函数在提供的代码中没有定义,所以它的确切行为是未知的。 返回数据集:返回已加载和处理的数据集。
请注意,该函数假设存在某些变量和函数,如loggerosinspectload_dataset,这些都是Python代码中的典型内容。此外,align_dataset在提供的代码中被引用,但没有定义,这表明可能还有其他代码定义了这个函数及其行为。

数据集tokenpythonpromptapptodojson日志记录scriptpython代码codefix预训练处理文本文本数据idecsv模型训练cto奖励模型
  • 本文作者:李琛
  • 本文链接: https://wapzz.net/post-17259.html
  • 版权声明:本博客所有文章除特别声明外,均默认采用 CC BY-NC-SA 4.0 许可协议。
本站部分内容来源于网络转载,仅供学习交流使用。如涉及版权问题,请及时联系我们,我们将第一时间处理。
文章很赞!支持一下吧 还没有人为TA充电
为TA充电
还没有人为TA充电
0
  • 支付宝打赏
    支付宝扫一扫
  • 微信打赏
    微信扫一扫
感谢支持
文章很赞!支持一下吧
关于作者
2.3W+
5
0
1
WAP站长官方

免费ai写作软件有哪些?分享10个给你 #媒体#学习#媒体

上一篇

曝苹果正多方下注布局AI商店:OpenAI终究只是备胎

下一篇
  • 复制图片
按住ctrl可打开默认菜单