def alter_dataset_langtok( self, lang_pair_dataset, src_eos=None, src_lang=None, tgt_eos=None, tgt_lang=None, src_langtok_spec=None, tgt_langtok_spec=None, ): if src_langtok_spec is None and tgt_langtok_spec is None: return lang_pair_dataset new_src_eos = None if (src_langtok_spec is not None and src_eos is not None and (src_lang is not None or tgt_lang is not None)): new_src_eos = self.get_encoder_langtok(src_lang, tgt_lang, src_langtok_spec) else: src_eos = None new_tgt_bos = None if tgt_langtok_spec and tgt_eos is not None and tgt_lang is not None: new_tgt_bos = self.get_decoder_langtok(tgt_lang, tgt_langtok_spec) else: tgt_eos = None return TransformEosLangPairDataset( lang_pair_dataset, src_eos=src_eos, new_src_eos=new_src_eos, tgt_bos=tgt_eos, new_tgt_bos=new_tgt_bos, )
def alter_dataset_langtok(self, lang_pair_dataset, src_eos=None, src_lang=None, tgt_eos=None, tgt_lang=None): if self.args.encoder_langtok is None and not self.args.decoder_langtok: return lang_pair_dataset new_src_eos = None if self.args.encoder_langtok is not None and src_eos is not None \ and src_lang is not None and tgt_lang is not None: new_src_eos = self.get_encoder_langtok(src_lang, tgt_lang) else: src_eos = None new_tgt_bos = None if self.args.decoder_langtok and tgt_eos is not None and tgt_lang is not None: new_tgt_bos = self.get_decoder_langtok(tgt_lang) else: tgt_eos = None return TransformEosLangPairDataset( lang_pair_dataset, src_eos=src_eos, new_src_eos=new_src_eos, tgt_bos=tgt_eos, new_tgt_bos=new_tgt_bos, )
def _prepend_lang_bos_to_target(self, dataset: LanguagePairDataset, lang: str) -> LanguagePairDataset: bos = _lang_token_index(self.dictionary, lang) return TransformEosLangPairDataset( dataset, src_eos=self.dictionary.eos(), new_src_eos=self.dictionary.eos(), tgt_bos=self.dictionary.eos(), new_tgt_bos=bos, )
def load_langpair_dataset(self, prepend_tgt_lang_tag=False, sampling_alpha=1.0, epoch=0): lang_pairs = [] text_dataset = None split = "train" for lp in self.args.langpairs.split(","): src, tgt = lp.split("-") text_dataset = load_langpair_dataset( self.args.parallel_text_data, split, src, self.src_dict, tgt, self.tgt_dict, combine=True, dataset_impl=None, upsample_primary=1, left_pad_source=False, left_pad_target=False, max_source_positions=self.args.max_positions_text, max_target_positions=self.args.max_target_positions, load_alignments=False, truncate_source=False, ) if prepend_tgt_lang_tag: # TODO text_dataset = TransformEosLangPairDataset( text_dataset, src_eos=self.src_dict.eos(), tgt_bos=self.tgt_dict.eos( ), # 'prev_output_tokens' starts with eos new_tgt_bos=self.tgt_dict.index( LANG_TAG_TEMPLATE.format(tgt)), ) lang_pairs.append(text_dataset) if len(lang_pairs) > 1: if sampling_alpha != 1.0: size_ratios = SpeechToTextDatasetCreator.get_size_ratios( self.args.langpairs.split(","), [len(s) for s in lang_pairs], alpha=sampling_alpha, ) lang_pairs = [ ResamplingDataset(d, size_ratio=r, epoch=epoch, replace=(r >= 1.0)) for d, r in zip(lang_pairs, size_ratios) ] return ConcatDataset(lang_pairs) return text_dataset
def alter_dataset_langtok(self, lang_pair_dataset, src_eos=None, src_lang=None, tgt_eos=None, tgt_lang=None, tgt_langs=[], split='train'): if self.args.encoder_langtok is None and not self.args.decoder_langtok: return lang_pair_dataset new_src_eos = None if self.args.encoder_langtok is not None and src_eos is not None \ and src_lang is not None and tgt_lang is not None: new_src_eos = self.get_encoder_langtok(src_lang, tgt_lang) else: src_eos = None new_tgt_bos = None if self.args.decoder_langtok and tgt_eos is not None and tgt_lang is not None: new_tgt_bos = self.get_decoder_langtok(tgt_lang) else: tgt_eos = None if split == 'train' and tgt_lang in tgt_langs: cur_tgt_idx = tgt_langs.index(tgt_lang) p = self.args.sample_tag_prob / (len(tgt_langs) - 1) new_src_eos_list_probs = [p for _ in range(len(tgt_langs))] new_src_eos_list_probs[cur_tgt_idx] = 1 - self.args.sample_tag_prob new_src_eos_list = [ self.get_encoder_langtok(src_lang, t) for t in tgt_langs ] else: new_src_eos_list = None new_src_eos_list_probs = None return TransformEosLangPairDataset( lang_pair_dataset, src_eos=src_eos, new_src_eos=new_src_eos, tgt_bos=tgt_eos, new_tgt_bos=new_tgt_bos, new_src_eos_list=new_src_eos_list, new_src_eos_list_probs=new_src_eos_list_probs, split=split, )
def load_dataset_only(self, split, lang_pairs, do_mask=True, epoch=1, combine=False): paths = utils.split_paths(self.args.data) assert len(paths) > 0 data_path = paths[(epoch - 1) % len(paths)] # TODO unk token will be considered as first word too, though it might be an unknown phoneme within a word # get_whole_word_mask returns a tensor (size V by 1 ) to indicate if a token is a word start token mask_whole_src_words = gen_whole_word_mask(self.args, self.src_dict) language_without_segmentations = self.args.no_whole_word_mask_langs.split( ",") lang_datasets = [] eos_bos = [] lang_pairs = lang_pairs.split(",") if lang_pairs != "" else [] assert len(lang_pairs) > 0 for lp in lang_pairs: src, tgt = lp.split("-") lang_mask_whole_src_words = (mask_whole_src_words if src not in language_without_segmentations else None) end_token = (self.source_dictionary.index( PairedDenoisingTask.LANG_TAG_TEMPLATE.format(src)) if self.args.add_src_lang_token else None) bos_token = (self.target_dictionary.index( PairedDenoisingTask.LANG_TAG_TEMPLATE.format(tgt)) if self.args.add_tgt_lang_token else None) src_lang_id = None if self.args.add_src_lang_token or self.args.add_tgt_lang_token: eos_bos.append((end_token, bos_token)) dataset = PairedDenoisingTask.language_pair_denoising_dataset( data_path, do_mask, split, src, self.source_dictionary, tgt, self.target_dictionary, self.mask_idx, lang_mask_whole_src_words, self.args.seed, self.args, self.args.dataset_impl, combine=combine, left_pad_source=utils.eval_bool(self.args.left_pad_source), left_pad_target=utils.eval_bool(self.args.left_pad_target), max_source_positions=self.args.max_source_positions, max_target_positions=self.args.max_target_positions, src_lang_id=src_lang_id, ) lang_datasets.append(dataset) if len(lang_datasets) == 0: return elif len(lang_datasets) == 1: dataset = lang_datasets[0] if self.args.add_src_lang_token or self.args.add_tgt_lang_token: end_token, bos_token = eos_bos[0] dataset = TransformEosLangPairDataset( dataset, src_eos=self.source_dictionary.eos(), new_src_eos=end_token, tgt_bos=self.target_dictionary.eos(), new_tgt_bos=bos_token, ) else: end_tokens = [item[0] for item in eos_bos if item[0] is not None] bos_tokens = [item[1] for item in eos_bos if item[1] is not None] lang_datasets = self.resample_datasets(lang_datasets, lang_pairs, epoch) dataset = TransformEosConcatLangPairDataset( lang_datasets, self.source_dictionary.eos(), self.target_dictionary.eos(), new_src_eos=end_tokens, new_tgt_bos=bos_tokens, ) return dataset