コード例 #1
0
ファイル: speech_to_text.py プロジェクト: xuhu357/fairseq
 def build_generator(
     self,
     models,
     args,
     seq_gen_cls=None,
     extra_gen_cls_kwargs=None,
 ):
     if self.data_cfg.prepend_tgt_lang_tag and args.prefix_size != 1:
         raise ValueError('Please set "--prefix-size 1" since '
                          "target language ID token is prepended as BOS.")
     lang_token_ids = {
         i
         for s, i in self.tgt_dict.indices.items()
         if SpeechToTextDataset.is_lang_tag(s)
     }
     if extra_gen_cls_kwargs is None:
         extra_gen_cls_kwargs = {
             "symbols_to_strip_from_output": lang_token_ids
         }
     else:
         extra_gen_cls_kwargs[
             "symbols_to_strip_from_output"] = lang_token_ids
     return super().build_generator(
         models,
         args,
         seq_gen_cls=None,
         extra_gen_cls_kwargs=extra_gen_cls_kwargs)
コード例 #2
0
 def build_dataset_for_inference(self, src_tokens, src_lengths, **kwargs):
     assert self.interactive_tgt_lang is not None
     return SpeechToTextDataset(
         "interactive", False, self.data_cfg, src_tokens, src_lengths,
         tgt_texts=([""] * len(src_tokens)),
         tgt_langs=([self.interactive_tgt_lang] * len(src_tokens)),
         tgt_dict=self.tgt_dict
     )
コード例 #3
0
 def get_prefix_token(cls, task, lang):
     prefix_size = int(task.data_cfg.prepend_tgt_lang_tag)
     prefix_tokens = None
     if prefix_size > 0:
         assert lang is not None
         lang_tag = SpeechToTextDataset.get_lang_tag_idx(
             lang, task.tgt_dict)
         prefix_tokens = torch.Tensor([lang_tag]).long().unsqueeze(0)
     return prefix_tokens
コード例 #4
0
ファイル: speech_to_text.py プロジェクト: freewym/espresso
    def build_generator(
        self,
        models,
        args,
        seq_gen_cls=None,
        extra_gen_cls_kwargs=None,
    ):
        if self.data_cfg.prepend_tgt_lang_tag and args.prefix_size != 1:
            raise ValueError(
                'Please set "--prefix-size 1" since '
                "target language ID token is prepended as BOS."
            )
        lang_token_ids = {
            i
            for s, i in self.tgt_dict.indices.items()
            if SpeechToTextDataset.is_lang_tag(s)
        }

        if extra_gen_cls_kwargs is None:
            extra_gen_cls_kwargs = {}
        extra_gen_cls_kwargs["symbols_to_strip_from_output"] = lang_token_ids

        eos_token = (
            args.eos_token
            if "eos_token" in args and args.eos_token is not None
            else self.data_cfg.config.get("eos_token", None)
        )

        if self.data_cfg.prepend_bos_and_append_tgt_lang_tag and not eos_token:
            raise Warning(
                "Please provide --eos_token to replace eos in sequence generator"
            )

        eos_id = self.tgt_dict.index(eos_token) if eos_token else None
        extra_gen_cls_kwargs["eos"] = eos_id

        return super().build_generator(
            models, args, seq_gen_cls=None, extra_gen_cls_kwargs=extra_gen_cls_kwargs
        )
コード例 #5
0
 def build_dataset_for_inference(cls, audio_paths, n_frames):
     return SpeechToTextDataset("interactive", False, {}, audio_paths,
                                n_frames)
コード例 #6
0
ファイル: speech_to_text.py プロジェクト: freewym/espresso
 def build_dataset_for_inference(self, src_tokens, src_lengths, **kwargs):
     return SpeechToTextDataset(
         "interactive", False, self.data_cfg, src_tokens, src_lengths
     )