Exemple #1
0
def load_langpair_dataset(
    data_path,
    split,
    src,
    src_dict,
    tgt,
    tgt_dict,
    combine,
    dataset_impl,
    upsample_primary,
    left_pad_source,
    left_pad_target,
    max_source_positions,
    max_target_positions,
    prepend_bos=False,
    load_alignments=False,
    truncate_source=False,
    append_source_id=False,
    num_buckets=0,
    shuffle=True,
    pad_to_multiple=1,
    prepend_bos_src=None,
    bert_model_name=None,
    bart_model_name=None,
    electra_model_name=None,
    electra_pretrain=False,
    denoising=False,
    masking=False,
    extra_data=False,
    input_mapping=False,
    mask_ratio=None,
    random_ratio=None,
    insert_ratio=None,
    rotate_ratio=None,
    permute_sentence_ratio=None,
):
    def split_exists(split, src, tgt, lang, data_path):
        filename = os.path.join(data_path,
                                "{}.{}-{}.{}".format(split, src, tgt, lang))
        return indexed_dataset.dataset_exists(filename, impl=dataset_impl)

    src_datasets = []
    tgt_datasets = []
    bert_tokenizer = BertTokenizer.from_pretrained(bert_model_name,
                                                   do_lower_case=False)
    if denoising:
        bart_tokenizer = AutoTokenizer.from_pretrained(bart_model_name,
                                                       do_lower_case=False)
        #bart_tokenizer = BartTokenizer.from_pretrained(bart_model_name, do_lower_case=False)
    if electra_pretrain:
        electra_tokenizer = ElectraTokenizer.from_pretrained(
            electra_model_name)
    srcbert_datasets = []
    extra_datasets = []
    extra_bert_datasets = []
    extra_bert_mapping_datasets = []
    extra_bart_datasets = []
    extra_bart_mapping_datasets = []
    if denoising:
        srcbart_datasets = []
    if electra_pretrain:
        srcelectra_datasets = []
    for k in itertools.count():
        split_k = split + (str(k) if k > 0 else "")

        # infer langcode
        if split_exists(split_k, src, tgt, src, data_path):
            prefix = os.path.join(data_path,
                                  "{}.{}-{}.".format(split_k, src, tgt))
            bertprefix = os.path.join(
                data_path, '{}.bert.{}-{}.'.format(split_k, src, tgt))
            bert_mapping_prefix = os.path.join(
                data_path, '{}.bert.map.{}-{}.'.format(split_k, src, tgt))

            if denoising:
                bartprefix = os.path.join(
                    data_path, '{}.bart.{}-{}.'.format(split_k, src, tgt))
                bart_mapping_prefix = os.path.join(
                    data_path, '{}.bart.map.{}-{}.'.format(split_k, src, tgt))

            if electra_pretrain:
                electraprefix = os.path.join(
                    data_path, '{}.electra.{}-{}.'.format(split_k, src, tgt))
                electra_mapping_prefix = os.path.join(
                    data_path,
                    '{}.electra.map.{}-{}.'.format(split_k, src, tgt))

            if extra_data:
                extraprefix = os.path.join(
                    data_path, '{}.extra.{}-{}.'.format(split_k, src, tgt))
                extra_bert_prefix = os.path.join(
                    data_path,
                    '{}.extra.bert.{}-{}.'.format(split_k, src, tgt))
                extra_bert_mapping_prefix = os.path.join(
                    data_path,
                    '{}.extra.bert.map.{}-{}.'.format(split_k, src, tgt))
                extra_bart_prefix = os.path.join(
                    data_path,
                    '{}.extra.bart.{}-{}.'.format(split_k, src, tgt))
                extra_bart_mapping_prefix = os.path.join(
                    data_path,
                    '{}.extra.bart.map.{}-{}.'.format(split_k, src, tgt))

        elif split_exists(split_k, tgt, src, src, data_path):
            prefix = os.path.join(data_path,
                                  "{}.{}-{}.".format(split_k, tgt, src))
            bertprefix = os.path.join(
                data_path, '{}.bert.{}-{}.'.format(split_k, tgt, src))
            bert_mapping_prefix = os.path.join(
                data_path, '{}.bert.map.{}-{}.'.format(split_k, src, tgt))

            if denoising:
                bartprefix = os.path.join(
                    data_path, '{}.bart.{}-{}.'.format(split_k, tgt, src))
                bart_mapping_prefix = os.path.join(
                    data_path, '{}.bart.map.{}-{}.'.format(split_k, src, tgt))

            if electra_pretrain:
                electraprefix = os.path.join(
                    data_path, '{}.electra.{}-{}.'.format(split_k, src, tgt))
                electra_mapping_prefix = os.path.join(
                    data_path,
                    '{}.electra.map.{}-{}.'.format(split_k, src, tgt))

            if extra_data:
                extraprefix = os.path.join(
                    data_path, '{}.extra.{}-{}.'.format(split_k, src, tgt))
                extra_bert_prefix = os.path.join(
                    data_path,
                    '{}.extra.bert.{}-{}.'.format(split_k, src, tgt))
                extra_bert_mapping_prefix = os.path.join(
                    data_path,
                    '{}.extra.bert.map.{}-{}.'.format(split_k, src, tgt))
                extra_bart_prefix = os.path.join(
                    data_path,
                    '{}.extra.bart.{}-{}.'.format(split_k, src, tgt))
                extra_bart_mapping_prefix = os.path.join(
                    data_path,
                    '{}.extra.bart.map.{}-{}.'.format(split_k, src, tgt))

        else:
            if k > 0:
                break
            else:
                raise FileNotFoundError("Dataset not found: {} ({})".format(
                    split, data_path))

        src_dataset = data_utils.load_indexed_dataset(prefix + src, src_dict,
                                                      dataset_impl)
        if truncate_source:
            src_dataset = AppendTokenDataset(
                TruncateDataset(
                    StripTokenDataset(src_dataset, src_dict.eos()),
                    max_source_positions - 1,
                ),
                src_dict.eos(),
            )
        src_datasets.append(src_dataset)

        tgt_dataset = data_utils.load_indexed_dataset(prefix + tgt, tgt_dict,
                                                      dataset_impl)
        if tgt_dataset is not None:
            tgt_datasets.append(tgt_dataset)

        # srcbert_datasets.append(indexed_dataset.make_dataset(bertprefix + src, impl=dataset_impl,
        #                                                      fix_lua_indexing=True, ))
        # if denoising:
        #     srcbart_datasets.append(indexed_dataset.make_dataset(bartprefix + src, impl=dataset_impl,
        #                                                          fix_lua_indexing=True, ))
        # if extra_data:
        #     extra_datasets.append(indexed_dataset.make_dataset(extraprefix + src, impl=dataset_impl,
        #                                                        fix_lua_indexing=True, ))
        srcbert_datasets.append(
            data_utils.load_indexed_dataset(
                bertprefix + src,
                dataset_impl=dataset_impl,
            ))
        if denoising:
            srcbart_datasets.append(
                data_utils.load_indexed_dataset(
                    bartprefix + src,
                    dataset_impl=dataset_impl,
                ))
        if electra_pretrain:
            srcelectra_datasets.append(
                data_utils.load_indexed_dataset(
                    electraprefix + src,
                    dataset_impl=dataset_impl,
                ))
        if extra_data and split == 'train':
            extra_datasets.append(
                data_utils.load_indexed_dataset(
                    extraprefix + src,
                    dataset_impl=dataset_impl,
                ))
            extra_bert_datasets.append(
                data_utils.load_indexed_dataset(
                    extra_bert_prefix + src,
                    dataset_impl=dataset_impl,
                ))
            extra_bert_mapping_datasets.append(
                data_utils.load_indexed_dataset(
                    extra_bert_mapping_prefix + src,
                    dataset_impl=dataset_impl,
                ))
            extra_bart_datasets.append(
                data_utils.load_indexed_dataset(
                    extra_bart_prefix + src,
                    dataset_impl=dataset_impl,
                ))
            extra_bart_mapping_datasets.append(
                data_utils.load_indexed_dataset(
                    extra_bart_mapping_prefix + src,
                    dataset_impl=dataset_impl,
                ))
            #import pdb; pdb.set_trace()
            assert extra_datasets != [] or extra_bert_datasets != [] or extra_bert_mapping_datasets != [] or extra_bart_datasets != [] or extra_bart_mapping_datasets != []

            #extra_datasets = extra_datasets[0]
        #import pdb; pdb.set_trace()
        src_datasets[-1] = PrependTokenDataset(src_datasets[-1],
                                               token=src_dict.bos_index)
        if extra_data and split == 'train':
            extra_datasets[-1] = PrependTokenDataset(extra_datasets[-1],
                                                     token=src_dict.bos_index)
        if denoising is True:
            if input_mapping is True and split == 'train':
                bart_mapping_dataset = data_utils.load_indexed_dataset(
                    bart_mapping_prefix + src, dataset_impl=dataset_impl)
            else:
                bart_mapping_dataset = None

            src_datasets[-1] = DenoisingBartDataset(
                src_datasets[-1],
                src_datasets[-1].sizes,
                src_dict,
                srcbart_datasets[-1],
                srcbart_datasets[-1].sizes,
                bart_tokenizer,
                map_dataset=bart_mapping_dataset,
                mask_ratio=mask_ratio,
                random_ratio=random_ratio,
                insert_ratio=insert_ratio,
                rotate_ratio=rotate_ratio,
                permute_sentence_ratio=permute_sentence_ratio,
            )

        if electra_pretrain is True:
            if input_mapping is True and split == 'train':
                electra_mapping_dataset = data_utils.load_indexed_dataset(
                    electra_mapping_prefix + src, dataset_impl=dataset_impl)
            else:
                electra_mapping_dataset = None

            src_datasets[-1] = ElectrapretrainDataset(
                src_datasets[-1],
                src_datasets[-1].sizes,
                src_dict,
                srcelectra_datasets[-1],
                srcelectra_datasets[-1].sizes,
                electra_tokenizer,
                map_dataset=electra_mapping_dataset,
                left_pad_source=left_pad_source,
                left_pad_target=left_pad_target,
                max_source_positions=max_source_positions,
                max_target_positions=max_target_positions,
            )

        if masking is True:
            if input_mapping is True and split == 'train':
                #bert_mapping_dataset = indexed_dataset.make_dataset(bert_mapping_prefix + src, impl=dataset_impl, fix_lua_indexing=True)
                bert_mapping_dataset = data_utils.load_indexed_dataset(
                    bert_mapping_prefix + src, dataset_impl=dataset_impl)
            else:
                bert_mapping_dataset = None
            src_datasets[-1] = MaskingDataset(
                src_datasets[-1],
                src_datasets[-1].sizes,
                src_dict,
                srcbert_datasets[-1],
                srcbert_datasets[-1].sizes,
                bert_tokenizer,
                map_dataset=bert_mapping_dataset,
                left_pad_source=left_pad_source,
                left_pad_target=left_pad_target,
                max_source_positions=max_source_positions,
                max_target_positions=max_target_positions,
            )

        if extra_data is True and split == 'train':

            assert input_mapping is True
            src_datasets[-1] = MaskingExtraDataset(
                src_datasets[-1],
                src_datasets[-1].sizes,
                src_dict,
                extra_datasets[-1],
                extra_datasets[-1].sizes,
                extra_bert_datasets[-1],
                extra_bert_datasets[-1].sizes,
                bert_tokenizer,
                map_dataset=extra_bert_mapping_datasets[-1],
                left_pad_source=left_pad_source,
                left_pad_target=left_pad_target,
                max_source_positions=max_source_positions,
                max_target_positions=max_target_positions,
            )

            src_datasets[-1] = DenoisingBartExtraDataset(
                src_datasets[-1],
                src_datasets[-1].sizes,
                src_dict,
                extra_datasets[-1],
                extra_datasets[-1].sizes,
                extra_bart_datasets[-1],
                extra_bart_datasets[-1].sizes,
                bart_tokenizer,
                map_dataset=extra_bart_mapping_datasets[-1],
            )

        logger.info("{} {} {}-{} {} examples".format(data_path, split_k,
                                                     src, tgt,
                                                     len(src_datasets[-1])))

        if not combine:
            break

    assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0

    if len(src_datasets) == 1:
        src_dataset = src_datasets[0]
        tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None
        # srcbert_datasets = srcbert_datasets[0]
        # if denoising:
        #     srcbart_datasets = srcbart_datasets[0]

    else:
        sample_ratios = [1] * len(src_datasets)
        sample_ratios[0] = upsample_primary
        src_dataset = ConcatDataset(src_datasets, sample_ratios)
        if len(tgt_datasets) > 0:
            tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)
        else:
            tgt_dataset = None

    if prepend_bos:
        assert hasattr(src_dict, "bos_index") and hasattr(
            tgt_dict, "bos_index")
        src_dataset = PrependTokenDataset(src_dataset, src_dict.bos())
        if tgt_dataset is not None:
            tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos())
    elif prepend_bos_src is not None:
        logger.info(f"prepending src bos: {prepend_bos_src}")
        src_dataset = PrependTokenDataset(src_dataset, prepend_bos_src)

    eos = None
    if append_source_id:
        src_dataset = AppendTokenDataset(src_dataset,
                                         src_dict.index("[{}]".format(src)))
        if tgt_dataset is not None:
            tgt_dataset = AppendTokenDataset(
                tgt_dataset, tgt_dict.index("[{}]".format(tgt)))
        eos = tgt_dict.index("[{}]".format(tgt))

    align_dataset = None
    if load_alignments:
        align_path = os.path.join(data_path,
                                  "{}.align.{}-{}".format(split, src, tgt))
        if indexed_dataset.dataset_exists(align_path, impl=dataset_impl):
            align_dataset = data_utils.load_indexed_dataset(
                align_path, None, dataset_impl)

    tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None

    src_bart_dataset = None
    src_bert_dataset = None
    src_electra_dataset = None

    return LanguagePairDataset(
        src_dataset,
        src_dataset.sizes,
        src_dict,
        tgt_dataset,
        tgt_dataset_sizes,
        tgt_dict,
        masking,
        src_bert_dataset,
        denoising,
        src_bart_dataset,
        src_electra_dataset,
        #extra_datasets,
        left_pad_source=left_pad_source,
        left_pad_target=left_pad_target,
        align_dataset=align_dataset,
        eos=eos,
        num_buckets=num_buckets,
        shuffle=shuffle,
        pad_to_multiple=pad_to_multiple,
    )
Exemple #2
0
    def load_dataset(self, split, epoch=1, **kwargs):
        """Load a dataset split."""

        def indexed_dataset(path, dictionary):
            if self.args.raw_text:
                raise Exception("Unable to handle raw text.")
            dataset = IndexedDataset(path, fix_lua_indexing=True)

            return dataset

        pair_datasets = OrderedDict()

        if split == "valid":
            self.datasets[split] = pair_datasets
            return

        if split not in self.config:
            raise FileNotFoundError(
                "Dataset not found in config file: {}".format(split)
            )

        size_by_corpus = defaultdict(int)
        size_sum = 0
        size_sum_with_subsampling = 0
        init_pair_datasets = {}

        for dataset_config in self.config[split]:
            src_path = os.path.dirname(dataset_config["src"])
            corpus_name = src_path.split("/")[-2]
            language_pair_name = src_path.split("/")[-1]
            pair_datasets_key = corpus_name + "-" + language_pair_name

            logger.info(f"loading... {pair_datasets_key}")
            if "src" in dataset_config:
                src_dataset = indexed_dataset(
                    dataset_config["src"], self.src_dictionary
                )
            else:
                src_dataset = None

            if "tgt" in dataset_config:
                tgt_dataset = indexed_dataset(
                    dataset_config["tgt"], self.tgt_dictionary
                )
            else:
                tgt_dataset = None

            dataset = LanguagePairDataset(
                src_dataset,
                src_dataset.sizes,
                self.src_dictionary,
                tgt_dataset,
                tgt_dataset.sizes,
                self.tgt_dictionary,
                left_pad_source=self.args.left_pad_source,
                left_pad_target=self.args.left_pad_target,
            )

            if pair_datasets_key in init_pair_datasets:
                logger.warning(
                    f"Ignoring already added {pair_datasets_key}. "
                    f"Consider using `sample` key in order to upsample."
                )
            else:
                init_pair_datasets[pair_datasets_key] = {
                    "dataset": dataset,
                    "sample": dataset_config.get("sample", None),
                    "id": dataset_config.get("id", None),
                    "len": len(dataset),
                }

        length_sum = 0
        weighted_freqs_sum = 0
        freq_per_dataset = {}
        vmax = 0
        vmin = 1
        weighted_freq_per_dataset = {}

        if self.args.weighting_alpha:
            for key in init_pair_datasets:
                if init_pair_datasets[key]["sample"] is None:
                    length_sum += len(init_pair_datasets[key]["dataset"])

            for key in init_pair_datasets:
                if init_pair_datasets[key]["sample"] is None:
                    val = float(init_pair_datasets[key]["len"]) / length_sum
                    freq_per_dataset[key] = val
                    weighted_freqs_sum += val ** self.args.weighting_alpha

            for key in freq_per_dataset:
                val = (
                    freq_per_dataset[key] ** self.args.weighting_alpha
                    / weighted_freqs_sum
                )
                vmin = min(vmin, val)
                vmax = max(vmax, val)
                weighted_freq_per_dataset[key] = val

        for pair_datasets_key in init_pair_datasets:
            dataset_config = init_pair_datasets[pair_datasets_key]
            dataset = dataset_config["dataset"]
            sample = dataset_config["sample"]
            if sample is None:
                sample = 1.0

            if pair_datasets_key in weighted_freq_per_dataset:
                w = vmax / weighted_freq_per_dataset[pair_datasets_key]
                sample = w

            sample = round(sample)

            initial_sample = sample
            initial_pair_datasets_key = pair_datasets_key

            while sample >= 1.0:
                assert (
                    pair_datasets_key not in pair_datasets
                ), f"{pair_datasets_key} already in"
                size_sum_with_subsampling += len(dataset)
                pair_datasets[pair_datasets_key] = MultitaskDatasetWrapper(
                    dataset, dataset_config.get("id", 0), 1.0, name=pair_datasets_key
                )
                size_sum += len(dataset)
                sample -= 1.0
                pair_datasets_key += "-up"

            assert sample < 1e-6, f"sample remains > 0 {pair_datasets_key}"

            logger.info(
                f"added pair {initial_pair_datasets_key} length {len(dataset)} new_length = {len(dataset)*initial_sample}"
            )
            size_by_corpus[corpus_name] += len(dataset)

        self.datasets[split] = pair_datasets
        logger.info(
            f"Datasets number = {len(self.datasets[split])} size = {size_sum} size_sum_with_subsampling = {size_sum_with_subsampling}"
        )
    def load_dataset(self, split, epoch=0, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        paths = self.args.data.split(':')
        assert len(paths) > 0
        data_path = paths[epoch % len(paths)]

        def split_exists(split, src, tgt, lang, data_path):
            filename = os.path.join(
                data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang))
            return indexed_dataset.dataset_exists(filename,
                                                  impl=self.args.dataset_impl)

        src_datasets = []
        tgt_datasets = []

        for k in itertools.count():
            split_k = split + (str(k) if k > 0 else '')

            # infer langcode
            src, tgt = self.args.source_lang, self.args.target_lang
            if split_exists(split_k, src, tgt, src, data_path):
                prefix = os.path.join(data_path,
                                      '{}.{}-{}.'.format(split_k, src, tgt))
            elif split_exists(split_k, tgt, src, src, data_path):
                prefix = os.path.join(data_path,
                                      '{}.{}-{}.'.format(split_k, tgt, src))
            else:
                if k > 0:
                    break
                else:
                    raise FileNotFoundError(
                        'Dataset not found: {} ({})'.format(split, data_path))

            src_datasets.append(
                indexed_dataset.make_dataset(prefix + src,
                                             impl=self.args.dataset_impl,
                                             fix_lua_indexing=True,
                                             dictionary=self.src_dict))
            tgt_datasets.append(
                indexed_dataset.make_dataset(prefix + tgt,
                                             impl=self.args.dataset_impl,
                                             fix_lua_indexing=True,
                                             dictionary=self.tgt_dict))

            print('| {} {} {} examples'.format(data_path, split_k,
                                               len(src_datasets[-1])))

            if not combine:
                break

        assert len(src_datasets) == len(tgt_datasets)

        if len(src_datasets) == 1:
            src_dataset, tgt_dataset = src_datasets[0], tgt_datasets[0]
        else:
            sample_ratios = [1] * len(src_datasets)
            sample_ratios[0] = self.args.upsample_primary
            src_dataset = ConcatDataset(src_datasets, sample_ratios)
            tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)

        self.datasets[split] = LanguagePairDataset(
            src_dataset,
            src_dataset.sizes,
            self.src_dict,
            tgt_dataset,
            tgt_dataset.sizes,
            self.tgt_dict,
            left_pad_source=self.args.left_pad_source,
            left_pad_target=self.args.left_pad_target,
            max_source_positions=self.args.max_source_positions,
            max_target_positions=self.args.max_target_positions,
        )
Exemple #4
0
 def build_dataset_for_inference(self, src_tokens, src_lengths):
     return LanguagePairDataset(src_tokens, src_lengths,
                                self.source_dictionary)
    def load_dataset(self, split, epoch=0, **kwargs):
        """Load a dataset split."""

        paths = self.args.data.split(os.pathsep)
        assert len(paths) > 0
        data_path = paths[epoch % len(paths)]

        def split_exists(split, src, tgt, lang):
            if src is not None:
                filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang))
            else:
                filename = os.path.join(data_path, '{}.{}-None.{}'.format(split, src, tgt))
            return indexed_dataset.dataset_exists(filename, impl=self.args.dataset_impl)

        def load_indexed_dataset(path, dictionary):
            return data_utils.load_indexed_dataset(path, dictionary, self.args.dataset_impl)

        # load parallel datasets
        src_datasets, tgt_datasets = {}, {}
        if (self.lambda_parallel > 0.0 or self.lambda_parallel_steps is not None or not split.startswith("train")):
            for lang_pair in self.lang_pairs:
                src, tgt = lang_pair.split('-')
                if split_exists(split, src, tgt, src):
                    prefix = os.path.join(data_path, '{}.{}-{}.'.format(split, src, tgt))
                elif split_exists(split, tgt, src, src):
                    prefix = os.path.join(data_path, '{}.{}-{}.'.format(split, tgt, src))
                else:
                    continue
                src_datasets[lang_pair] = load_indexed_dataset(prefix + src, self.dicts[src])
                tgt_datasets[lang_pair] = load_indexed_dataset(prefix + tgt, self.dicts[tgt])
                logger.info('parallel-{} {} {} examples'.format(data_path, split, len(src_datasets[lang_pair])))
            if len(src_datasets) == 0:
                raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path))

        # back translation datasets
        backtranslate_datasets = {}
        if (self.lambda_otf_bt > 0.0 or self.lambda_otf_bt_steps is not None) and split.startswith("train"):
            for lang_pair in self.lang_pairs:
                src, tgt = lang_pair.split('-')
                if not split_exists(split, tgt, None, tgt):
                    raise FileNotFoundError('Dataset not found: backtranslation {} ({})'.format(split, data_path))
                filename = os.path.join(data_path, '{}.{}-None.{}'.format(split, tgt, tgt))
                dataset = load_indexed_dataset(filename, self.dicts[tgt])
                lang_pair_dataset_tgt = LanguagePairDataset(
                    dataset,
                    dataset.sizes,
                    self.dicts[tgt],
                    left_pad_source=self.args.left_pad_source,
                    left_pad_target=self.args.left_pad_target,
                )
                lang_pair_dataset = LanguagePairDataset(
                    dataset,
                    dataset.sizes,
                    src_dict=self.dicts[src],
                    tgt=dataset,
                    tgt_sizes=dataset.sizes,
                    tgt_dict=self.dicts[tgt],
                    left_pad_source=self.args.left_pad_source,
                    left_pad_target=self.args.left_pad_target,
                )
                backtranslate_datasets[lang_pair] = BacktranslationDataset(
                    tgt_dataset=self.alter_dataset_langtok(
                        lang_pair_dataset_tgt,
                        src_eos=self.dicts[tgt].eos(),
                        src_lang=tgt,
                        tgt_lang=src,
                    ),
                    backtranslation_fn=self.backtranslators[lang_pair],
                    src_dict=self.dicts[src], tgt_dict=self.dicts[tgt],
                    output_collater=self.alter_dataset_langtok(
                        lang_pair_dataset=lang_pair_dataset,
                        src_eos=self.dicts[src].eos(),
                        src_lang=src,
                        tgt_eos=self.dicts[tgt].eos(),
                        tgt_lang=tgt,
                    ).collater,
                )
                logger.info('backtranslate-{}: {} {} {} examples'.format(
                    tgt, data_path, split, len(backtranslate_datasets[lang_pair]),
                ))
                self.backtranslate_datasets[lang_pair] = backtranslate_datasets[lang_pair]

        # denoising autoencoder
        noising_datasets = {}
        if (self.lambda_denoising > 0.0 or self.lambda_denoising_steps is not None) and split.startswith("train"):
            for lang_pair in self.lang_pairs:
                _, tgt = lang_pair.split('-')
                if not split_exists(split, tgt, None, tgt):
                    continue
                filename = os.path.join(data_path, '{}.{}-None.{}'.format(split, tgt, tgt))
                tgt_dataset1 = load_indexed_dataset(filename, self.dicts[tgt])
                tgt_dataset2 = load_indexed_dataset(filename, self.dicts[tgt])
                noising_dataset = NoisingDataset(
                    tgt_dataset1,
                    self.dicts[tgt],
                    seed=1,
                    max_word_shuffle_distance=self.args.max_word_shuffle_distance,
                    word_dropout_prob=self.args.word_dropout_prob,
                    word_blanking_prob=self.args.word_blanking_prob,
                )
                noising_datasets[lang_pair] = self.alter_dataset_langtok(
                    LanguagePairDataset(
                        noising_dataset,
                        tgt_dataset1.sizes,
                        self.dicts[tgt],
                        tgt_dataset2,
                        tgt_dataset2.sizes,
                        self.dicts[tgt],
                        left_pad_source=self.args.left_pad_source,
                        left_pad_target=self.args.left_pad_target,
                    ),
                    src_eos=self.dicts[tgt].eos(),
                    src_lang=tgt,
                    tgt_eos=self.dicts[tgt].eos(),
                    tgt_lang=tgt,
                )
                logger.info('denoising-{}: {} {} {} examples'.format(
                    tgt, data_path, split, len(noising_datasets[lang_pair]),
                ))

        def language_pair_dataset(lang_pair):
            src, tgt = lang_pair.split('-')
            src_dataset, tgt_dataset = src_datasets[lang_pair], tgt_datasets[lang_pair]
            return self.alter_dataset_langtok(
                LanguagePairDataset(
                    src_dataset, src_dataset.sizes, self.dicts[src],
                    tgt_dataset, tgt_dataset.sizes, self.dicts[tgt],
                    left_pad_source=self.args.left_pad_source,
                    left_pad_target=self.args.left_pad_target,
                    max_source_positions=self.args.max_source_positions,
                    max_target_positions=self.args.max_target_positions,
                ),
                self.dicts[src].eos(),
                src,
                self.dicts[tgt].eos(),
                tgt,
            )

        self.datasets[split] = RoundRobinZipDatasets(
            OrderedDict([
                (lang_pair, language_pair_dataset(lang_pair))
                for lang_pair in src_datasets.keys()
            ] + [
                (_get_bt_dataset_key(lang_pair), dataset)
                for lang_pair, dataset in backtranslate_datasets.items()
            ] + [
                (_get_denoising_dataset_key(lang_pair), dataset)
                for lang_pair, dataset in noising_datasets.items()
            ]),
            eval_key=None if self.training else "%s-%s" % (self.args.source_lang, self.args.target_lang),
        )
    def load_dataset(self, split, seed=None):
        """Load split, which is train (monolingual data, optional parallel data),
        or eval (always parallel data).
        """
        if split == self.args.valid_subset:
            # tune set is always parallel
            primal_parallel, _, _ = data_utils.load_parallel_dataset(
                source_lang=self.source_lang,
                target_lang=self.target_lang,
                src_bin_path=self.args.forward_eval_source_binary_path,
                tgt_bin_path=self.args.forward_eval_target_binary_path,
                source_dictionary=self.primal_src_dict,
                target_dictionary=self.primal_tgt_dict,
                split=split,
                remove_eos_from_source=not self.args.append_eos_to_source,
                append_eos_to_target=True,
                char_source_dict=None,
                log_verbose=self.args.log_verbose,
            )
            # now just flip the source and target
            dual_parallel, _, _ = data_utils.load_parallel_dataset(
                source_lang=self.target_lang,
                target_lang=self.source_lang,
                src_bin_path=self.args.backward_eval_source_binary_path,
                tgt_bin_path=self.args.backward_eval_target_binary_path,
                source_dictionary=self.dual_src_dict,
                target_dictionary=self.dual_src_dict,
                split=split,
                remove_eos_from_source=not self.args.append_eos_to_source,
                append_eos_to_target=True,
                char_source_dict=None,
                log_verbose=self.args.log_verbose,
            )
            self.datasets[split] = RoundRobinZipDatasets(
                OrderedDict([
                    ("primal_parallel", primal_parallel),
                    ("dual_parallel", dual_parallel),
                ]))
        elif split == self.args.train_subset:
            src_dataset = data_utils.load_monolingual_dataset(
                self.args.train_mono_source_binary_path, is_source=True)
            tgt_dataset = data_utils.load_monolingual_dataset(
                self.args.train_mono_target_binary_path, is_source=True)
            primal_source_mono = LanguagePairDataset(
                src=src_dataset,
                src_sizes=src_dataset.sizes,
                src_dict=self.primal_src_dict,
                tgt=None,
                tgt_sizes=None,
                tgt_dict=None,
            )
            dual_source_mono = LanguagePairDataset(
                src=tgt_dataset,
                src_sizes=tgt_dataset.sizes,
                src_dict=self.dual_src_dict,
                tgt=None,
                tgt_sizes=None,
                tgt_dict=None,
            )

            primal_parallel, _, _ = data_utils.load_parallel_dataset(
                source_lang=self.source_lang,
                target_lang=self.target_lang,
                src_bin_path=self.args.forward_train_source_binary_path,
                tgt_bin_path=self.args.forward_train_target_binary_path,
                source_dictionary=self.primal_src_dict,
                target_dictionary=self.primal_tgt_dict,
                split=split,
                remove_eos_from_source=not self.args.append_eos_to_source,
                append_eos_to_target=True,
                char_source_dict=None,
                log_verbose=self.args.log_verbose,
            )
            dual_parallel, _, _ = data_utils.load_parallel_dataset(
                source_lang=self.target_lang,
                target_lang=self.source_lang,
                src_bin_path=self.args.backward_train_source_binary_path,
                tgt_bin_path=self.args.backward_train_target_binary_path,
                source_dictionary=self.dual_src_dict,
                target_dictionary=self.dual_src_dict,
                split=split,
                remove_eos_from_source=not self.args.append_eos_to_source,
                append_eos_to_target=True,
                char_source_dict=None,
                log_verbose=self.args.log_verbose,
            )
            self.datasets[split] = RoundRobinZipDatasets(
                OrderedDict([
                    ("primal_parallel", primal_parallel),
                    ("dual_parallel", dual_parallel),
                    ("primal_source", primal_source_mono),
                    ("dual_source", dual_source_mono),
                ]))
        else:
            raise ValueError("Invalid data split.")
Exemple #7
0
    def _backtranslation_dataset_helper(
        self,
        remove_eos_from_input_src,
        remove_eos_from_output_src,
    ):
        tgt_dataset = LanguagePairDataset(
            src=self.tgt_dataset,
            src_sizes=self.tgt_dataset.sizes,
            src_dict=self.tgt_dict,
            tgt=None,
            tgt_sizes=None,
            tgt_dict=None,
        )

        generator = SequenceGenerator(
            [self.model],
            tgt_dict=self.tgt_dict,
            max_len_a=0,
            max_len_b=200,
            beam_size=2,
            unk_penalty=0,
        )

        backtranslation_dataset = BacktranslationDataset(
            tgt_dataset=TransformEosDataset(
                dataset=tgt_dataset,
                eos=self.tgt_dict.eos(),
                # remove eos from the input src
                remove_eos_from_src=remove_eos_from_input_src,
            ),
            src_dict=self.tgt_dict,
            backtranslation_fn=(
                lambda sample: generator.generate([self.model], sample)),
            output_collater=TransformEosDataset(
                dataset=tgt_dataset,
                eos=self.tgt_dict.eos(),
                # if we remove eos from the input src, then we need to add it
                # back to the output tgt
                append_eos_to_tgt=remove_eos_from_input_src,
                remove_eos_from_src=remove_eos_from_output_src,
            ).collater,
            cuda=self.cuda,
        )
        dataloader = torch.utils.data.DataLoader(
            backtranslation_dataset,
            batch_size=2,
            collate_fn=backtranslation_dataset.collater,
        )
        backtranslation_batch_result = next(iter(dataloader))

        eos, pad, w1, w2 = self.tgt_dict.eos(), self.tgt_dict.pad(
        ), self.w1, self.w2

        # Note that we sort by src_lengths and add left padding, so actually
        # ids will look like: [1, 0]
        expected_src = torch.LongTensor([[w1, w2, w1, eos],
                                         [pad, pad, w1, eos]])
        if remove_eos_from_output_src:
            expected_src = expected_src[:, :-1]
        expected_tgt = torch.LongTensor([[w1, w2, eos], [w1, w2, eos]])
        generated_src = backtranslation_batch_result["net_input"]["src_tokens"]
        tgt_tokens = backtranslation_batch_result["target"]

        self.assertTensorEqual(expected_src, generated_src)
        self.assertTensorEqual(expected_tgt, tgt_tokens)
d1 = vocab
d2 = vocab
token1 = x.t()
tokens_ds1 = TokenBlockDataset(
    token1,
    sizes=src_lengths,
    break_mode='complete',
    block_size=1,
    pad=0,
    eos=1,
    include_targets=False,
)
token2 = x.t()
tokens_ds2 = TokenBlockDataset(
    token2,
    sizes=src_lengths,
    break_mode='complete',
    block_size=1,
    pad=0,
    eos=1,
    include_targets=False,
)
p_tokens_ds2 = PermutedDataset(tokens_ds2, d2, seed=123)
dataset = LanguagePairDataset(tokens_ds1,
                              tokens_ds1.sizes,
                              d1,
                              tokens_ds2,
                              tokens_ds2.sizes,
                              d2,
                              shuffle=False)
Exemple #9
0
    def _load_dataset_multi_path_helper(
        self,
        split: str,
        src_multiple_bin_paths: Dict[str, str],
        tgt_multiple_bin_paths: Dict[str, str],
        dataset_upsampling: Optional[Dict[str, float]] = None,
        dataset_relative_ratio: Optional[Tuple[str, float]] = None,
        seed: Optional[int] = None,
        noiser: Optional[Dict[str, UnsupervisedMTNoising]] = None,
        is_npz: bool = True,
    ):
        corpora_map = pytorch_translate_data.ParallelCorporaMapConfig(
            src_files=src_multiple_bin_paths, tgt_files=tgt_multiple_bin_paths)
        datasets = OrderedDict()
        for key in corpora_map.src_files:
            src, tgt = corpora_map.src_files[key], corpora_map.tgt_files[key]
            tgt_dataset = pytorch_translate_data.InMemoryIndexedDataset.create_from_file(
                tgt, is_npz=is_npz)

            if self.char_source_dict is not None:
                src_dataset = char_data.InMemoryNumpyWordCharDataset.create_from_file(
                    src)

            else:
                src_dataset = pytorch_translate_data.InMemoryIndexedDataset.create_from_file(
                    src, is_npz=is_npz)
            src_sizes = src_dataset.sizes
            if noiser is not None and key in noiser:
                src_dataset = NoisingDataset(
                    src_dataset=src_dataset,
                    src_dict=self.source_dictionary,
                    seed=seed,
                    noiser=noiser[key],
                )
            if self.char_source_dict is not None:
                datasets[key] = char_data.LanguagePairSourceCharDataset(
                    src=src_dataset,
                    src_sizes=src_sizes,
                    src_dict=self.source_dictionary,
                    tgt=tgt_dataset,
                    tgt_sizes=tgt_dataset.sizes,
                    tgt_dict=self.target_dictionary,
                )
            else:
                datasets[key] = LanguagePairDataset(
                    src=src_dataset,
                    src_sizes=src_sizes,
                    src_dict=self.source_dictionary,
                    tgt=tgt_dataset,
                    tgt_sizes=tgt_dataset.sizes,
                    tgt_dict=self.target_dictionary,
                    left_pad_source=False,
                )
        total_line_count = sum(len(datasets[key]) for key in datasets)
        if dataset_relative_ratio:
            ds, ratio = dataset_relative_ratio
            line_count = len(datasets[ds])
            # By definition ratio = u * line_count / sum(#lines of other datasets)
            u = (total_line_count - line_count) / line_count * ratio
            dataset_upsampling = {key: u}
        elif not dataset_upsampling:
            dataset_upsampling = {}

        print(f"|dataset upsampling:{dataset_upsampling}")
        ds_list = []
        sample_ratios = []
        for key, val in datasets.items():
            ds_list.append(val)
            sample_ratios.append(int(dataset_upsampling.get(key, 1)))

        self.datasets[split] = LanguagePairUpsamplingDataset(
            datasets=datasets.values(), sample_ratios=sample_ratios)
def load_langpair_dataset(data_path,
                          split,
                          src,
                          src_dict,
                          tgt,
                          tgt_dict,
                          combine,
                          dataset_impl,
                          upsample_primary,
                          left_pad_source,
                          left_pad_target,
                          max_source_positions,
                          max_target_positions,
                          prepend_bos=False,
                          load_alignments=False,
                          load_cls_labels=False,
                          load_cls_indices=False,
                          load_sample_weights=False,
                          truncate_source=False,
                          append_source_id=False,
                          shuffle=True):
    def split_exists(split, src, tgt, lang, data_path):
        filename = os.path.join(data_path,
                                '{}.{}-{}.{}'.format(split, src, tgt, lang))
        return indexed_dataset.dataset_exists(filename, impl=dataset_impl)

    src_datasets = []
    tgt_datasets = []

    for k in itertools.count():
        split_k = split + (str(k) if k > 0 else '')

        # infer langcode
        if split_exists(split_k, src, tgt, src, data_path):
            prefix = os.path.join(data_path,
                                  '{}.{}-{}.'.format(split_k, src, tgt))
        elif split_exists(split_k, tgt, src, src, data_path):
            prefix = os.path.join(data_path,
                                  '{}.{}-{}.'.format(split_k, tgt, src))
        else:
            if k > 0:
                break
            else:
                raise FileNotFoundError('Dataset not found: {} ({})'.format(
                    split, data_path))

        src_dataset = data_utils.load_indexed_dataset(prefix + src, src_dict,
                                                      dataset_impl)
        if truncate_source:
            src_dataset = AppendTokenDataset(
                TruncateDataset(
                    StripTokenDataset(src_dataset, src_dict.eos()),
                    max_source_positions - 1,
                ),
                src_dict.eos(),
            )
        src_datasets.append(src_dataset)

        tgt_dataset = data_utils.load_indexed_dataset(prefix + tgt, tgt_dict,
                                                      dataset_impl)
        if tgt_dataset is not None:
            tgt_datasets.append(tgt_dataset)

        logger.info('{} {} {}-{} {} examples'.format(data_path, split_k,
                                                     src, tgt,
                                                     len(src_datasets[-1])))

        if not combine:
            break

    assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0

    if len(src_datasets) == 1:
        src_dataset = src_datasets[0]
        tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None
    else:
        sample_ratios = [1] * len(src_datasets)
        sample_ratios[0] = upsample_primary
        src_dataset = ConcatDataset(src_datasets, sample_ratios)
        if len(tgt_datasets) > 0:
            tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)
        else:
            tgt_dataset = None

    src_prepended_bos = False
    if prepend_bos:
        assert hasattr(src_dict, "bos_index") and hasattr(
            tgt_dict, "bos_index")
        src_dataset = PrependTokenDataset(src_dataset, src_dict.bos())
        if tgt_dataset is not None:
            tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos())
        src_prepended_bos = True

    eos = None
    if append_source_id:
        src_dataset = AppendTokenDataset(src_dataset,
                                         src_dict.index('[{}]'.format(src)))
        if tgt_dataset is not None:
            tgt_dataset = AppendTokenDataset(
                tgt_dataset, tgt_dict.index('[{}]'.format(tgt)))
        eos = tgt_dict.index('[{}]'.format(tgt))

    align_dataset = None
    if load_alignments:
        align_path = os.path.join(data_path,
                                  '{}.align.{}-{}'.format(split, src, tgt))
        if indexed_dataset.dataset_exists(align_path, impl=dataset_impl):
            align_dataset = data_utils.load_indexed_dataset(
                align_path, None, dataset_impl)

    tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None
    cls_dataset = None
    if load_cls_labels:
        cls_labels_path = os.path.join(data_path, '{}.cls'.format(split))
        if indexed_dataset.dataset_exists(cls_labels_path, impl=dataset_impl):
            cls_dataset = data_utils.load_indexed_dataset(
                cls_labels_path, None, dataset_impl)
            if truncate_source:
                cls_dataset = AppendTokenDataset(
                    TruncateDataset(
                        TruncateLastElementDataset(cls_dataset),
                        max_source_positions - 1,
                    ),
                    -1,  # will ignore -1 label in training
                )
            if src_prepended_bos:
                cls_dataset = PrependTokenDataset(cls_dataset, -1)
        else:
            print("cls_labels dataset NOT FOUND!", cls_labels_path)

    cls_indices_dataset = None
    if load_cls_indices:
        cls_indices_path = os.path.join(data_path, '{}.cls_ind'.format(split))
        if indexed_dataset.dataset_exists(cls_indices_path, impl=dataset_impl):
            cls_indices_dataset = data_utils.load_indexed_dataset(
                cls_indices_path, None, dataset_impl)

    sample_weights = None
    if load_sample_weights:
        weights_file = os.path.join(
            data_path, '{}.{}-{}.weights.npy'.format(split, src, tgt))
        assert os.path.exists(weights_file)
        with open(weights_file, 'rb') as f:
            sample_weights = np.load(f)
        logger.info('Loaded {} weights from {}'.format(len(sample_weights),
                                                       weights_file))

    return LanguagePairDataset(
        src_dataset,
        src_dataset.sizes,
        src_dict,
        tgt_dataset,
        tgt_dataset_sizes,
        tgt_dict,
        left_pad_source=left_pad_source,
        left_pad_target=left_pad_target,
        max_source_positions=max_source_positions,
        max_target_positions=max_target_positions,
        align_dataset=align_dataset,
        eos=eos,
        cls_dataset=cls_dataset,
        cls_indices_dataset=cls_indices_dataset,
        sample_weights=sample_weights,
        shuffle=shuffle,
    )
Exemple #11
0
def load_langpair_dataset(
    data_path,
    split,
    src,
    src_dict,
    tgt,
    tgt_dict,
    feature_dict,
    combine,
    dataset_impl,
    upsample_primary,
    left_pad_source,
    left_pad_target,
    max_source_positions,
    max_target_positions,
    prepend_bos=False,
    load_features=False,
    load_alignments=False,
    truncate_source=False,
    append_source_id=False,
    num_buckets=0,
):
    def split_exists(split, src, tgt, lang, data_path):
        filename = os.path.join(data_path,
                                '{}.{}-{}.{}'.format(split, src, tgt, lang))
        return indexed_dataset.dataset_exists(filename, impl=dataset_impl)

    src_datasets = []
    tgt_datasets = []

    for k in itertools.count():
        split_k = split + (str(k) if k > 0 else '')

        # infer langcode
        if split_exists(split_k, src, tgt, src, data_path):
            prefix = os.path.join(data_path,
                                  '{}.{}-{}.'.format(split_k, src, tgt))
        elif split_exists(split_k, tgt, src, src, data_path):
            prefix = os.path.join(data_path,
                                  '{}.{}-{}.'.format(split_k, tgt, src))
        else:
            if k > 0:
                break
            else:
                raise FileNotFoundError('Dataset not found: {} ({})'.format(
                    split, data_path))

        src_dataset = data_utils.load_indexed_dataset(prefix + src, src_dict,
                                                      dataset_impl)
        if truncate_source:
            src_dataset = AppendTokenDataset(
                TruncateDataset(
                    StripTokenDataset(src_dataset, src_dict.eos()),
                    max_source_positions - 1,
                ),
                src_dict.eos(),
            )
        src_datasets.append(src_dataset)

        tgt_dataset = data_utils.load_indexed_dataset(prefix + tgt, tgt_dict,
                                                      dataset_impl)
        if tgt_dataset is not None:
            tgt_datasets.append(tgt_dataset)

        logger.info('{} {} {}-{} {} examples'.format(data_path, split_k,
                                                     src, tgt,
                                                     len(src_datasets[-1])))

        if not combine:
            break

    assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0

    if len(src_datasets) == 1:
        src_dataset = src_datasets[0]
        tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None
    else:
        sample_ratios = [1] * len(src_datasets)
        sample_ratios[0] = upsample_primary
        src_dataset = ConcatDataset(src_datasets, sample_ratios)
        if len(tgt_datasets) > 0:
            tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)
        else:
            tgt_dataset = None

    if prepend_bos:
        assert hasattr(src_dict, "bos_index") and hasattr(
            tgt_dict, "bos_index")
        src_dataset = PrependTokenDataset(src_dataset, src_dict.bos())
        if tgt_dataset is not None:
            tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos())

    eos = None
    if append_source_id:
        src_dataset = AppendTokenDataset(src_dataset,
                                         src_dict.index('[{}]'.format(src)))
        if tgt_dataset is not None:
            tgt_dataset = AppendTokenDataset(
                tgt_dataset, tgt_dict.index('[{}]'.format(tgt)))
        eos = tgt_dict.index('[{}]'.format(tgt))

    align_dataset = None
    if load_alignments:
        align_path = os.path.join(data_path,
                                  '{}.align.{}-{}'.format(split, src, tgt))
        if indexed_dataset.dataset_exists(align_path, impl=dataset_impl):
            align_dataset = data_utils.load_indexed_dataset(
                align_path, None, dataset_impl)

    # print("feature_dict", feature_dict.symbols, feature_dict.count) #feature_dict ['<s>', '<pad>', '</s>', '<unk>', '<ori>', '<rep>', 'madeupword0000', 'madeupword0001'] [1, 1, 1, 1, 18558611, 5354704, 0, 0]

    feature_dataset = None
    if load_features:
        feature_path = os.path.join(
            data_path, '{}.feature.{}-{}.{}'.format(split, src, tgt, src))
        if indexed_dataset.dataset_exists(feature_path, impl=dataset_impl):
            feature_dataset = data_utils.load_indexed_dataset(
                feature_path, feature_dict, dataset_impl)

    tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None
    return LanguagePairDataset(
        src_dataset,
        src_dataset.sizes,
        src_dict,
        tgt_dataset,
        tgt_dataset_sizes,
        tgt_dict,
        left_pad_source=left_pad_source,
        left_pad_target=left_pad_target,
        feature_dataset=feature_dataset,
        align_dataset=align_dataset,
        eos=eos,
        num_buckets=num_buckets,
    )
Exemple #12
0
    def load_dataset(self, split, **kwargs):
        """Load a given dataset split (e.g., train, valid, test)."""

        logger.info("load dataset start")
        prefix = os.path.join(self.args.data, '{}.input-label'.format(split))

        # Read input sentences.
        sentences, lengths = [], []
        with open(prefix + '.input', encoding='utf-8') as file:
            for line in file:
                sentence = line.strip()
                #print('sentence: {} '.format((sentence)))
                # Tokenize the sentence, splitting on spaces
                tokens = self.input_vocab.encode_line(
                    sentence,
                    add_if_not_exist=False,
                )
                #print('token: {} '.format((tokens)))
                #token: tensor([48,  4, 13, 15,  5,  8,  2], dtype=torch.int32)
                sentences.append(tokens)
                lengths.append(tokens.numel())
                # print(lengths) [7, 8, 8, 5, 12, 6, 6, 5 ...

        # Read labels.
        labels = []
        with open(prefix + '.label', encoding='utf-8') as file:
            print(prefix + '.label')
            for line in file:
                label = line.strip()
                # print('label: {} '.format((label)))
                labels.append(
                    # Convert label to a numeric ID.
                    torch.LongTensor([self.label_vocab.add_symbol(label)]))
                #print(labels[0]) tensor([5])
                # if label == 'Russian':
                #     print(self.label_vocab.index('Russian'))
                #     print(self.label_vocab.count[4])
        print("lables are {}".format(np.unique(labels)))
        print(self.label_vocab.indices.keys())
        print(self.label_vocab.indices.values())
        for i in range(len(self.label_vocab.count)):
            print(self.label_vocab.symbols[i])
            print(self.label_vocab.count[i])
        print('label_vocab: {} '.format(self.label_vocab.values()))

        assert len(sentences) == len(labels)
        print('| {} {} {} examples'.format(self.args.data, split,
                                           len(sentences)))

        # We reuse LanguagePairDataset since classification can be modeled as a
        # sequence-to-sequence task where the target sequence has length 1.
        self.datasets[split] = LanguagePairDataset(
            src=sentences,
            src_sizes=lengths,
            src_dict=self.input_vocab,
            tgt=labels,
            tgt_sizes=torch.ones(len(labels)),  # targets have length 1
            tgt_dict=self.label_vocab,
            left_pad_source=False,
            # Since our target is a single class label, there's no need for
            # teacher forcing. If we set this to ``True`` then our Model's
            # ``forward()`` method would receive an additional argument called
            # *prev_output_tokens* that would contain a shifted version of the
            # target sequence.
            input_feeding=False,
        )
        print(self.datasets[split])
        print("load dataset complete")
        assert len(sentences) == len(labels)
 def build_dataset_for_inference(self, src_tokens, src_lengths, tgt_tokens=None, tgt_lengths=None, num_source_inputs=1):
     if num_source_inputs == 1:
         return LanguagePairDataset(src_tokens, src_lengths, self.source_dictionary, tgt=tgt_tokens, tgt_sizes=tgt_lengths)
     else:
         return MultiSourceTranslationDataset(src_tokens, src_lengths, self.source_dictionary, tgt=tgt_tokens, tgt_sizes=tgt_lengths)
Exemple #14
0
 def build_dataset_for_evaluation(self, src_tokens, src_lengths, tgt_tokens, tgt_lengths):
     return LanguagePairDataset(src_tokens, src_lengths, self.source_dictionary, tgt_tokens, tgt_lengths)
 def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None):
     return LanguagePairDataset(src_tokens, src_lengths, self.source_dictionary,
                                tgt_dict=self.target_dictionary,
                                constraints=constraints)
Exemple #16
0
    def load_dataset(self, split, epoch=0, **kwargs):

        if self.retrieve_fn is None:
            self.build_model(self.args)
            # raise ValueError(
            #     "retrieve_fn is None !"
            # )

        retrieve_dataset = None
        if self.retrieve_pool is None:
            paths = self.args.data.split(os.pathsep)
            assert len(paths) > 0
            data_path = paths[epoch % len(paths)]
            split_path = os.path.join(data_path, split)

            dataset = data_utils.load_indexed_dataset(split_path,
                                                      self.dictionary,
                                                      self.args.dataset_impl)

            if dataset is None:
                raise FileNotFoundError("Dataset not found: {} ({})".format(
                    split, split_path))

            lang_pair_dataset = LanguagePairDataset(
                dataset,
                dataset.sizes,
                self.src_dict,
                left_pad_source=self.args.left_pad_source,
                left_pad_target=self.args.left_pad_target,
            )

            if split == self.args.retrieve_split:
                print("split {} is used as the retrieve_pool".format(split))
                retrieve_dataset = lang_pair_dataset
            else:
                print("loading the retrieve split {}".format(
                    self.args.retrieve_split))

                split_path = os.path.join(self.args.data,
                                          self.args.retrieve_split)
                dataset = data_utils.load_indexed_dataset(
                    split_path, self.dictionary, self.args.dataset_impl)

                if dataset is None:
                    raise FileNotFoundError(
                        "Dataset not found: {} ({})".format(
                            self.args.retrieve_split, split_path))

                if self.args.prune_num > 0:
                    retrieve_dataset = LanguagePairMapDataset(
                        dataset,
                        dataset.sizes,
                        self.src_dict,
                        left_pad_source=self.args.left_pad_source,
                        left_pad_target=self.args.left_pad_target,
                    )
                else:
                    retrieve_dataset = LanguagePairDataset(
                        dataset,
                        dataset.sizes,
                        self.src_dict,
                        left_pad_source=self.args.left_pad_source,
                        left_pad_target=self.args.left_pad_target,
                    )

            self.retrieve_pool = retrieve_dataset

        elif split == self.args.retrieve_split:
            print(
                "skip reading split {} since it is used as the retrieve_pool".
                format(split))
            lang_pair_dataset = self.retrieve_pool

        else:
            paths = self.args.data.split(os.pathsep)
            assert len(paths) > 0
            data_path = paths[epoch % len(paths)]
            split_path = os.path.join(data_path, split)

            dataset = data_utils.load_indexed_dataset(split_path,
                                                      self.dictionary,
                                                      self.args.dataset_impl)

            if dataset is None:
                raise FileNotFoundError("Dataset not found: {} ({})".format(
                    split, split_path))

            lang_pair_dataset = LanguagePairDataset(
                dataset,
                dataset.sizes,
                self.src_dict,
                left_pad_source=self.args.left_pad_source,
                left_pad_target=self.args.left_pad_target,
            )

        # always use unbiased estimator at test time
        # Avoid selecting self as templates at training time
        if 'train' not in split and self.args.criterion != 'guu_elbo':
            sampling = True
            masks = None
        else:

            def read_mask(fpath):
                with open(fpath) as fin:
                    return [int(x.rstrip()) for x in fin]

            sampling = options.eval_bool(self.args.reinforce)

            if os.path.exists(os.path.join(self.args.data, 'mask_id.txt')):
                masks = read_mask(os.path.join(self.args.data, 'mask_id.txt'))
            else:
                masks = None

        self.datasets[split] = RetrievePrototypeDataset(
            lang_pair_dataset,
            self.src_dict,
            retrieve_dataset=self.retrieve_pool,
            retrieve_fn=self.retrieve_fn,
            cuda=not self.args.cpu,
            num_samples=self.args.infer_ns,
            temperature=self.args.reinforce_temperature,
            sampling=sampling,
            edit_dict=self.edit_dict,
            split=split,
            masks=masks,
        )
    def load_dataset(self, split, **kwargs):
        def split_exists(split, lang):
            filename = os.path.join(self.args.data,
                                    '{}.{}'.format(split, lang))
            if self.args.raw_text and IndexedRawTextDataset.exists(filename):
                return True
            elif not self.args.raw_text and IndexedDataset.exists(filename):
                return True
            return False

        def split_para_exists(split, key, lang):
            filename = os.path.join(self.args.data,
                                    '{}.{}.{}'.format(split, key, lang))
            if self.args.raw_text and IndexedRawTextDataset.exists(filename):
                return True
            elif not self.args.raw_text and IndexedDataset.exists(filename):
                return True
            return False

        def indexed_dataset(path, dictionary):
            if self.args.raw_text:
                return IndexedRawTextDataset(path, dictionary)
            elif IndexedDataset.exists(path):
                if self.args.lazy_load:
                    return IndexedDataset(path, fix_lua_indexing=True)
                else:
                    return IndexedCachedDataset(path, fix_lua_indexing=True)
            return None

        src_mono_datasets = {}
        for lang_pair in self.args.mono_lang_pairs:
            lang = lang_pair.split('-')[0]

            if split_exists(split, lang):
                prefix = os.path.join(self.args.data,
                                      '{}.{}'.format(split, lang))
            else:
                raise FileNotFoundError(
                    'Not Found available {} dataset for ({}) lang'.format(
                        split, lang))

            src_mono_datasets[lang_pair] = indexed_dataset(
                prefix, self.dicts[lang])
            print('| monolingual {}-{}: {} examples'.format(
                split, lang, len(src_mono_datasets[lang_pair])))

        src_para_datasets = {}
        for lang_pair in self.args.para_lang_pairs:
            src, tgt = lang_pair.split('-')
            key = '-'.join(sorted([src, tgt]))
            if not split_para_exists(split, key, src):
                raise FileNotFoundError(
                    'Not Found available {}-{} para dataset for ({}) lang'.
                    format(split, key, src))
            if not split_para_exists(split, key, tgt):
                raise FileNotFoundError(
                    'Not Found available {}-{} para dataset for ({}) lang'.
                    format(split, key, tgt))

            prefix = os.path.join(self.args.data, '{}.{}'.format(split, key))
            if '{}.{}'.format(key, src) not in src_para_datasets:
                src_para_datasets[key + '.' + src] = indexed_dataset(
                    prefix + '.' + src, self.dicts[src])
            if '{}.{}'.format(key, tgt) not in src_para_datasets:
                src_para_datasets[key + '.' + tgt] = indexed_dataset(
                    prefix + '.' + tgt, self.dicts[tgt])

            print('| bilingual {} {}-{}.{}: {} examples'.format(
                split, src, tgt, src, len(src_para_datasets[key + '.' + src])))
            print('| bilingual {} {}-{}.{}: {} examples'.format(
                split, src, tgt, tgt, len(src_para_datasets[key + '.' + tgt])))

        mt_para_dataset = {}
        for lang_pair in self.args.mt_steps:
            src, tgt = lang_pair.split('-')
            key = '-'.join(sorted([src, tgt]))
            src_key = key + '.' + src
            tgt_key = key + '.' + tgt
            src_dataset = src_para_datasets[src_key]
            tgt_dataset = src_para_datasets[tgt_key]
            mt_para_dataset[lang_pair] = LanguagePairDataset(
                src_dataset,
                src_dataset.sizes,
                self.dicts[src],
                tgt_dataset,
                tgt_dataset.sizes,
                self.dicts[tgt],
                left_pad_source=self.args.left_pad_source,
                left_pad_target=self.args.left_pad_target,
                max_source_positions=self.args.max_source_positions,
                max_target_positions=self.args.max_target_positions,
            )

        eval_para_dataset = {}
        if split != 'train':
            for lang_pair in self.args.valid_lang_pairs:
                src, tgt = lang_pair.split('-')
                if src == tgt:
                    src_key = src + '-' + tgt
                    tgt_key = src + '-' + tgt
                    src_dataset = src_mono_datasets[src_key]
                    tgt_dataset = src_mono_datasets[tgt_key]
                else:
                    key = '-'.join(sorted([src, tgt]))
                    src_key = key + '.' + src
                    tgt_key = key + '.' + tgt
                    src_dataset = src_para_datasets[src_key]
                    tgt_dataset = src_para_datasets[tgt_key]
                eval_para_dataset[lang_pair] = LanguagePairDataset(
                    src_dataset,
                    src_dataset.sizes,
                    self.dicts[src],
                    tgt_dataset,
                    tgt_dataset.sizes,
                    self.dicts[tgt],
                    left_pad_source=self.args.left_pad_source,
                    left_pad_target=self.args.left_pad_target,
                    max_source_positions=self.args.max_source_positions,
                    max_target_positions=self.args.max_target_positions,
                )

        memt_para_dataset = {}
        if split == 'train':
            for lang_pair in self.args.memt_steps:
                src, tgt = lang_pair.split('-')
                key = '-'.join(sorted([src, tgt]))
                src_key = key + '.' + src
                tgt_key = key + '.' + tgt
                src_id, tgt_id = self.args.langs_id[src], self.args.langs_id[
                    tgt]
                src_dataset = src_para_datasets[src_key]
                tgt_dataset = src_para_datasets[tgt_key]
                memt_para_dataset[lang_pair] = NoisyLanguagePairDataset(
                    src_dataset,
                    src_dataset.sizes,
                    tgt_dataset,
                    tgt_dataset.sizes,
                    self.dicts[src],
                    self.dicts[tgt],
                    src_id,
                    tgt_id,
                    left_pad_source=self.args.left_pad_source,
                    left_pad_target=self.args.left_pad_target,
                    max_source_positions=self.args.max_source_positions,
                    max_target_positions=self.args.max_target_positions,
                    ratio=self.args.word_mask,
                    pred_probs=self.args.pred_probs,
                )

        mass_mono_datasets = {}
        if split == 'train':
            for lang_pair in self.args.mass_steps:
                src_dataset = src_mono_datasets[lang_pair]
                lang = lang_pair.split('-')[0]
                mass_mono_dataset = MaskedLanguagePairDataset(
                    src_dataset,
                    src_dataset.sizes,
                    self.dicts[lang],
                    left_pad_source=self.args.left_pad_source,
                    left_pad_target=self.args.left_pad_target,
                    max_source_positions=self.args.max_source_positions,
                    max_target_positions=self.args.max_target_positions,
                    shuffle=True,
                    lang_id=self.args.langs_id[lang],
                    ratio=self.args.word_mask,
                    pred_probs=self.args.pred_probs,
                )
                mass_mono_datasets[lang_pair] = mass_mono_dataset

        self.datasets[split] = RoundRobinZipDatasets(OrderedDict(
            [(_get_mt_dataset_key(lang_pair), mt_para_dataset[lang_pair])
             for lang_pair in mt_para_dataset.keys()] +
            [(_get_memt_dataset_key(lang_pair), memt_para_dataset[lang_pair])
             for lang_pair in memt_para_dataset.keys()] +
            [(_get_mass_dataset_key(lang_pair), mass_mono_datasets[lang_pair])
             for lang_pair in mass_mono_datasets.keys()] +
            [(_get_mt_dataset_key(lang_pair), eval_para_dataset[lang_pair])
             for lang_pair in eval_para_dataset.keys()]),
                                                     eval_key=None
                                                     if self.training else
                                                     self.args.eval_lang_pair)
Exemple #18
0
def load_langpair_dataset(
    data_path,
    split,
    src,
    src_dict,
    tgt,
    tgt_dict,
    combine,
    dataset_impl,
    upsample_primary,
    left_pad_source,
    left_pad_target,
    max_source_positions,
    max_target_positions,
    shuffle,
    is_infer,
):
    def split_exists(split, src, tgt, lang, data_path):
        filename = os.path.join(data_path,
                                '{}.{}-{}.{}'.format(split, src, tgt, lang))
        return indexed_dataset.dataset_exists(filename, impl=dataset_impl)

    src_datasets = []
    tgt_datasets = []

    for k in itertools.count():
        split_k = split + (str(k) if k > 0 else '')

        # infer langcode
        if split_exists(split_k, src, tgt, src, data_path):
            prefix = os.path.join(data_path,
                                  '{}.{}-{}.'.format(split_k, src, tgt))
        elif split_exists(split_k, tgt, src, src, data_path):
            prefix = os.path.join(data_path,
                                  '{}.{}-{}.'.format(split_k, tgt, src))
        else:
            if k > 0:
                break
            else:
                raise FileNotFoundError('Dataset not found: {} ({})'.format(
                    split, data_path))

        src_datasets.append(
            data_utils.load_indexed_dataset(prefix + src, src_dict,
                                            dataset_impl))
        tgt_datasets.append(
            data_utils.load_indexed_dataset(prefix + tgt, tgt_dict,
                                            dataset_impl))

        print('| {} {} {}-{} {} examples'.format(data_path, split_k, src, tgt,
                                                 len(src_datasets[-1])))

        if not combine:
            break

    assert len(src_datasets) == len(tgt_datasets)

    if len(src_datasets) == 1:
        src_dataset, tgt_dataset = src_datasets[0], tgt_datasets[0]
    else:
        sample_ratios = [1] * len(src_datasets)
        sample_ratios[0] = upsample_primary
        src_dataset = ConcatDataset(src_datasets, sample_ratios)
        tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)
    # for infer step using dataset not truncate
    if is_infer:
        return LanguagePairDataset(
            src_dataset,
            src_dataset.sizes,
            src_dict,
            tgt_dataset,
            tgt_dataset.sizes,
            tgt_dict,
            left_pad_source=left_pad_source,
            left_pad_target=left_pad_target,
            max_source_positions=max_source_positions,
            max_target_positions=max_target_positions,
            shuffle=shuffle,
        )
    # for train and valid step using truncate truncate dataset
    else:
        return TruncateLanguagePairDataset(
            src_dataset,
            src_dataset.sizes,
            src_dict,
            DEFAULT_MAX_SRC_LEN,
            tgt_dataset,
            tgt_dataset.sizes,
            tgt_dict,
            DEFAULT_MAX_TGT_LEN,
            left_pad_source=left_pad_source,
            left_pad_target=left_pad_target,
            max_source_positions=max_source_positions,
            max_target_positions=max_target_positions,
            shuffle=shuffle,
        )
Exemple #19
0
def load_langpair_dataset(
    data_path, split,
    src, src_dict,
    tgt, tgt_dict,
    combine, dataset_impl, upsample_primary,
    left_pad_source, left_pad_target, max_source_positions, max_target_positions,
    src_tag=None, tgt_tag=None, src_tau=-1, tgt_tau=-1, epoch=0, id_to_sample_probabilities=None, lm=None,
    idx_to_src_gradnorm=None 
):
    def split_exists(split, src, tgt, lang, data_path):
        filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang))
        return indexed_dataset.dataset_exists(filename, impl=dataset_impl)

    src_datasets = []
    tgt_datasets = []

    for k in itertools.count():
        split_k = split + (str(k) if k > 0 else '')

        # infer langcode
        if split_exists(split_k, src, tgt, src, data_path):
            prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, src, tgt))
        elif split_exists(split_k, tgt, src, src, data_path):
            prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, tgt, src))
        else:
            if k > 0:
                break
            else:
                raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path))

        src_datasets.append(
            data_utils.load_indexed_dataset(prefix + src, src_dict, dataset_impl)
        )
        tgt_datasets.append(
            data_utils.load_indexed_dataset(prefix + tgt, tgt_dict, dataset_impl)
        )

        print('| {} {} {}-{} {} examples'.format(data_path, split_k, src, tgt, len(src_datasets[-1])))

        if not combine:
            break

    assert len(src_datasets) == len(tgt_datasets)

    if len(src_datasets) == 1:
        src_dataset, tgt_dataset = src_datasets[0], tgt_datasets[0]
    else:
        sample_ratios = [1] * len(src_datasets)
        sample_ratios[0] = upsample_primary
        src_dataset = ConcatDataset(src_datasets, sample_ratios)
        tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)

    return LanguagePairDataset(
        src_dataset, src_dataset.sizes, src_dict,
        tgt_dataset, tgt_dataset.sizes, tgt_dict,
        left_pad_source=left_pad_source,
        left_pad_target=left_pad_target,
        max_source_positions=max_source_positions,
        max_target_positions=max_target_positions,
        src_tag=src_tag,
        tgt_tag=tgt_tag,
        src_tau=src_tau,
        tgt_tau=tgt_tau,
        id_to_sample_probabilities=id_to_sample_probabilities,
        lm=lm,
        idx_to_src_gradnorm=idx_to_src_gradnorm,
    )
Exemple #20
0
    def load_dataset(self, split, combine=False, only_train=False):
        """Load a dataset split."""
        def split_exists(split, src, tgt, lang):
            filename = os.path.join(
                self.args.data, '{}.{}-{}.{}'.format(split, src, tgt, lang))
            if self.args.raw_text and IndexedRawTextDataset.exists(filename):
                return True
            elif not self.args.raw_text and IndexedInMemoryDataset.exists(
                    filename):
                return True
            return False

        def indexed_dataset(path, dictionary):
            if self.args.raw_text:
                return IndexedRawTextDataset(path, dictionary)
            elif IndexedInMemoryDataset.exists(path):
                return IndexedInMemoryDataset(path, fix_lua_indexing=True)
            return None

        src_datasets = []
        tgt_datasets = []
        pivot_datasets = []
        mt_datasets = []

        for k in itertools.count():
            split_k = split + (str(k) if k > 0 else '')

            # infer langcode
            src, tgt, pivot, mt = \
                self.args.source_lang, self.args.target_lang, self.args.p, self.args.mt

            if split_exists(split_k, src, tgt, src):
                prefix = os.path.join(self.args.data,
                                      '{}.{}-{}.'.format(split_k, src, tgt))
            elif split_exists(split_k, tgt, src, src):
                prefix = os.path.join(self.args.data,
                                      '{}.{}-{}.'.format(split_k, tgt, src))
            else:
                if k > 0:
                    break
                else:
                    raise FileNotFoundError(
                        'Dataset not found: {} ({})'.format(
                            split, self.args.data))

            src_datasets.append(indexed_dataset(prefix + src, self.src_dict))
            tgt_datasets.append(indexed_dataset(prefix + tgt, self.tgt_dict))
            if only_train:
                pivot_datasets.append(
                    indexed_dataset(prefix + pivot, self.tgt_dict))
                mt_datasets.append(indexed_dataset(prefix + mt, self.tgt_dict))

            print('| {} {} {} examples'.format(self.args.data, split_k,
                                               len(src_datasets[-1])))

            if not combine:
                break

        if only_train:
            assert len(src_datasets) == len(tgt_datasets) == len(
                pivot_datasets) == len(mt_datasets)
        else:
            assert len(src_datasets) == len(tgt_datasets)

        if len(src_datasets) == 1:
            src_dataset, tgt_dataset = src_datasets[0], tgt_datasets[0]
            src_sizes = src_dataset.sizes
            tgt_sizes = tgt_dataset.sizes

            if only_train:
                pivot_dataset = pivot_datasets[0]
                pivot_sizes = pivot_dataset.sizes

                mt_dataset = mt_datasets[0]
                mt_sizes = mt_dataset.sizes
            else:
                pivot_dataset = None
                pivot_sizes = None
                mt_dataset = None
                mt_sizes = None
        else:
            src_dataset = ConcatDataset(src_datasets)
            tgt_dataset = ConcatDataset(tgt_datasets)
            src_sizes = np.concatenate([ds.sizes for ds in src_datasets])
            tgt_sizes = np.concatenate([ds.sizes for ds in tgt_datasets])

            if only_train:
                pivot_dataset = ConcatDataset(pivot_datasets)
                pivot_sizes = np.concatenate(
                    [ds.sizes for ds in pivot_datasets])
                mt_dataset = ConcatDataset(mt_datasets)
                mt_sizes = np.concatenate([ds.sizes for ds in mt_datasets])
            else:
                pivot_dataset = None
                pivot_sizes = None
                mt_dataset = None
                mt_sizes = None

        self.datasets[split] = LanguagePairDataset(
            src_dataset,
            src_sizes,
            self.src_dict,
            pivot_dataset,
            pivot_sizes,
            mt_dataset,
            mt_sizes,
            tgt_dataset,
            tgt_sizes,
            self.tgt_dict,
            left_pad_source=self.args.left_pad_source,
            left_pad_target=self.args.left_pad_target,
            max_source_positions=self.args.max_source_positions,
            max_target_positions=self.args.max_target_positions,
        )
Exemple #21
0
    def load_dataset(self, split, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        def split_exists(split, src, tgt, lang, data_path):
            filename = os.path.join(
                data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang))
            if self.args.raw_text and IndexedRawTextDataset.exists(filename):
                return True
            elif not self.args.raw_text and IndexedDataset.exists(filename):
                return True
            return False

        def indexed_dataset(path, dictionary):
            if self.args.raw_text:
                return IndexedRawTextDataset(path, dictionary)
            elif IndexedDataset.exists(path):
                return IndexedCachedDataset(path, fix_lua_indexing=True)
            return None

        src_datasets = []
        tgt_datasets = []

        data_paths = self.args.data

        for dk, data_path in enumerate(data_paths):
            for k in itertools.count():
                split_k = split + (str(k) if k > 0 else '')

                # infer langcode
                src, tgt = self.args.source_lang, self.args.target_lang
                if split_exists(split_k, src, tgt, src, data_path):
                    prefix = os.path.join(
                        data_path, '{}.{}-{}.'.format(split_k, src, tgt))
                elif split_exists(split_k, tgt, src, src, data_path):
                    prefix = os.path.join(
                        data_path, '{}.{}-{}.'.format(split_k, tgt, src))
                else:
                    if k > 0 or dk > 0:
                        break
                    else:
                        raise FileNotFoundError(
                            'Dataset not found: {} ({})'.format(
                                split, data_path))

                src_datasets.append(
                    indexed_dataset(prefix + src, self.src_dict))
                tgt_datasets.append(
                    indexed_dataset(prefix + tgt, self.tgt_dict))

                print('| {} {} {} examples'.format(data_path, split_k,
                                                   len(src_datasets[-1])))

                if not combine:
                    break

        assert len(src_datasets) == len(tgt_datasets)

        if len(src_datasets) == 1:
            src_dataset, tgt_dataset = src_datasets[0], tgt_datasets[0]
        else:
            sample_ratios = [1] * len(src_datasets)
            sample_ratios[0] = self.args.upsample_primary
            src_dataset = ConcatDataset(src_datasets, sample_ratios)
            tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)

        self.datasets[split] = LanguagePairDataset(
            src_dataset,
            src_dataset.sizes,
            self.src_dict,
            tgt_dataset,
            tgt_dataset.sizes,
            self.tgt_dict,
            left_pad_source=self.args.left_pad_source,
            left_pad_target=self.args.left_pad_target,
            max_source_positions=self.args.max_source_positions,
            max_target_positions=self.args.max_target_positions,
        )
Exemple #22
0
def load_langpair_dataset(
    data_path,
    split,
    src,
    src_dict,
    tgt,
    tgt_dict,
    combine,
    dataset_impl,
    upsample_primary,
    left_pad_source,
    left_pad_target,
    max_source_positions,
    max_target_positions,
    extra_feature_dicts,
):
    def split_exists(split, src, tgt, lang, data_path):
        filename = os.path.join(data_path,
                                '{}.{}-{}.{}'.format(split, src, tgt, lang))
        return indexed_dataset.dataset_exists(filename, impl=dataset_impl)

    src_datasets = []
    tgt_datasets = []

    extra_feature_datasets = defaultdict(list)

    for k in itertools.count():
        split_k = split + (str(k) if k > 0 else '')
        # infer langcode
        if split_exists(split_k, src, tgt, src, data_path):
            prefix = os.path.join(data_path,
                                  '{}.{}-{}.'.format(split_k, src, tgt))
        elif split_exists(split_k, tgt, src, src, data_path):
            prefix = os.path.join(data_path,
                                  '{}.{}-{}.'.format(split_k, tgt, src))
        else:
            if k > 0:
                break
            else:
                raise FileNotFoundError('Dataset not found: {} ({})'.format(
                    split, data_path))

        src_datasets.append(
            data_utils.load_indexed_dataset(prefix + src, src_dict,
                                            dataset_impl))
        tgt_datasets.append(
            data_utils.load_indexed_dataset(prefix + tgt, tgt_dict,
                                            dataset_impl))

        if extra_feature_dicts:
            for i, feature_type in enumerate(extra_feature_dicts):
                extra_feature_datasets[feature_type].append(
                    data_utils.load_indexed_dataset(
                        prefix + feature_type,
                        extra_feature_dicts[feature_type], dataset_impl))

        print('| {} {} {}-{} {} examples'.format(data_path, split_k, src, tgt,
                                                 len(src_datasets[-1])))

        if not combine:
            break

    assert len(src_datasets) == len(tgt_datasets)

    if len(src_datasets) == 1:
        src_dataset, tgt_dataset = src_datasets[0], tgt_datasets[0]
        extra_feature_datasets = {
            feature_type: datasets[0]
            for feature_type, datasets in extra_feature_datasets.items()
        }

    else:
        sample_ratios = [1] * len(src_datasets)
        sample_ratios[0] = upsample_primary
        src_dataset = ConcatDataset(src_datasets, sample_ratios)
        tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)
        extra_feature_datasets = {
            feature_type: ConcatDataset(datasets)
            for feature_type, datasets in extra_feature_datasets.items()
        }

    if len(extra_feature_datasets.keys()) > 0:
        return LanguagePairDatasetWithExtraFeatures(
            src_dataset,
            src_dataset.sizes,
            src_dict,
            tgt_dataset,
            tgt_dataset.sizes,
            tgt_dict,
            left_pad_source=left_pad_source,
            left_pad_target=left_pad_target,
            max_source_positions=max_source_positions,
            max_target_positions=max_target_positions,
            extra_feature_dicts=extra_feature_dicts,
            extra_features=extra_feature_datasets)
    else:
        return LanguagePairDataset(
            src_dataset,
            src_dataset.sizes,
            src_dict,
            tgt_dataset,
            tgt_dataset.sizes,
            tgt_dict,
            left_pad_source=left_pad_source,
            left_pad_target=left_pad_target,
            max_source_positions=max_source_positions,
            max_target_positions=max_target_positions,
        )
Exemple #23
0
def load_langpair_dataset(
    data_path,
    split,
    src,
    src_dict,
    tgt,
    tgt_dict,
    combine,
    dataset_impl,
    upsample_primary,
    left_pad_source,
    left_pad_target,
    max_source_positions,
    max_target_positions,
    prepend_bos=False,
    load_alignments=False,
    truncate_source=False,
):
    def split_exists(split, src, tgt, lang, data_path):
        filename = os.path.join(data_path,
                                '{}.{}-{}.{}'.format(split, src, tgt, lang))
        return indexed_dataset.dataset_exists(filename, impl=dataset_impl)

    src_datasets = []
    tgt_datasets = []

    for k in itertools.count():
        split_k = split + (str(k) if k > 0 else '')

        # infer langcode
        if split_exists(split_k, src, tgt, src, data_path):
            prefix = os.path.join(data_path,
                                  '{}.{}-{}.'.format(split_k, src, tgt))
        elif split_exists(split_k, tgt, src, src, data_path):
            prefix = os.path.join(data_path,
                                  '{}.{}-{}.'.format(split_k, tgt, src))
        else:
            if k > 0:
                break
            else:
                raise FileNotFoundError('Dataset not found: {} ({})'.format(
                    split, data_path))

        src_dataset = data_utils.load_indexed_dataset(prefix + src, src_dict,
                                                      dataset_impl)
        if truncate_source:
            src_dataset = AppendTokenDataset(
                TruncateDataset(
                    StripTokenDataset(src_dataset, src_dict.eos()),
                    max_source_positions - 1,
                ),
                src_dict.eos(),
            )
        src_datasets.append(src_dataset)
        tgt_datasets.append(
            data_utils.load_indexed_dataset(prefix + tgt, tgt_dict,
                                            dataset_impl))

        logger.info('{} {} {}-{} {} examples'.format(data_path, split_k,
                                                     src, tgt,
                                                     len(src_datasets[-1])))

        if not combine:
            break

    assert len(src_datasets) == len(tgt_datasets)

    if len(src_datasets) == 1:
        src_dataset, tgt_dataset = src_datasets[0], tgt_datasets[0]
    else:
        sample_ratios = [1] * len(src_datasets)
        sample_ratios[0] = upsample_primary
        src_dataset = ConcatDataset(src_datasets, sample_ratios)
        tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)

    if prepend_bos:
        assert hasattr(src_dict, "bos_index") and hasattr(
            tgt_dict, "bos_index")
        src_dataset = PrependTokenDataset(src_dataset, src_dict.bos())
        tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos())

    align_dataset = None
    if load_alignments:
        align_path = os.path.join(data_path,
                                  '{}.align.{}-{}'.format(split, src, tgt))
        if indexed_dataset.dataset_exists(align_path, impl=dataset_impl):
            align_dataset = data_utils.load_indexed_dataset(
                align_path, None, dataset_impl)

    return LanguagePairDataset(
        src_dataset,
        src_dataset.sizes,
        src_dict,
        tgt_dataset,
        tgt_dataset.sizes,
        tgt_dict,
        left_pad_source=left_pad_source,
        left_pad_target=left_pad_target,
        max_source_positions=max_source_positions,
        max_target_positions=max_target_positions,
        align_dataset=align_dataset,
    )
Exemple #24
0
 def build_dataset_for_inference(self, src_tokens, src_lengths):
     # TODO: add extra-features if exists
     return LanguagePairDataset(src_tokens, src_lengths,
                                self.source_dictionary)
Exemple #25
0
    def load_langpair_dataset(
        self,
        data_path,
        split,
        src,
        src_dict,
        tgt,
        tgt_dict,
        combine,
        dataset_impl,
        upsample_primary,
        left_pad_source,
        left_pad_target,
        max_source_positions,
        max_target_positions,
        prepend_bos=False,
        load_alignments=False,
        truncate_source=False,
        src_dataset_transform_func=lambda dataset: dataset,
        tgt_dataset_transform_func=lambda dataset: dataset,
        src_lang_id=None,
        tgt_lang_id=None,
        langpairs_sharing_datasets=None,
    ):
        norm_direction = "-".join(sorted([src, tgt]))
        if langpairs_sharing_datasets is not None:
            src_dataset = langpairs_sharing_datasets.get(
                (data_path, split, norm_direction, src), "NotInCache"
            )
            tgt_dataset = langpairs_sharing_datasets.get(
                (data_path, split, norm_direction, tgt), "NotInCache"
            )
            align_dataset = langpairs_sharing_datasets.get(
                (data_path, split, norm_direction, src, tgt), "NotInCache"
            )

        # a hack: any one is not in cache, we need to reload them
        if (
            langpairs_sharing_datasets is None
            or src_dataset == "NotInCache"
            or tgt_dataset == "NotInCache"
            or align_dataset == "NotInCache"
            or split != getattr(self.args, "train_subset", None)
        ):
            # source and target datasets can be reused in reversed directions to save memory
            # reversed directions of valid and test data will not share source and target datasets
            src_dataset, tgt_dataset, align_dataset = self.load_lang_dataset(
                data_path,
                split,
                src,
                src_dict,
                tgt,
                tgt_dict,
                combine,
                dataset_impl,
                upsample_primary,
                max_source_positions=max_source_positions,
                prepend_bos=prepend_bos,
                load_alignments=load_alignments,
                truncate_source=truncate_source,
            )
            src_dataset = src_dataset_transform_func(src_dataset)
            tgt_dataset = tgt_dataset_transform_func(tgt_dataset)
            if langpairs_sharing_datasets is not None:
                langpairs_sharing_datasets[
                    (data_path, split, norm_direction, src)
                ] = src_dataset
                langpairs_sharing_datasets[
                    (data_path, split, norm_direction, tgt)
                ] = tgt_dataset
                langpairs_sharing_datasets[
                    (data_path, split, norm_direction, src, tgt)
                ] = align_dataset
                if align_dataset is None:
                    # no align data so flag the reverse direction as well in sharing
                    langpairs_sharing_datasets[
                        (data_path, split, norm_direction, tgt, src)
                    ] = align_dataset
        else:
            logger.info(
                f"Reusing source and target datasets of [{split}] {tgt}-{src} for reversed direction: "
                f"[{split}] {src}-{tgt}: src length={len(src_dataset)}; tgt length={len(tgt_dataset)}"
            )

        return LanguagePairDataset(
            src_dataset,
            src_dataset.sizes,
            src_dict,
            tgt_dataset,
            tgt_dataset.sizes if tgt_dataset is not None else None,
            tgt_dict,
            left_pad_source=left_pad_source,
            left_pad_target=left_pad_target,
            align_dataset=align_dataset,
            src_lang_id=src_lang_id,
            tgt_lang_id=tgt_lang_id,
        )
Exemple #26
0
    def load_dataset(self, split, combine=False):
        """Load a dataset split."""
        def split_exists(split, src, tgt, lang):
            filename = os.path.join(
                self.args.data, '{}.{}-{}.{}'.format(split, src, tgt, lang))
            print('filename:', filename)
            print('raw_text:', self.args.raw_text)
            if self.args.raw_text and IndexedRawTokenIDDataset.exists(
                    filename):
                return True
            elif not self.args.raw_text and IndexedInMemoryDataset.exists(
                    filename):
                return True
            return False

        def indexed_dataset(path, dictionary):
            if self.args.raw_text and not self.args.uniform_n_seq_per_batch and not self.args.uniform_seq_len_per_batch:
                return IndexedRawTokenIDDataset(path, dictionary)
            elif IndexedInMemoryDataset.exists(
                    path
            ) and not self.args.uniform_n_seq_per_batch and not self.args.uniform_seq_len_per_batch:
                return IndexedInMemoryDataset(path)
            elif self.args.uniform_n_seq_per_batch and self.args.uniform_seq_len_per_batch:
                if self.args.uniform_n_seq_in_dataset:
                    return MockedInMemoryDataset(
                        path, self.args.uniform_n_seq_in_dataset,
                        self.args.uniform_n_seq_per_batch,
                        self.args.uniform_seq_len_per_batch)
            return None

        src_datasets = []
        tgt_datasets = []

        for k in itertools.count():
            split_k = split + (str(k) if k > 0 else '')

            # infer langcode
            src, tgt = self.args.source_lang, self.args.target_lang
            if split_exists(split_k, src, tgt, src):
                prefix = os.path.join(self.args.data,
                                      '{}.{}-{}.'.format(split_k, src, tgt))
            elif split_exists(split_k, tgt, src, src):
                prefix = os.path.join(self.args.data,
                                      '{}.{}-{}.'.format(split_k, tgt, src))
            else:
                if k > 0:
                    break
                else:
                    raise FileNotFoundError(
                        'Dataset not found: {} ({})'.format(
                            split, self.args.data))

            src_datasets.append(indexed_dataset(prefix + src, self.src_dict))
            tgt_datasets.append(indexed_dataset(prefix + tgt, self.tgt_dict))

            print('| {} {} {} examples'.format(self.args.data, split_k,
                                               len(src_datasets[-1])))

            if not combine:
                break

        assert len(src_datasets) == len(tgt_datasets)

        if len(src_datasets) == 1:
            src_dataset, tgt_dataset = src_datasets[0], tgt_datasets[0]
            src_sizes = src_dataset.sizes
            tgt_sizes = tgt_dataset.sizes
        else:
            src_dataset = ConcatDataset(src_datasets)
            tgt_dataset = ConcatDataset(tgt_datasets)
            src_sizes = np.concatenate([ds.sizes for ds in src_datasets])
            tgt_sizes = np.concatenate([ds.sizes for ds in tgt_datasets])

        print('srcline:', src_dataset[0])

        self.datasets[split] = LanguagePairDataset(
            src_dataset,
            src_sizes,
            self.src_dict,
            tgt_dataset,
            tgt_sizes,
            self.tgt_dict,
            left_pad_source=self.args.left_pad_source,
            left_pad_target=self.args.left_pad_target,
            max_source_positions=self.args.max_source_positions,
            max_target_positions=self.args.max_target_positions,
            seq_len_multiple=self.args.seq_len_multiple,
        )
def load_langpair_dataset(
    data_path,
    split,
    src,
    src_dict,
    tgt,
    tgt_dict,
    combine,
    dataset_impl,
    upsample_primary,
    left_pad_source,
    left_pad_target,
    max_source_positions,
    max_target_positions,
    prepend_bos=False,
    load_alignments=False,
    truncate_source=False,
    append_source_id=False,
    num_buckets=0,
    shuffle=True,
    pad_to_multiple=1,
    prepend_bos_src=None,
):
    def split_exists(split, src, tgt, lang, data_path):
        filename = os.path.join(data_path,
                                "{}.{}-{}.{}".format(split, src, tgt, lang))
        return indexed_dataset.dataset_exists(filename, impl=dataset_impl)

    src_datasets = []
    tgt_datasets = []

    for k in itertools.count():
        split_k = split + (str(k) if k > 0 else "")

        # infer langcode
        if split_exists(split_k, src, tgt, src, data_path):
            prefix = os.path.join(data_path,
                                  "{}.{}-{}.".format(split_k, src, tgt))
        elif split_exists(split_k, tgt, src, src, data_path):
            prefix = os.path.join(data_path,
                                  "{}.{}-{}.".format(split_k, tgt, src))
        else:
            if k > 0:
                break
            else:
                raise FileNotFoundError("Dataset not found: {} ({})".format(
                    split, data_path))

        src_dataset = data_utils.load_indexed_dataset(prefix + src, src_dict,
                                                      dataset_impl)
        if truncate_source:
            src_dataset = AppendTokenDataset(
                TruncateDataset(
                    StripTokenDataset(src_dataset, src_dict.eos()),
                    max_source_positions - 1,
                ),
                src_dict.eos(),
            )
        src_datasets.append(src_dataset)

        tgt_dataset = data_utils.load_indexed_dataset(prefix + tgt, tgt_dict,
                                                      dataset_impl)
        if tgt_dataset is not None:
            tgt_datasets.append(tgt_dataset)

        logger.info("{} {} {}-{} {} examples".format(data_path, split_k,
                                                     src, tgt,
                                                     len(src_datasets[-1])))

        if not combine:
            break

    assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0

    if len(src_datasets) == 1:
        src_dataset = src_datasets[0]
        tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None
    else:
        sample_ratios = [1] * len(src_datasets)
        sample_ratios[0] = upsample_primary
        src_dataset = ConcatDataset(src_datasets, sample_ratios)
        if len(tgt_datasets) > 0:
            tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)
        else:
            tgt_dataset = None

    if prepend_bos:
        assert hasattr(src_dict, "bos_index") and hasattr(
            tgt_dict, "bos_index")
        src_dataset = PrependTokenDataset(src_dataset, src_dict.bos())
        if tgt_dataset is not None:
            tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos())
    elif prepend_bos_src is not None:
        logger.info(f"prepending src bos: {prepend_bos_src}")
        src_dataset = PrependTokenDataset(src_dataset, prepend_bos_src)

    eos = None
    if append_source_id:
        src_dataset = AppendTokenDataset(src_dataset,
                                         src_dict.index("[{}]".format(src)))
        if tgt_dataset is not None:
            tgt_dataset = AppendTokenDataset(
                tgt_dataset, tgt_dict.index("[{}]".format(tgt)))
        eos = tgt_dict.index("[{}]".format(tgt))

    align_dataset = None
    if load_alignments:
        align_path = os.path.join(data_path,
                                  "{}.align.{}-{}".format(split, src, tgt))
        if indexed_dataset.dataset_exists(align_path, impl=dataset_impl):
            align_dataset = data_utils.load_indexed_dataset(
                align_path, None, dataset_impl)

    tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None
    return LanguagePairDataset(
        src_dataset,
        src_dataset.sizes,
        src_dict,
        tgt_dataset,
        tgt_dataset_sizes,
        tgt_dict,
        left_pad_source=left_pad_source,
        left_pad_target=left_pad_target,
        align_dataset=align_dataset,
        eos=eos,
        num_buckets=num_buckets,
        shuffle=shuffle,
        pad_to_multiple=pad_to_multiple,
    )
Exemple #28
0
    def load_dataset(self, split, epoch=1, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        paths = utils.split_paths(self.args.data)
        assert len(paths) > 0
        is_train_subset = split == getattr(self.args, "train_subset", None)
        if not is_train_subset:
            # if not training data set, use the first shard for valid and test
            paths = paths[:1]
        data_path = paths[(epoch - 1) % len(paths)]

        # infer langcode
        src, tgt = self.args.source_lang, self.args.target_lang
        """
        this is mask_word_initial
        WordNoising uses mask_word_end or mask_bpe_cont
        probably easiest to write FlippedDataset that reverses sequences
        and use the standard pipeline

        load_langpair_dataset:
            find files by pattern
            load_indexed source
                maybe truncate
                load target
            check shard counts
            sample ratios
            bos, source_id
            load_alignments
            LangpairDataset constructor

        """

        src_dataset, tgt_dataset = load_unpaired_langpair(
            data_path,
            split,
            src,
            self.src_dict,
            tgt,
            self.tgt_dict,
            combine=combine,
            dataset_impl=self.args.dataset_impl,
            max_source_positions=self.args.max_source_positions,
            max_target_positions=self.args.max_target_positions,
            truncate_source=self.args.truncate_source,
            prepend_bos=self.args.prepend_bos,
        )

        if self.args.bpe_dropout > 0:
            src_dataset = DynamicGPT2BPEDropoutResampling(
                self.args,
                src_dataset,
                self.source_dictionary,
                dropout=self.args.bpe_dropout,
            )

        # load backtranslation
        if is_train_subset and not self.args.skip_backtranslation_data:
            """
            noised vs unnoised valdation set? they might converge at different times
            """
            bt_src_dataset, bt_tgt_dataset = load_unpaired_langpair(
                # data_path, "{}.bt".format(split), src, self.src_dict, tgt, self.tgt_dict,
                data_path,
                "{}.bt".format(split),
                src,
                self.src_dict,
                tgt,
                self.tgt_dict,
                combine=combine,
                dataset_impl=self.args.dataset_impl,
                max_source_positions=self.args.max_source_positions,
                max_target_positions=self.args.max_target_positions,
                truncate_source=self.args.truncate_source,
                prepend_bos=self.args.prepend_bos,
            )
            if self.args.bpe == "gpt2":
                mask_is_beginning_of_word = get_whole_word_mask(
                    self.args, self.source_dictionary)
                mask_is_beginning_of_word = mask_is_beginning_of_word.numpy(
                ).astype(np.bool)
                # noiser = GPT2WordNoising(
                #     self.src_dict,
                #     mask_is_beginning_of_word,
                #     self.args.max_word_shuffle_distance,
                #     self.args.word_dropout_prob,
                #     self.args.word_blanking_prob,
                # )
                if self.args.bpe_dropout > 0:
                    bt_src_dataset = DynamicGPT2BPEDropoutResampling(
                        self.args,
                        bt_src_dataset,
                        self.source_dictionary,
                        dropout=self.args.bpe_dropout,
                    )
                noiser = GPT2WordNoisingV2(
                    self.src_dict,
                    mask_is_beginning_of_word,
                    self.args.max_word_shuffle_distance,
                    self.args.word_dropout_prob,
                    self.args.word_blanking_prob,
                )
                bt_src_dataset = DynamicNoisingDataset(
                    bt_src_dataset,
                    self.src_dict,
                    seed=1,
                    noiser=noiser,
                )

                # try:
                #     from icecream import ic
                #     ic.configureOutput(includeContext=True)
                # except ImportError:  # Graceful fallback if IceCream isn't installed.
                #     ic = lambda *a: None if not a else (a[0] if len(a) == 1 else a)  # noqa
                # ic("gpt2 bbpe")
                # bpe = encoders.build_bpe(self.args)
                # def decode(foo):
                #     return bpe.decode(self.src_dict.string(foo))
                # def disp(foo):
                #     return " ".join([bpe.decode(i) for i in self.src_dict.string(foo).split(" ")])
                #     # foo = [bpe.decode(str(i)) for i in range(0,1000)]
                #     # doo = [bpe.decode((i)) for i in self.src_dict.symbols[4:1000]]
                # for i in range(5):
                #     ic(_bt_src_dataset[i])
                #     ic(decode(_bt_src_dataset[i]))
                #     ic(disp(_bt_src_dataset[i]))
                #     ic(disp(bt_src_dataset[i]))
                #     ic(bt_src_dataset[i])
                # import pdb; pdb.set_trace()
            else:
                assert self.args.bpe_dropout <= 0, "BPE dropout not supported for this BPE scheme"
                # standard bpe with @@ as continuation marker
                bt_src_dataset = DynamicNoisingDataset(
                    bt_src_dataset,
                    self.src_dict,
                    seed=1,
                    max_word_shuffle_distance=self.args.
                    max_word_shuffle_distance,
                    word_dropout_prob=self.args.word_dropout_prob,
                    word_blanking_prob=self.args.word_blanking_prob,
                )
            # if self.append_backtranslation_tag:
            if self.args.tagged_backtranslation:
                bt_src_dataset = AppendTokenDataset(
                    AppendTokenDataset(
                        StripTokenDataset(bt_src_dataset, self.src_dict.eos()),
                        self.bt_idx),
                    self.src_dict.eos(),
                )

            sample_ratios = [self.args.upsample_primary, 1]
            src_dataset = ConcatDataset([src_dataset, bt_src_dataset],
                                        sample_ratios)
            tgt_dataset = ConcatDataset([tgt_dataset, bt_tgt_dataset],
                                        sample_ratios)

        self.datasets[split] = LanguagePairDataset(
            src_dataset,
            src_dataset.sizes,
            self.src_dict,
            tgt_dataset,
            tgt_dataset.sizes,
            self.tgt_dict,
            left_pad_source=self.args.left_pad_source,
            left_pad_target=self.args.left_pad_target,
            align_dataset=None,
            eos=self.tgt_dict.eos(),
            num_buckets=self.args.num_batch_buckets,
            shuffle=(split not in ("test", "valid")),
        )