def test_concat_dataset_basics(self): d = ConcatDataset([self.dataset_1, self.dataset_2]) assert (len(d) == 2) assert (d[0]['source'][0] == 1) assert (d[1]['source'][0] == 2) d = ConcatDataset([self.dataset_1, self.dataset_2], sample_ratios=[1, 2]) assert (len(d) == 3) assert (d[0]['source'][0] == 1) assert (d[1]['source'][0] == 2) assert (d[2]['source'][0] == 2) d = ConcatDataset([self.dataset_1, self.dataset_2], sample_ratios=[2, 1]) assert (len(d) == 3) assert (d[0]['source'][0] == 1) assert (d[1]['source'][0] == 1) assert (d[2]['source'][0] == 2)
def test_concat_dataset_basics(self): d = ConcatDataset([self.dataset_1, self.dataset_2]) assert len(d) == 2 assert d[0]["source"][0] == 1 assert d[1]["source"][0] == 2 d = ConcatDataset([self.dataset_1, self.dataset_2], sample_ratios=[1, 2]) assert len(d) == 3 assert d[0]["source"][0] == 1 assert d[1]["source"][0] == 2 assert d[2]["source"][0] == 2 d = ConcatDataset([self.dataset_1, self.dataset_2], sample_ratios=[2, 1]) assert len(d) == 3 assert d[0]["source"][0] == 1 assert d[1]["source"][0] == 1 assert d[2]["source"][0] == 2
def load_indexed_dataset(path, dictionary=None, dataset_impl=None, combine=False, default="cached"): """A helper function for loading indexed datasets. Args: path (str): path to indexed dataset (e.g., 'data-bin/train') dictionary (~fairseq.data.Dictionary): data dictionary dataset_impl (str, optional): which dataset implementation to use. If not provided, it will be inferred automatically. For legacy indexed data we use the 'cached' implementation by default. combine (bool, optional): automatically load and combine multiple datasets. For example, if *path* is 'data-bin/train', then we will combine 'data-bin/train', 'data-bin/train1', ... and return a single ConcatDataset instance. """ import fairseq.data.indexed_dataset as indexed_dataset from fairseq.data.concat_dataset import ConcatDataset datasets = [] for k in itertools.count(): path_k = path + (str(k) if k > 0 else "") try: path_k = indexed_dataset.get_indexed_dataset_to_local(path_k) except Exception as e: if "StorageException: [404] Path not found" in str(e): logger.warning(f"path_k: {e} not found") else: raise e dataset_impl_k = dataset_impl if dataset_impl_k is None: dataset_impl_k = indexed_dataset.infer_dataset_impl(path_k) dataset = indexed_dataset.make_dataset( path_k, impl=dataset_impl_k or default, fix_lua_indexing=True, dictionary=dictionary, ) if dataset is None: break logger.info("loaded {:,} examples from: {}".format( len(dataset), path_k)) datasets.append(dataset) if not combine: break if len(datasets) == 0: return None elif len(datasets) == 1: return datasets[0] else: return ConcatDataset(datasets)
def load_indexed_dataset(path, dictionary, dataset_impl=None, combine=False, default='cached'): """A helper function for loading indexed datasets. Args: path (str): path to indexed dataset (e.g., 'data-bin/train') dictionary (~fairseq.data.Dictionary): data dictionary dataset_impl (str, optional): which dataset implementation to use. If not provided, it will be inferred automatically. For legacy indexed data we use the 'cached' implementation by default. combine (bool, optional): automatically load and combine multiple datasets. For example, if *path* is 'data-bin/train', then we will combine 'data-bin/train', 'data-bin/train1', ... and return a single ConcatDataset instance. """ from fairseq.data.concat_dataset import ConcatDataset import fairseq.data.indexed_dataset as indexed_dataset datasets = [] for k in itertools.count(): ## 从0开始,无限加1遍历,用于存在多个训练集的情况 path_k = path + (str(k) if k > 0 else '') ##k=0时,名字不加入id,用于存在多个训练集的情况 dataset_impl_k = dataset_impl if dataset_impl_k is None: dataset_impl_k = indexed_dataset.infer_dataset_impl(path_k) #dataset_impl_k==lazy -->IndexedDataset类, #调用make_dataset函数构建IndexedDataset对象,并读入数据索引相关信息,即读入.idx后缀的文件。 #之后可以如同list按id索引数据,每次索引都是直接从二进制文件读取 dataset = indexed_dataset.make_dataset( path_k, impl=dataset_impl_k or default, fix_lua_indexing=True, dictionary=dictionary, ) if dataset is None: break print('| loaded {} examples from: {}'.format(len(dataset), path_k)) datasets.append(dataset) if not combine: break if len(datasets) == 0: return None elif len(datasets) == 1: return datasets[0] else: return ConcatDataset(datasets)
def load_indexed_dataset(path, dictionary, dataset_impl=None, combine=False, default='cached', path_xml=None): """A helper function for loading indexed datasets. Args: path (str): path to indexed dataset (e.g., 'data-bin/train') dictionary (~fairseq.data.Dictionary): data dictionary dataset_impl (str, optional): which dataset implementation to use. If not provided, it will be inferred automatically. For legacy indexed data we use the 'cached' implementation by default. combine (bool, optional): automatically load and combine multiple datasets. For example, if *path* is 'data-bin/train', then we will combine 'data-bin/train', 'data-bin/train1', ... and return a single ConcatDataset instance. """ from fairseq.data.concat_dataset import ConcatDataset import fairseq.data.indexed_dataset as indexed_dataset datasets = [] for k in itertools.count(): path_k = path + (str(k) if k > 0 else '') dataset_impl_k = dataset_impl if dataset_impl_k is None: dataset_impl_k = indexed_dataset.infer_dataset_impl(path_k) dataset = indexed_dataset.make_dataset( path_k, impl=dataset_impl_k or default, fix_lua_indexing=True, dictionary=dictionary, path_xml=path_xml, ) if dataset is None: break print('| loaded {} examples from: {}'.format(len(dataset), path_k)) datasets.append(dataset) if not combine: break if len(datasets) == 0: return None elif len(datasets) == 1: return datasets[0] else: return ConcatDataset(datasets)
def load_langpair_dataset( data_path, split, src, src_dict, tgt, tgt_dict, combine, dataset_impl, upsample_primary, left_pad_source, left_pad_target, max_source_positions, max_target_positions, prepend_bos=False, load_alignments=False, truncate_source=False, append_source_id=False, num_buckets=0, shuffle=True, pad_to_multiple=1, ): print("THIS IS LOADED") def split_exists(split, src, tgt, lang, data_path): filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang)) return indexed_dataset.dataset_exists(filename, impl=dataset_impl) src_datasets = [] tgt_datasets = [] for k in itertools.count(): split_k = split + (str(k) if k > 0 else '') # infer langcode if split_exists(split_k, src, tgt, src, data_path): prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, src, tgt)) elif split_exists(split_k, tgt, src, src, data_path): prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, tgt, src)) else: if k > 0: break else: raise FileNotFoundError('Dataset not found: {} ({})'.format( split, data_path)) src_dataset = data_utils.load_indexed_dataset(prefix + src, src_dict, dataset_impl) if truncate_source: src_dataset = AppendTokenDataset( TruncateDataset( StripTokenDataset(src_dataset, src_dict.eos()), max_source_positions - 1, ), src_dict.eos(), ) src_datasets.append(src_dataset) tgt_dataset = data_utils.load_indexed_dataset(prefix + tgt, tgt_dict, dataset_impl) if tgt_dataset is not None: tgt_datasets.append(tgt_dataset) logger.info('{} {} {}-{} {} examples'.format(data_path, split_k, src, tgt, len(src_datasets[-1]))) if not combine: break assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0 if len(src_datasets) == 1: src_dataset = src_datasets[0] tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None else: sample_ratios = [1] * len(src_datasets) sample_ratios[0] = upsample_primary src_dataset = ConcatDataset(src_datasets, sample_ratios) if len(tgt_datasets) > 0: tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) else: tgt_dataset = None if prepend_bos: assert hasattr(src_dict, "bos_index") and hasattr( tgt_dict, "bos_index") src_dataset = PrependTokenDataset(src_dataset, src_dict.bos()) if tgt_dataset is not None: tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos()) eos = None if append_source_id: src_dataset = AppendTokenDataset(src_dataset, src_dict.index('[{}]'.format(src))) if tgt_dataset is not None: tgt_dataset = AppendTokenDataset( tgt_dataset, tgt_dict.index('[{}]'.format(tgt))) eos = tgt_dict.index('[{}]'.format(tgt)) align_dataset = None if load_alignments: align_path = os.path.join(data_path, '{}.align.{}-{}'.format(split, src, tgt)) if indexed_dataset.dataset_exists(align_path, impl=dataset_impl): align_dataset = data_utils.load_indexed_dataset( align_path, None, dataset_impl) tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None return LanguagePairDataset( src_dataset, src_dataset.sizes, src_dict, tgt_dataset, tgt_dataset_sizes, tgt_dict, left_pad_source=left_pad_source, left_pad_target=left_pad_target, align_dataset=align_dataset, eos=src_dict.bos(), num_buckets=num_buckets, shuffle=shuffle, pad_to_multiple=pad_to_multiple, )