Beispiel #1
0
    def build_dataset_for_inference(self,
                                    src_tokens,
                                    src_lengths,
                                    constraints=None):
        if constraints is not None:
            raise NotImplementedError(
                "Constrained decoding with the multilingual_translation task is not supported"
            )

        src_data = ListDataset(src_tokens, src_lengths)
        dataset = LanguagePairDataset(src_data, src_lengths,
                                      self.source_dictionary)
        src_langtok_spec, tgt_langtok_spec = self.args.langtoks["main"]
        if self.args.lang_tok_replacing_bos_eos:
            dataset = self.data_manager.alter_dataset_langtok(
                dataset,
                src_eos=self.source_dictionary.eos(),
                src_lang=self.args.source_lang,
                tgt_eos=self.target_dictionary.eos(),
                tgt_lang=self.args.target_lang,
                src_langtok_spec=src_langtok_spec,
                tgt_langtok_spec=tgt_langtok_spec,
            )
        else:
            dataset.src = self.data_manager.src_dataset_tranform_func(
                self.args.source_lang,
                self.args.target_lang,
                dataset=dataset.src,
                spec=src_langtok_spec,
            )
        return dataset
Beispiel #2
0
    def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None):
        if constraints is not None:
            # Though see Susanto et al. (ACL 2020): https://www.aclweb.org/anthology/2020.acl-main.325/
            raise NotImplementedError(
                "Constrained decoding with the translation_lev task is not supported"
            )

        return LanguagePairDataset(
            src_tokens, src_lengths, self.source_dictionary, append_bos=True
        )
Beispiel #3
0
 def build_dataset_for_inference(self,
                                 src_tokens,
                                 src_lengths,
                                 constraints=None):
     return LanguagePairDataset(
         src_tokens,
         src_lengths,
         self.source_dictionary,
         tgt_dict=self.target_dictionary,
         constraints=constraints,
     )
 def build_dataset_for_inference(self,
                                 src_tokens,
                                 src_lengths,
                                 constraints=None):
     src_lang_id = self.source_dictionary.index("[{}]".format(
         self.args.source_lang))
     source_tokens = []
     for s_t in src_tokens:
         s_t = torch.cat([s_t, s_t.new(1).fill_(src_lang_id)])
         source_tokens.append(s_t)
     dataset = LanguagePairDataset(
         source_tokens,
         src_lengths,
         self.source_dictionary,
         tgt_dict=self.target_dictionary,
         constraints=constraints,
     )
     return dataset
 def language_pair_dataset(lang_pair):
     src, tgt = lang_pair.split("-")
     src_dataset, tgt_dataset = src_datasets[lang_pair], tgt_datasets[
         lang_pair]
     return self.alter_dataset_langtok(
         LanguagePairDataset(
             src_dataset,
             src_dataset.sizes,
             self.dicts[src],
             tgt_dataset,
             tgt_dataset.sizes,
             self.dicts[tgt],
             left_pad_source=self.args.left_pad_source,
             left_pad_target=self.args.left_pad_target,
         ),
         self.dicts[src].eos(),
         src,
         self.dicts[tgt].eos(),
         tgt,
     )
    def build_dataset_for_inference(self,
                                    src_tokens,
                                    src_lengths,
                                    constraints=None):
        if constraints is not None:
            raise NotImplementedError(
                "Constrained decoding with the multilingual_translation task is not supported"
            )

        lang_pair = "%s-%s" % (self.args.source_lang, self.args.target_lang)
        return RoundRobinZipDatasets(
            OrderedDict([(
                lang_pair,
                self.alter_dataset_langtok(
                    LanguagePairDataset(src_tokens, src_lengths,
                                        self.source_dictionary),
                    src_eos=self.source_dictionary.eos(),
                    src_lang=self.args.source_lang,
                    tgt_eos=self.target_dictionary.eos(),
                    tgt_lang=self.args.target_lang,
                ),
            )]),
            eval_key=lang_pair,
        )
    def load_langpair_dataset(
        self,
        data_path,
        split,
        src,
        src_dict,
        tgt,
        tgt_dict,
        combine,
        dataset_impl,
        upsample_primary,
        left_pad_source,
        left_pad_target,
        max_source_positions,
        max_target_positions,
        prepend_bos=False,
        load_alignments=False,
        truncate_source=False,
        src_dataset_transform_func=lambda dataset: dataset,
        tgt_dataset_transform_func=lambda dataset: dataset,
        src_lang_id=None,
        tgt_lang_id=None,
        langpairs_sharing_datasets=None,
    ):
        norm_direction = "-".join(sorted([src, tgt]))
        if langpairs_sharing_datasets is not None:
            src_dataset = langpairs_sharing_datasets.get(
                (data_path, split, norm_direction, src), "NotInCache")
            tgt_dataset = langpairs_sharing_datasets.get(
                (data_path, split, norm_direction, tgt), "NotInCache")
            align_dataset = langpairs_sharing_datasets.get(
                (data_path, split, norm_direction, src, tgt), "NotInCache")

        # a hack: any one is not in cache, we need to reload them
        if (langpairs_sharing_datasets is None or src_dataset == "NotInCache"
                or tgt_dataset == "NotInCache" or align_dataset == "NotInCache"
                or split != getattr(self.args, "train_subset", None)):
            # source and target datasets can be reused in reversed directions to save memory
            # reversed directions of valid and test data will not share source and target datasets
            src_dataset, tgt_dataset, align_dataset = self.load_lang_dataset(
                data_path,
                split,
                src,
                src_dict,
                tgt,
                tgt_dict,
                combine,
                dataset_impl,
                upsample_primary,
                max_source_positions=max_source_positions,
                prepend_bos=prepend_bos,
                load_alignments=load_alignments,
                truncate_source=truncate_source,
            )
            src_dataset = src_dataset_transform_func(src_dataset)
            tgt_dataset = tgt_dataset_transform_func(tgt_dataset)
            if langpairs_sharing_datasets is not None:
                langpairs_sharing_datasets[(data_path, split, norm_direction,
                                            src)] = src_dataset
                langpairs_sharing_datasets[(data_path, split, norm_direction,
                                            tgt)] = tgt_dataset
                langpairs_sharing_datasets[(data_path, split, norm_direction,
                                            src, tgt)] = align_dataset
                if align_dataset is None:
                    # no align data so flag the reverse direction as well in sharing
                    langpairs_sharing_datasets[(data_path, split,
                                                norm_direction, tgt,
                                                src)] = align_dataset
        else:
            logger.info(
                f"Reusing source and target datasets of [{split}] {tgt}-{src} for reversed direction: "
                f"[{split}] {src}-{tgt}: src length={len(src_dataset)}; tgt length={len(tgt_dataset)}"
            )

        return LanguagePairDataset(
            src_dataset,
            src_dataset.sizes,
            src_dict,
            tgt_dataset,
            tgt_dataset.sizes if tgt_dataset is not None else None,
            tgt_dict,
            left_pad_source=left_pad_source,
            left_pad_target=left_pad_target,
            align_dataset=align_dataset,
            src_lang_id=src_lang_id,
            tgt_lang_id=tgt_lang_id,
        )
Beispiel #8
0
def load_langpair_dataset(
    data_path,
    split,
    src,
    src_dict,
    tgt,
    tgt_dict,
    combine,
    dataset_impl,
    upsample_primary,
    left_pad_source,
    left_pad_target,
    max_source_positions,
    max_target_positions,
    prepend_bos=False,
    load_alignments=False,
    truncate_source=False,
    append_source_id=False,
    num_buckets=0,
    shuffle=True,
    pad_to_multiple=1,
):
    def split_exists(split, src, tgt, lang, data_path):
        filename = os.path.join(data_path,
                                "{}.{}-{}.{}".format(split, src, tgt, lang))
        return indexed_dataset.dataset_exists(filename, impl=dataset_impl)

    src_datasets = []
    tgt_datasets = []

    for k in itertools.count():
        split_k = split + (str(k) if k > 0 else "")

        # infer langcode
        if split_exists(split_k, src, tgt, src, data_path):
            prefix = os.path.join(data_path,
                                  "{}.{}-{}.".format(split_k, src, tgt))
        elif split_exists(split_k, tgt, src, src, data_path):
            prefix = os.path.join(data_path,
                                  "{}.{}-{}.".format(split_k, tgt, src))
        else:
            if k > 0:
                break
            else:
                raise FileNotFoundError("Dataset not found: {} ({})".format(
                    split, data_path))

        src_dataset = data_utils.load_indexed_dataset(prefix + src, src_dict,
                                                      dataset_impl)
        if truncate_source:
            src_dataset = AppendTokenDataset(
                TruncateDataset(
                    StripTokenDataset(src_dataset, src_dict.eos()),
                    max_source_positions - 1,
                ),
                src_dict.eos(),
            )
        src_datasets.append(src_dataset)

        tgt_dataset = data_utils.load_indexed_dataset(prefix + tgt, tgt_dict,
                                                      dataset_impl)
        if tgt_dataset is not None:
            tgt_datasets.append(tgt_dataset)

        logger.info("{} {} {}-{} {} examples".format(data_path, split_k,
                                                     src, tgt,
                                                     len(src_datasets[-1])))

        if not combine:
            break

    assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0

    if len(src_datasets) == 1:
        src_dataset = src_datasets[0]
        tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None
    else:
        sample_ratios = [1] * len(src_datasets)
        sample_ratios[0] = upsample_primary
        src_dataset = ConcatDataset(src_datasets, sample_ratios)
        if len(tgt_datasets) > 0:
            tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)
        else:
            tgt_dataset = None

    if prepend_bos:
        assert hasattr(src_dict, "bos_index") and hasattr(
            tgt_dict, "bos_index")
        src_dataset = PrependTokenDataset(src_dataset, src_dict.bos())
        if tgt_dataset is not None:
            tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos())

    eos = None
    if append_source_id:
        src_dataset = AppendTokenDataset(src_dataset,
                                         src_dict.index("[{}]".format(src)))
        if tgt_dataset is not None:
            tgt_dataset = AppendTokenDataset(
                tgt_dataset, tgt_dict.index("[{}]".format(tgt)))
        eos = tgt_dict.index("[{}]".format(tgt))

    align_dataset = None
    if load_alignments:
        align_path = os.path.join(data_path,
                                  "{}.align.{}-{}".format(split, src, tgt))
        if indexed_dataset.dataset_exists(align_path, impl=dataset_impl):
            align_dataset = data_utils.load_indexed_dataset(
                align_path, None, dataset_impl)

    tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None
    return LanguagePairDataset(
        src_dataset,
        src_dataset.sizes,
        src_dict,
        tgt_dataset,
        tgt_dataset_sizes,
        tgt_dict,
        left_pad_source=left_pad_source,
        left_pad_target=left_pad_target,
        align_dataset=align_dataset,
        eos=eos,
        num_buckets=num_buckets,
        shuffle=shuffle,
        pad_to_multiple=pad_to_multiple,
    )
    def load_dataset(self, split, epoch=1, **kwargs):
        """Load a dataset split."""
        paths = utils.split_paths(self.args.data)
        assert len(paths) > 0
        data_path = paths[(epoch - 1) % len(paths)]

        def split_exists(split, src, tgt, lang):
            if src is not None:
                filename = os.path.join(
                    data_path, "{}.{}-{}.{}".format(split, src, tgt, lang))
            else:
                filename = os.path.join(
                    data_path, "{}.{}-None.{}".format(split, src, tgt))
            return indexed_dataset.dataset_exists(filename,
                                                  impl=self.args.dataset_impl)

        def load_indexed_dataset(path, dictionary):
            return data_utils.load_indexed_dataset(path, dictionary,
                                                   self.args.dataset_impl)

        # load parallel datasets
        src_datasets, tgt_datasets = {}, {}
        if (self.lambda_parallel > 0.0
                or self.lambda_parallel_steps is not None
                or not split.startswith("train")):
            for lang_pair in self.lang_pairs:
                src, tgt = lang_pair.split("-")
                if split_exists(split, src, tgt, src):
                    prefix = os.path.join(data_path,
                                          "{}.{}-{}.".format(split, src, tgt))
                elif split_exists(split, tgt, src, src):
                    prefix = os.path.join(data_path,
                                          "{}.{}-{}.".format(split, tgt, src))
                else:
                    continue
                src_datasets[lang_pair] = load_indexed_dataset(
                    prefix + src, self.dicts[src])
                tgt_datasets[lang_pair] = load_indexed_dataset(
                    prefix + tgt, self.dicts[tgt])
                logger.info("parallel-{} {} {} examples".format(
                    data_path, split, len(src_datasets[lang_pair])))
            if len(src_datasets) == 0:
                raise FileNotFoundError("Dataset not found: {} ({})".format(
                    split, data_path))

        # back translation datasets
        backtranslate_datasets = {}
        if (self.lambda_otf_bt > 0.0 or self.lambda_otf_bt_steps
                is not None) and split.startswith("train"):
            for lang_pair in self.lang_pairs:
                src, tgt = lang_pair.split("-")
                if not split_exists(split, tgt, None, tgt):
                    raise FileNotFoundError(
                        "Dataset not found: backtranslation {} ({})".format(
                            split, data_path))
                filename = os.path.join(
                    data_path, "{}.{}-None.{}".format(split, tgt, tgt))
                dataset = load_indexed_dataset(filename, self.dicts[tgt])
                lang_pair_dataset_tgt = LanguagePairDataset(
                    dataset,
                    dataset.sizes,
                    self.dicts[tgt],
                    left_pad_source=self.args.left_pad_source,
                    left_pad_target=self.args.left_pad_target,
                )
                lang_pair_dataset = LanguagePairDataset(
                    dataset,
                    dataset.sizes,
                    src_dict=self.dicts[src],
                    tgt=dataset,
                    tgt_sizes=dataset.sizes,
                    tgt_dict=self.dicts[tgt],
                    left_pad_source=self.args.left_pad_source,
                    left_pad_target=self.args.left_pad_target,
                )
                backtranslate_datasets[lang_pair] = BacktranslationDataset(
                    tgt_dataset=self.alter_dataset_langtok(
                        lang_pair_dataset_tgt,
                        src_eos=self.dicts[tgt].eos(),
                        src_lang=tgt,
                        tgt_lang=src,
                    ),
                    backtranslation_fn=self.backtranslators[lang_pair],
                    src_dict=self.dicts[src],
                    tgt_dict=self.dicts[tgt],
                    output_collater=self.alter_dataset_langtok(
                        lang_pair_dataset=lang_pair_dataset,
                        src_eos=self.dicts[src].eos(),
                        src_lang=src,
                        tgt_eos=self.dicts[tgt].eos(),
                        tgt_lang=tgt,
                    ).collater,
                )
                logger.info("backtranslate-{}: {} {} {} examples".format(
                    tgt,
                    data_path,
                    split,
                    len(backtranslate_datasets[lang_pair]),
                ))
                self.backtranslate_datasets[
                    lang_pair] = backtranslate_datasets[lang_pair]

        # denoising autoencoder
        noising_datasets = {}
        if (self.lambda_denoising > 0.0 or self.lambda_denoising_steps
                is not None) and split.startswith("train"):
            for lang_pair in self.lang_pairs:
                _, tgt = lang_pair.split("-")
                if not split_exists(split, tgt, None, tgt):
                    continue
                filename = os.path.join(
                    data_path, "{}.{}-None.{}".format(split, tgt, tgt))
                tgt_dataset1 = load_indexed_dataset(filename, self.dicts[tgt])
                tgt_dataset2 = load_indexed_dataset(filename, self.dicts[tgt])
                noising_dataset = NoisingDataset(
                    tgt_dataset1,
                    self.dicts[tgt],
                    seed=1,
                    max_word_shuffle_distance=self.args.
                    max_word_shuffle_distance,
                    word_dropout_prob=self.args.word_dropout_prob,
                    word_blanking_prob=self.args.word_blanking_prob,
                )
                noising_datasets[lang_pair] = self.alter_dataset_langtok(
                    LanguagePairDataset(
                        noising_dataset,
                        tgt_dataset1.sizes,
                        self.dicts[tgt],
                        tgt_dataset2,
                        tgt_dataset2.sizes,
                        self.dicts[tgt],
                        left_pad_source=self.args.left_pad_source,
                        left_pad_target=self.args.left_pad_target,
                    ),
                    src_eos=self.dicts[tgt].eos(),
                    src_lang=tgt,
                    tgt_eos=self.dicts[tgt].eos(),
                    tgt_lang=tgt,
                )
                logger.info("denoising-{}: {} {} {} examples".format(
                    tgt,
                    data_path,
                    split,
                    len(noising_datasets[lang_pair]),
                ))

        def language_pair_dataset(lang_pair):
            src, tgt = lang_pair.split("-")
            src_dataset, tgt_dataset = src_datasets[lang_pair], tgt_datasets[
                lang_pair]
            return self.alter_dataset_langtok(
                LanguagePairDataset(
                    src_dataset,
                    src_dataset.sizes,
                    self.dicts[src],
                    tgt_dataset,
                    tgt_dataset.sizes,
                    self.dicts[tgt],
                    left_pad_source=self.args.left_pad_source,
                    left_pad_target=self.args.left_pad_target,
                ),
                self.dicts[src].eos(),
                src,
                self.dicts[tgt].eos(),
                tgt,
            )

        self.datasets[split] = RoundRobinZipDatasets(
            OrderedDict(
                [(lang_pair, language_pair_dataset(lang_pair))
                 for lang_pair in src_datasets.keys()] +
                [(_get_bt_dataset_key(lang_pair), dataset)
                 for lang_pair, dataset in backtranslate_datasets.items()] +
                [(_get_denoising_dataset_key(lang_pair), dataset)
                 for lang_pair, dataset in noising_datasets.items()]),
            eval_key=None if self.training else "%s-%s" %
            (self.args.source_lang, self.args.target_lang),
        )