def build_generator( self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None, ): if self.data_cfg.prepend_tgt_lang_tag and args.prefix_size != 1: raise ValueError('Please set "--prefix-size 1" since ' "target language ID token is prepended as BOS.") lang_token_ids = { i for s, i in self.tgt_dict.indices.items() if SpeechToTextDataset.is_lang_tag(s) } if extra_gen_cls_kwargs is None: extra_gen_cls_kwargs = { "symbols_to_strip_from_output": lang_token_ids } else: extra_gen_cls_kwargs[ "symbols_to_strip_from_output"] = lang_token_ids return super().build_generator( models, args, seq_gen_cls=None, extra_gen_cls_kwargs=extra_gen_cls_kwargs)
def build_dataset_for_inference(self, src_tokens, src_lengths, **kwargs): assert self.interactive_tgt_lang is not None return SpeechToTextDataset( "interactive", False, self.data_cfg, src_tokens, src_lengths, tgt_texts=([""] * len(src_tokens)), tgt_langs=([self.interactive_tgt_lang] * len(src_tokens)), tgt_dict=self.tgt_dict )
def get_prefix_token(cls, task, lang): prefix_size = int(task.data_cfg.prepend_tgt_lang_tag) prefix_tokens = None if prefix_size > 0: assert lang is not None lang_tag = SpeechToTextDataset.get_lang_tag_idx( lang, task.tgt_dict) prefix_tokens = torch.Tensor([lang_tag]).long().unsqueeze(0) return prefix_tokens
def build_generator( self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None, ): if self.data_cfg.prepend_tgt_lang_tag and args.prefix_size != 1: raise ValueError( 'Please set "--prefix-size 1" since ' "target language ID token is prepended as BOS." ) lang_token_ids = { i for s, i in self.tgt_dict.indices.items() if SpeechToTextDataset.is_lang_tag(s) } if extra_gen_cls_kwargs is None: extra_gen_cls_kwargs = {} extra_gen_cls_kwargs["symbols_to_strip_from_output"] = lang_token_ids eos_token = ( args.eos_token if "eos_token" in args and args.eos_token is not None else self.data_cfg.config.get("eos_token", None) ) if self.data_cfg.prepend_bos_and_append_tgt_lang_tag and not eos_token: raise Warning( "Please provide --eos_token to replace eos in sequence generator" ) eos_id = self.tgt_dict.index(eos_token) if eos_token else None extra_gen_cls_kwargs["eos"] = eos_id return super().build_generator( models, args, seq_gen_cls=None, extra_gen_cls_kwargs=extra_gen_cls_kwargs )
def build_dataset_for_inference(cls, audio_paths, n_frames): return SpeechToTextDataset("interactive", False, {}, audio_paths, n_frames)
def build_dataset_for_inference(self, src_tokens, src_lengths, **kwargs): return SpeechToTextDataset( "interactive", False, self.data_cfg, src_tokens, src_lengths )