Esempio n. 1
0
    def alter_dataset_langtok(
        self,
        lang_pair_dataset,
        src_eos=None,
        src_lang=None,
        tgt_eos=None,
        tgt_lang=None,
        src_langtok_spec=None,
        tgt_langtok_spec=None,
    ):
        if src_langtok_spec is None and tgt_langtok_spec is None:
            return lang_pair_dataset

        new_src_eos = None
        if (src_langtok_spec is not None and src_eos is not None
                and (src_lang is not None or tgt_lang is not None)):
            new_src_eos = self.get_encoder_langtok(src_lang, tgt_lang,
                                                   src_langtok_spec)
        else:
            src_eos = None

        new_tgt_bos = None
        if tgt_langtok_spec and tgt_eos is not None and tgt_lang is not None:
            new_tgt_bos = self.get_decoder_langtok(tgt_lang, tgt_langtok_spec)
        else:
            tgt_eos = None

        return TransformEosLangPairDataset(
            lang_pair_dataset,
            src_eos=src_eos,
            new_src_eos=new_src_eos,
            tgt_bos=tgt_eos,
            new_tgt_bos=new_tgt_bos,
        )
Esempio n. 2
0
    def alter_dataset_langtok(self, lang_pair_dataset,
                              src_eos=None, src_lang=None, tgt_eos=None, tgt_lang=None):
        if self.args.encoder_langtok is None and not self.args.decoder_langtok:
            return lang_pair_dataset

        new_src_eos = None
        if self.args.encoder_langtok is not None and src_eos is not None \
           and src_lang is not None and tgt_lang is not None:
            new_src_eos = self.get_encoder_langtok(src_lang, tgt_lang)
        else:
            src_eos = None

        new_tgt_bos = None
        if self.args.decoder_langtok and tgt_eos is not None and tgt_lang is not None:
            new_tgt_bos = self.get_decoder_langtok(tgt_lang)
        else:
            tgt_eos = None

        return TransformEosLangPairDataset(
            lang_pair_dataset,
            src_eos=src_eos,
            new_src_eos=new_src_eos,
            tgt_bos=tgt_eos,
            new_tgt_bos=new_tgt_bos,
        )
Esempio n. 3
0
 def _prepend_lang_bos_to_target(self, dataset: LanguagePairDataset,
                                 lang: str) -> LanguagePairDataset:
     bos = _lang_token_index(self.dictionary, lang)
     return TransformEosLangPairDataset(
         dataset,
         src_eos=self.dictionary.eos(),
         new_src_eos=self.dictionary.eos(),
         tgt_bos=self.dictionary.eos(),
         new_tgt_bos=bos,
     )
Esempio n. 4
0
 def load_langpair_dataset(self,
                           prepend_tgt_lang_tag=False,
                           sampling_alpha=1.0,
                           epoch=0):
     lang_pairs = []
     text_dataset = None
     split = "train"
     for lp in self.args.langpairs.split(","):
         src, tgt = lp.split("-")
         text_dataset = load_langpair_dataset(
             self.args.parallel_text_data,
             split,
             src,
             self.src_dict,
             tgt,
             self.tgt_dict,
             combine=True,
             dataset_impl=None,
             upsample_primary=1,
             left_pad_source=False,
             left_pad_target=False,
             max_source_positions=self.args.max_positions_text,
             max_target_positions=self.args.max_target_positions,
             load_alignments=False,
             truncate_source=False,
         )
         if prepend_tgt_lang_tag:
             # TODO
             text_dataset = TransformEosLangPairDataset(
                 text_dataset,
                 src_eos=self.src_dict.eos(),
                 tgt_bos=self.tgt_dict.eos(
                 ),  # 'prev_output_tokens' starts with eos
                 new_tgt_bos=self.tgt_dict.index(
                     LANG_TAG_TEMPLATE.format(tgt)),
             )
         lang_pairs.append(text_dataset)
     if len(lang_pairs) > 1:
         if sampling_alpha != 1.0:
             size_ratios = SpeechToTextDatasetCreator.get_size_ratios(
                 self.args.langpairs.split(","),
                 [len(s) for s in lang_pairs],
                 alpha=sampling_alpha,
             )
             lang_pairs = [
                 ResamplingDataset(d,
                                   size_ratio=r,
                                   epoch=epoch,
                                   replace=(r >= 1.0))
                 for d, r in zip(lang_pairs, size_ratios)
             ]
         return ConcatDataset(lang_pairs)
     return text_dataset
Esempio n. 5
0
    def alter_dataset_langtok(self,
                              lang_pair_dataset,
                              src_eos=None,
                              src_lang=None,
                              tgt_eos=None,
                              tgt_lang=None,
                              tgt_langs=[],
                              split='train'):
        if self.args.encoder_langtok is None and not self.args.decoder_langtok:
            return lang_pair_dataset

        new_src_eos = None
        if self.args.encoder_langtok is not None and src_eos is not None \
           and src_lang is not None and tgt_lang is not None:
            new_src_eos = self.get_encoder_langtok(src_lang, tgt_lang)
        else:
            src_eos = None

        new_tgt_bos = None
        if self.args.decoder_langtok and tgt_eos is not None and tgt_lang is not None:
            new_tgt_bos = self.get_decoder_langtok(tgt_lang)
        else:
            tgt_eos = None

        if split == 'train' and tgt_lang in tgt_langs:
            cur_tgt_idx = tgt_langs.index(tgt_lang)
            p = self.args.sample_tag_prob / (len(tgt_langs) - 1)
            new_src_eos_list_probs = [p for _ in range(len(tgt_langs))]
            new_src_eos_list_probs[cur_tgt_idx] = 1 - self.args.sample_tag_prob
            new_src_eos_list = [
                self.get_encoder_langtok(src_lang, t) for t in tgt_langs
            ]
        else:
            new_src_eos_list = None
            new_src_eos_list_probs = None

        return TransformEosLangPairDataset(
            lang_pair_dataset,
            src_eos=src_eos,
            new_src_eos=new_src_eos,
            tgt_bos=tgt_eos,
            new_tgt_bos=new_tgt_bos,
            new_src_eos_list=new_src_eos_list,
            new_src_eos_list_probs=new_src_eos_list_probs,
            split=split,
        )
Esempio n. 6
0
    def load_dataset_only(self,
                          split,
                          lang_pairs,
                          do_mask=True,
                          epoch=1,
                          combine=False):
        paths = utils.split_paths(self.args.data)
        assert len(paths) > 0
        data_path = paths[(epoch - 1) % len(paths)]

        # TODO unk token will be considered as first word too, though it might be an unknown phoneme within a word
        # get_whole_word_mask returns a tensor (size V by 1 ) to indicate if a token is a word start token
        mask_whole_src_words = gen_whole_word_mask(self.args, self.src_dict)
        language_without_segmentations = self.args.no_whole_word_mask_langs.split(
            ",")
        lang_datasets = []
        eos_bos = []
        lang_pairs = lang_pairs.split(",") if lang_pairs != "" else []
        assert len(lang_pairs) > 0
        for lp in lang_pairs:
            src, tgt = lp.split("-")
            lang_mask_whole_src_words = (mask_whole_src_words if src
                                         not in language_without_segmentations
                                         else None)

            end_token = (self.source_dictionary.index(
                PairedDenoisingTask.LANG_TAG_TEMPLATE.format(src))
                         if self.args.add_src_lang_token else None)
            bos_token = (self.target_dictionary.index(
                PairedDenoisingTask.LANG_TAG_TEMPLATE.format(tgt))
                         if self.args.add_tgt_lang_token else None)
            src_lang_id = None

            if self.args.add_src_lang_token or self.args.add_tgt_lang_token:
                eos_bos.append((end_token, bos_token))

            dataset = PairedDenoisingTask.language_pair_denoising_dataset(
                data_path,
                do_mask,
                split,
                src,
                self.source_dictionary,
                tgt,
                self.target_dictionary,
                self.mask_idx,
                lang_mask_whole_src_words,
                self.args.seed,
                self.args,
                self.args.dataset_impl,
                combine=combine,
                left_pad_source=utils.eval_bool(self.args.left_pad_source),
                left_pad_target=utils.eval_bool(self.args.left_pad_target),
                max_source_positions=self.args.max_source_positions,
                max_target_positions=self.args.max_target_positions,
                src_lang_id=src_lang_id,
            )

            lang_datasets.append(dataset)

        if len(lang_datasets) == 0:
            return
        elif len(lang_datasets) == 1:
            dataset = lang_datasets[0]
            if self.args.add_src_lang_token or self.args.add_tgt_lang_token:
                end_token, bos_token = eos_bos[0]
                dataset = TransformEosLangPairDataset(
                    dataset,
                    src_eos=self.source_dictionary.eos(),
                    new_src_eos=end_token,
                    tgt_bos=self.target_dictionary.eos(),
                    new_tgt_bos=bos_token,
                )
        else:
            end_tokens = [item[0] for item in eos_bos if item[0] is not None]
            bos_tokens = [item[1] for item in eos_bos if item[1] is not None]
            lang_datasets = self.resample_datasets(lang_datasets, lang_pairs,
                                                   epoch)
            dataset = TransformEosConcatLangPairDataset(
                lang_datasets,
                self.source_dictionary.eos(),
                self.target_dictionary.eos(),
                new_src_eos=end_tokens,
                new_tgt_bos=bos_tokens,
            )
        return dataset