Exemple #1
0
    def __call__(self, inputs: np.array) -> Tuple[np.array, int, List[str]]:
        """
        Args:
            inputs (:obj:`np.array`):
                The raw waveform of audio received. By default sampled at `self.sampling_rate`.
                The shape of this array is `T`, where `T` is the time axis
        Return:
            A :obj:`tuple` containing:
              - :obj:`np.array`:
                 The return shape of the array must be `C'`x`T'`
              - a :obj:`int`: the sampling rate as an int in Hz.
              - a :obj:`List[str]`: the annotation for each out channel.
                    This can be the name of the instruments for audio source separation
                    or some annotation for speech enhancement. The length must be `C'`.
        """
        _inputs = torch.from_numpy(inputs).unsqueeze(0)
        sample = S2THubInterface.get_model_input(self.task, _inputs)
        text = S2THubInterface.get_prediction(self.task, self.model,
                                              self.generator, sample)

        if self.tts_model is None:
            return np.zeros((0, )), self.sampling_rate, [text]
        else:
            tts_sample = TTSHubInterface.get_model_input(self.tts_task, text)
            wav, sr = TTSHubInterface.get_prediction(self.tts_task,
                                                     self.tts_model,
                                                     self.tts_generator,
                                                     tts_sample)
            return wav.unsqueeze(0).numpy(), sr, [text]
Exemple #2
0
 def __init__(self, model_id: str):
     model, cfg, task = load_model_ensemble_and_task_from_hf_hub(
         model_id,
         arg_overrides={
             "vocoder": "griffin_lim",
             "fp16": False
         },
         cache_dir=os.getenv("HUGGINGFACE_HUB_CACHE"),
     )
     self.model = model[0].cpu()
     self.model.eval()
     cfg["task"].cpu = True
     self.task = task
     TTSHubInterface.update_cfg_with_data_cfg(cfg, self.task.data_cfg)
     self.generator = self.task.build_generator(model, cfg)
Exemple #3
0
    def __call__(self, inputs: str) -> Tuple[np.array, int]:
        """
        Args:
            inputs (:obj:`str`):
                The text to generate audio from
        Return:
            A :obj:`np.array` and a :obj:`int`: The raw waveform as a numpy
            array, and the sampling rate as an int.
        """
        inputs = inputs.strip("\x00")
        if len(inputs) == 0:
            return np.zeros((0, )), self.task.sr

        sample = TTSHubInterface.get_model_input(self.task, inputs)
        wav, sr = TTSHubInterface.get_prediction(self.task, self.model,
                                                 self.generator, sample)
        return wav.numpy(), sr
Exemple #4
0
    def __init__(self, model_id: str):
        models, cfg, task = load_model_ensemble_and_task_from_hf_hub(
            model_id,
            arg_overrides={"config_yaml": "config.yaml"},
            cache_dir=os.getenv("HUGGINGFACE_HUB_CACHE"),
        )
        self.model = models[0].cpu()
        self.model.eval()
        cfg["task"].cpu = True
        self.task = task
        self.generator = task.build_generator([self.model], cfg)

        self.sampling_rate = getattr(self.task, "sr", None) or 16_000

        tgt_lang = self.task.data_cfg.hub.get("tgt_lang", None)
        pfx = f"{tgt_lang}_" if self.task.data_cfg.prepend_tgt_lang_tag else ""
        tts_model_id = self.task.data_cfg.hub.get(f"{pfx}tts_model_id", None)
        self.tts_model, self.tts_task, self.tts_generator = None, None, None
        if tts_model_id is not None:
            _repo, _id = tts_model_id.split(":")
            (
                tts_models,
                tts_cfg,
                self.tts_task,
            ) = load_model_ensemble_and_task_from_hf_hub(
                f"facebook/{_id}",
                arg_overrides={
                    "vocoder": "griffin_lim",
                    "fp16": False
                },
                cache_dir=os.getenv("HUGGINGFACE_HUB_CACHE"),
            )
            self.tts_model = tts_models[0].cpu()
            self.tts_model.eval()
            tts_cfg["task"].cpu = True
            TTSHubInterface.update_cfg_with_data_cfg(tts_cfg,
                                                     self.tts_task.data_cfg)
            self.tts_generator = self.tts_task.build_generator(
                [self.tts_model], tts_cfg)
Exemple #5
0
 def from_pretrained(
     cls,
     model_name_or_path,
     checkpoint_file="model.pt",
     data_name_or_path=".",
     config_yaml="config.yaml",
     vocoder: str = "griffin_lim",
     fp16: bool = False,
     **kwargs,
 ):
     from fairseq import hub_utils
     x = hub_utils.from_pretrained(
         model_name_or_path,
         checkpoint_file,
         data_name_or_path,
         archive_map=cls.hub_models(),
         config_yaml=config_yaml,
         vocoder=vocoder,
         fp16=fp16,
         **kwargs,
     )
     return TTSHubInterface(x["args"], x["task"], x["models"][0])