def load_dataset(self, split, epoch=0, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        paths = self.args.data.split(':')
        assert len(paths) > 0
        data_path = paths[epoch % len(paths)]

        def split_exists(split, src, tgt, lang, data_path):
            filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang))
            return indexed_dataset.dataset_exists(filename, impl=self.args.dataset_impl)

        src_datasets = []
        tgt_datasets = []

        for k in itertools.count():
            split_k = split + (str(k) if k > 0 else '')

            # infer langcode
            src, tgt = self.args.source_lang, self.args.target_lang
            if split_exists(split_k, src, tgt, src, data_path):
                prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, src, tgt))
            elif split_exists(split_k, tgt, src, src, data_path):
                prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, tgt, src))
            else:
                if k > 0:
                    break
                else:
                    raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path))

            src_datasets.append(indexed_dataset.make_dataset(prefix + src, impl=self.args.dataset_impl,
                                                             fix_lua_indexing=True, dictionary=self.src_dict))
            tgt_datasets.append(indexed_dataset.make_dataset(prefix + tgt, impl=self.args.dataset_impl,
                                                             fix_lua_indexing=True, dictionary=self.tgt_dict))

            print('| {} {} {} examples'.format(data_path, split_k, len(src_datasets[-1])))

            if not combine:
                break

        assert len(src_datasets) == len(tgt_datasets)

        if len(src_datasets) == 1:
            src_dataset, tgt_dataset = src_datasets[0], tgt_datasets[0]
        else:
            sample_ratios = [1] * len(src_datasets)
            sample_ratios[0] = self.args.upsample_primary
            src_dataset = ConcatDataset(src_datasets, sample_ratios)
            tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)

        self.datasets[split] = LanguagePairDataset(
            src_dataset, src_dataset.sizes, self.src_dict,
            tgt_dataset, tgt_dataset.sizes, self.tgt_dict,
            left_pad_source=self.args.left_pad_source,
            left_pad_target=self.args.left_pad_target,
            max_source_positions=self.args.max_source_positions,
            max_target_positions=self.args.max_target_positions,
        )
def load_langpair_dataset(
        data_path, split,
        src, src_dict,
        tgt, tgt_dict,
        combine, dataset_impl, upsample_primary,
        left_pad_source, left_pad_target, max_source_positions, max_target_positions,
):
    def split_exists(split, src, tgt, lang, data_path):
        filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang))
        return indexed_dataset.dataset_exists(filename, impl=dataset_impl)

    src_datasets = []
    tgt_datasets = []

    for k in itertools.count():
        split_k = split + (str(k) if k > 0 else '')

        # infer langcode
        if split_exists(split_k, src, tgt, src, data_path):
            prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, src, tgt))
        elif split_exists(split_k, tgt, src, src, data_path):
            prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, tgt, src))
        else:
            if k > 0:
                break
            else:
                raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path))

        src_datasets.append(indexed_dataset.make_dataset(prefix + src, impl=dataset_impl,
                                                         fix_lua_indexing=True, dictionary=src_dict))
        tgt_datasets.append(indexed_dataset.make_dataset(prefix + tgt, impl=dataset_impl,
                                                         fix_lua_indexing=True, dictionary=tgt_dict))

        print('| {} {} {}-{} {} examples'.format(data_path, split_k, src, tgt, len(src_datasets[-1])))

        if not combine:
            break

    assert len(src_datasets) == len(tgt_datasets)

    if len(src_datasets) == 1:
        src_dataset, tgt_dataset = src_datasets[0], tgt_datasets[0]
    else:
        sample_ratios = [1] * len(src_datasets)
        sample_ratios[0] = upsample_primary
        src_dataset = ConcatDataset(src_datasets, sample_ratios)
        tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)

    return SpellCorrectDataset(
        src_dataset, src_dataset.sizes, src_dict,
        tgt_dataset, tgt_dataset.sizes, tgt_dict,
        left_pad_source=left_pad_source,
        left_pad_target=left_pad_target,
        max_source_positions=max_source_positions,
        max_target_positions=max_target_positions,
    )
Beispiel #3
0
def load_pair_dataset(
    data_path, split,
    src, src_dicts, src_feat,
    tgt, tgt_dict,
    combine, dataset_impl, upsample_primary,
    left_pad_source, left_pad_target, max_source_positions, max_target_positions,
):
    def split_exists(split, src, tgt, feat, data_path):
        filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, feat))
        return indexed_dataset.dataset_exists(filename, impl=dataset_impl)

    src_datasets = []
    tgt_datasets = []

    for k in itertools.count():
        split_k = split + (str(k) if k > 0 else '')
        if split_exists(split_k, src, tgt, tgt, data_path):
            prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, src, tgt))
        else:
            if k > 0:
                break
            else:
                raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path))

        src_data = FeatDict()#OrderedDict()
        for feat in src_feat:
            src_data[feat] = indexed_dataset.make_dataset(prefix + feat, impl=dataset_impl,
                                                         fix_lua_indexing=True, dictionary=src_dicts[feat])

        src_datasets.append(src_data)
        tgt_datasets.append(indexed_dataset.make_dataset(prefix + tgt, impl=dataset_impl,
                                                         fix_lua_indexing=True, dictionary=tgt_dict))
        print('| {} {} {}-{} {} examples'.format(data_path, split_k, src_feat[0], tgt, len(src_datasets[-1])))

        if not combine:
            break

    assert len(src_datasets) == len(tgt_datasets)
 
    if len(src_datasets) == 1:
        src_dataset, tgt_dataset = src_datasets[0], tgt_datasets[0]
    else:
        sample_ratios = [1] * len(src_datasets)
        sample_ratios[0] = upsample_primary
        src_dataset = ConcatDataset(src_datasets, sample_ratios)
        tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)

    return DataToTextDataset(
        src_dataset, src_dataset[src_feat[0]].sizes, src_dicts, src_feat,
        tgt_dataset, tgt_dataset.sizes, tgt_dict, 
        left_pad_source=left_pad_source,
        left_pad_target=left_pad_target,
        max_source_positions=max_source_positions,
        max_target_positions=max_target_positions,
        remove_eos_from_source=True,
    )
    def load_dataset(self, split, epoch=0, **kwargs):
        """Load a dataset split."""

        paths = self.args.data.split(':')
        assert len(paths) > 0
        data_path = paths[epoch % len(paths)]

        def split_exists(split, src, tgt, lang):
            filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang))
            return indexed_dataset.dataset_exists(filename, impl=self.args.dataset_impl)

        src_datasets, tgt_datasets = {}, {}
        for lang_pair in self.args.lang_pairs:
            src, tgt = lang_pair.split('-')
            if split_exists(split, src, tgt, src):
                prefix = os.path.join(data_path, '{}.{}-{}.'.format(split, src, tgt))
            elif split_exists(split, tgt, src, src):
                prefix = os.path.join(data_path, '{}.{}-{}.'.format(split, tgt, src))
            else:
                continue
            src_datasets[lang_pair] = indexed_dataset.make_dataset(prefix + src, impl=self.args.dataset_impl,
                                                                   fix_lua_indexing=True, dictionary=self.dicts[src])
            tgt_datasets[lang_pair] = indexed_dataset.make_dataset(prefix + tgt, impl=self.args.dataset_impl,
                                                                   fix_lua_indexing=True, dictionary=self.dicts[tgt])
            print('| {} {} {} examples'.format(data_path, split, len(src_datasets[lang_pair])))

        if len(src_datasets) == 0:
            raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path))

        def language_pair_dataset(lang_pair):
            src, tgt = lang_pair.split('-')
            src_dataset, tgt_dataset = src_datasets[lang_pair], tgt_datasets[lang_pair]
            return self.alter_dataset_langtok(
                LanguagePairDataset(
                    src_dataset, src_dataset.sizes, self.dicts[src],
                    tgt_dataset, tgt_dataset.sizes, self.dicts[tgt],
                    left_pad_source=self.args.left_pad_source,
                    left_pad_target=self.args.left_pad_target,
                    max_source_positions=self.args.max_source_positions,
                    max_target_positions=self.args.max_target_positions,
                ),
                src_eos=self.dicts[tgt].eos(),
                src_lang=src,
                tgt_lang=tgt,
            )

        self.datasets[split] = RoundRobinZipDatasets(
            OrderedDict([
                (lang_pair, language_pair_dataset(lang_pair))
                for lang_pair in self.args.lang_pairs
            ]),
            eval_key=None if self.training else "%s-%s" % (self.args.source_lang, self.args.target_lang),
        )
Beispiel #5
0
def load_indexed_dataset(path,
                         dictionary=None,
                         dataset_impl=None,
                         combine=False,
                         default="cached"):
    """A helper function for loading indexed datasets.

    Args:
        path (str): path to indexed dataset (e.g., 'data-bin/train')
        dictionary (~fairseq.data.Dictionary): data dictionary
        dataset_impl (str, optional): which dataset implementation to use. If
            not provided, it will be inferred automatically. For legacy indexed
            data we use the 'cached' implementation by default.
        combine (bool, optional): automatically load and combine multiple
            datasets. For example, if *path* is 'data-bin/train', then we will
            combine 'data-bin/train', 'data-bin/train1', ... and return a
            single ConcatDataset instance.
    """
    import fairseq.data.indexed_dataset as indexed_dataset
    from fairseq.data.concat_dataset import ConcatDataset

    datasets = []
    for k in itertools.count():
        path_k = path + (str(k) if k > 0 else "")
        try:
            path_k = indexed_dataset.get_indexed_dataset_to_local(path_k)
        except Exception as e:
            if "StorageException: [404] Path not found" in str(e):
                logger.warning(f"path_k: {e} not found")
            else:
                raise e

        dataset_impl_k = dataset_impl
        if dataset_impl_k is None:
            dataset_impl_k = indexed_dataset.infer_dataset_impl(path_k)
        dataset = indexed_dataset.make_dataset(
            path_k,
            impl=dataset_impl_k or default,
            fix_lua_indexing=True,
            dictionary=dictionary,
        )
        if dataset is None:
            break
        logger.info("loaded {:,} examples from: {}".format(
            len(dataset), path_k))
        datasets.append(dataset)
        if not combine:
            break
    if len(datasets) == 0:
        return None
    elif len(datasets) == 1:
        return datasets[0]
    else:
        return ConcatDataset(datasets)
Beispiel #6
0
    def compare_ds_data(self, summary, data, prefix, impl, vocab):
        self.assertEqual(summary.num_seq, len(data))
        self.assertEqual(summary.num_tok, sum([len(s) for s in data]))

        dataset = indexed_dataset.make_dataset(prefix, impl)

        self.assertEqual(len(dataset), len(data))
        decoded = [
            vocab.string(dataset[i]).split() for i in range(0, len(dataset))
        ]

        self.assertEqual(decoded, data)
        data_sizes = [i.item() for i in dataset.sizes]
        self.assertEqual(data_sizes, sizes(data))
Beispiel #7
0
def load_indexed_dataset(path,
                         dictionary,
                         dataset_impl=None,
                         combine=False,
                         default='cached'):
    """A helper function for loading indexed datasets.

    Args:
        path (str): path to indexed dataset (e.g., 'data-bin/train')
        dictionary (~fairseq.data.Dictionary): data dictionary
        dataset_impl (str, optional): which dataset implementation to use. If
            not provided, it will be inferred automatically. For legacy indexed
            data we use the 'cached' implementation by default.
        combine (bool, optional): automatically load and combine multiple
            datasets. For example, if *path* is 'data-bin/train', then we will
            combine 'data-bin/train', 'data-bin/train1', ... and return a
            single ConcatDataset instance.
    """
    from fairseq.data.concat_dataset import ConcatDataset
    import fairseq.data.indexed_dataset as indexed_dataset

    datasets = []
    for k in itertools.count():  ## 从0开始,无限加1遍历,用于存在多个训练集的情况
        path_k = path + (str(k) if k > 0 else '')  ##k=0时,名字不加入id,用于存在多个训练集的情况

        dataset_impl_k = dataset_impl
        if dataset_impl_k is None:
            dataset_impl_k = indexed_dataset.infer_dataset_impl(path_k)

        #dataset_impl_k==lazy -->IndexedDataset类,
        #调用make_dataset函数构建IndexedDataset对象,并读入数据索引相关信息,即读入.idx后缀的文件。
        #之后可以如同list按id索引数据,每次索引都是直接从二进制文件读取
        dataset = indexed_dataset.make_dataset(
            path_k,
            impl=dataset_impl_k or default,
            fix_lua_indexing=True,
            dictionary=dictionary,
        )
        if dataset is None:
            break
        print('| loaded {} examples from: {}'.format(len(dataset), path_k))
        datasets.append(dataset)
        if not combine:
            break
    if len(datasets) == 0:
        return None
    elif len(datasets) == 1:
        return datasets[0]
    else:
        return ConcatDataset(datasets)
Beispiel #8
0
def load_indexed_dataset(path,
                         dictionary,
                         dataset_impl=None,
                         combine=False,
                         default='cached',
                         path_xml=None):
    """A helper function for loading indexed datasets.

    Args:
        path (str): path to indexed dataset (e.g., 'data-bin/train')
        dictionary (~fairseq.data.Dictionary): data dictionary
        dataset_impl (str, optional): which dataset implementation to use. If
            not provided, it will be inferred automatically. For legacy indexed
            data we use the 'cached' implementation by default.
        combine (bool, optional): automatically load and combine multiple
            datasets. For example, if *path* is 'data-bin/train', then we will
            combine 'data-bin/train', 'data-bin/train1', ... and return a
            single ConcatDataset instance.
    """
    from fairseq.data.concat_dataset import ConcatDataset
    import fairseq.data.indexed_dataset as indexed_dataset

    datasets = []
    for k in itertools.count():
        path_k = path + (str(k) if k > 0 else '')

        dataset_impl_k = dataset_impl
        if dataset_impl_k is None:
            dataset_impl_k = indexed_dataset.infer_dataset_impl(path_k)

        dataset = indexed_dataset.make_dataset(
            path_k,
            impl=dataset_impl_k or default,
            fix_lua_indexing=True,
            dictionary=dictionary,
            path_xml=path_xml,
        )
        if dataset is None:
            break
        print('| loaded {} examples from: {}'.format(len(dataset), path_k))
        datasets.append(dataset)
        if not combine:
            break
    if len(datasets) == 0:
        return None
    elif len(datasets) == 1:
        return datasets[0]
    else:
        return ConcatDataset(datasets)
    def _load_single_lang_dataset(self, split, epoch):
        loaded_datasets = []

        paths = self.args.data.split(':')
        assert len(paths) > 0
        data_path = paths[epoch % len(paths)]

        for k in itertools.count():
            split_k = split + (str(k) if k > 0 else '')
            path = os.path.join(data_path, split_k)

            ds = indexed_dataset.make_dataset(
                path,
                impl=self.args.dataset_impl,
                fix_lua_indexing=True,
                dictionary=self.dictionary,
            )
            if ds is None:
                if k > 0:
                    break
                else:
                    raise FileNotFoundError(
                        'Dataset not found: {} ({})'.format(split, data_path))

            # Since we append each block with the classification_token,
            # we need to effectively create blocks of length
            # tokens_per_sample-1
            loaded_datasets.append(
                TokenBlockDataset(
                    ds,
                    ds.sizes,
                    self.args.tokens_per_sample - 1,
                    pad=self.dictionary.pad(),
                    eos=self.dictionary.eos(),
                ))

            print('| {} {} {} examples'.format(data_path, split_k,
                                               len(loaded_datasets[-1])))

        if len(loaded_datasets) == 1:
            dataset = loaded_datasets[0]
            sizes = dataset.sizes
        else:
            dataset = ConcatDataset(loaded_datasets)
            sizes = np.concatenate([ds.sizes for ds in loaded_datasets])

        return dataset, sizes
Beispiel #10
0
def main():
    parser = get_parser()
    args = parser.parse_args()

    dictionary = Dictionary.load(args.dict) if args.dict is not None else None
    dataset = indexed_dataset.make_dataset(args.input,
                                           impl=args.dataset_impl,
                                           fix_lua_indexing=True,
                                           dictionary=dictionary)

    for tensor_line in dataset:
        if dictionary is None:
            line = ' '.join([str(int(x)) for x in tensor_line])
        else:
            line = dictionary.string(tensor_line)

        print(line)
Beispiel #11
0
    def load_dataset(self, split, epoch=1, combine=False):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        loaded_datasets = []

        paths = utils.split_paths(self.args.data)
        assert len(paths) > 0
        data_path = paths[(epoch - 1) % len(paths)]
        logger.info("data_path", data_path)

        for k in itertools.count():
            split_k = split + (str(k) if k > 0 else '')
            path = os.path.join(data_path, split_k)
            ds = indexed_dataset.make_dataset(
                path,
                impl=self.args.dataset_impl,
                fix_lua_indexing=True,
                dictionary=self.dictionary,
            )

            if ds is None:
                if k > 0:
                    break
                else:
                    raise FileNotFoundError(
                        'Dataset not found: {} ({})'.format(split, data_path))

            with data_utils.numpy_seed(self.seed + k):
                loaded_datasets.append(
                    BlockPairDataset(
                        ds,
                        self.dictionary,
                        ds.sizes,
                        self.args.tokens_per_sample,
                        break_mode=self.args.break_mode,
                        doc_break_size=1,
                    ))

            logger.info('{} {} {} examples都是非常重要的例子'.format(
                data_path, split_k, len(loaded_datasets[-1])))

            if not combine:
                break

        if len(loaded_datasets) == 1:
            dataset = loaded_datasets[0]
            sizes = dataset.sizes
        else:
            dataset = ConcatDataset(loaded_datasets)
            sizes = np.concatenate([ds.sizes for ds in loaded_datasets])

        self.datasets[split] = MaskedLMDataset(
            dataset=dataset,
            sizes=sizes,
            vocab=self.dictionary,
            pad_idx=self.dictionary.pad(),
            mask_idx=self.dictionary.mask(),
            classif_token_idx=self.dictionary.cls(),
            sep_token_idx=self.dictionary.sep(),
            shuffle=self.args.shuffle_dataset,
            seed=self.seed,
        )
Beispiel #12
0
    def load_dataset(self, split, epoch=0, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """

        loaded_datasets = []

        paths = self.args.data.split(':')
        assert len(paths) > 0
        data_path = paths[epoch % len(paths)]

        for k in itertools.count():
            split_k = split + (str(k) if k > 0 else '')
            path = os.path.join(data_path, split_k)
            ds = indexed_dataset.make_dataset(path,
                                              impl=self.args.dataset_impl,
                                              fix_lua_indexing=True,
                                              dictionary=self.dictionary)

            if ds is None:
                if k > 0:
                    break
                else:
                    raise FileNotFoundError(
                        'Dataset not found: {} ({})'.format(split, data_path))

            loaded_datasets.append(
                TokenBlockDataset(
                    ds,
                    ds.sizes,
                    self.args.tokens_per_sample,
                    pad=self.dictionary.pad(),
                    eos=self.dictionary.eos(),
                    break_mode=self.args.sample_break_mode,
                    include_targets=True,
                ))

            print('| {} {} {} examples'.format(data_path, split_k,
                                               len(loaded_datasets[-1])))

            if not combine:
                break

        if len(loaded_datasets) == 1:
            dataset = loaded_datasets[0]
            sizes = dataset.sizes
        else:
            dataset = ConcatDataset(loaded_datasets)
            sizes = np.concatenate([ds.sizes for ds in loaded_datasets])

        add_eos_for_other_targets = self.args.sample_break_mode is not None and self.args.sample_break_mode != 'none'

        self.datasets[split] = MonolingualDataset(
            dataset,
            sizes,
            self.dictionary,
            self.output_dictionary,
            add_eos_for_other_targets=add_eos_for_other_targets,
            shuffle=True,
            targets=self.targets,
            add_bos_token=self.args.add_bos_token,
        )
def load_langpair_dataset(
    data_path, split,
    src, src_dict,
    tgt, tgt_dict,
    combine, dataset_impl, upsample_primary,
    left_pad_source, left_pad_target, max_source_positions, max_target_positions,
    ratio, pred_probs, bert_model_name,
):
    def split_exists(split, src, tgt, lang, data_path):
        filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang))
        return indexed_dataset.dataset_exists(filename, impl=dataset_impl)

    src_datasets = []
    tgt_datasets = []
    srcbert_datasets = []

    for k in itertools.count():
        split_k = split + (str(k) if k > 0 else '')

        # infer langcode
        if split_exists(split_k, src, tgt, src, data_path):
            prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, src, tgt))
            bertprefix = os.path.join(data_path, '{}.bert.{}-{}.'.format(split_k, src, tgt))
        elif split_exists(split_k, tgt, src, src, data_path):
            prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, tgt, src))
            bertprefix = os.path.join(data_path, '{}.bert.{}-{}.'.format(split_k, tgt, src))
        else:
            if k > 0:
                break
            else:
                raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path))
        src_datasets.append(indexed_dataset.make_dataset(prefix + src, impl=dataset_impl,
                                                         fix_lua_indexing=True, dictionary=src_dict))
        tgt_datasets.append(indexed_dataset.make_dataset(prefix + tgt, impl=dataset_impl,
                                                         fix_lua_indexing=True, dictionary=tgt_dict))
        srcbert_datasets.append(indexed_dataset.make_dataset(bertprefix + src, impl=dataset_impl,
                                                         fix_lua_indexing=True, ))

        print('| {} {} {}-{} {} examples'.format(data_path, split_k, src, tgt, len(src_datasets[-1])))

        if not combine:
            break

    assert len(src_datasets) == len(tgt_datasets)

    if len(src_datasets) == 1:
        src_dataset, tgt_dataset = src_datasets[0], tgt_datasets[0]
        srcbert_datasets = srcbert_datasets[0]
    else:
        sample_ratios = [1] * len(src_datasets)
        sample_ratios[0] = upsample_primary
        src_dataset = ConcatDataset(src_datasets, sample_ratios)
        tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)

    berttokenizer = BertTokenizer.from_pretrained(bert_model_name)
    if split == 'test':
        return BertLanguagePairDataset(
            src_dataset, src_dataset.sizes, src_dict,
            tgt_dataset, tgt_dataset.sizes, tgt_dict,
            left_pad_source=left_pad_source,
            left_pad_target=left_pad_target,
            max_source_positions=max_source_positions,
            max_target_positions=max_target_positions,
            srcbert=srcbert_datasets,
            srcbert_sizes=srcbert_datasets.sizes if srcbert_datasets is not None else None,
            berttokenizer=berttokenizer,
        )
    else:
        return BertXYNoisyLanguagePairDataset(
            src_dataset, src_dataset.sizes, src_dict,
            tgt_dataset, tgt_dataset.sizes, tgt_dict,
            left_pad_source=left_pad_source,
            left_pad_target=left_pad_target,
            max_source_positions=max_source_positions,
            max_target_positions=max_target_positions,
            shuffle=True,
            ratio=ratio,
            pred_probs=pred_probs,
            srcbert=srcbert_datasets,
            srcbert_sizes=srcbert_datasets.sizes if srcbert_datasets is not None else None,
            berttokenizer=berttokenizer,
        )