コード例 #1
0
def waveglow_infer(mel, config):
    print(
        colored('Running WaveGlow with ', 'blue', attrs=['bold']) +
        config.vocoder_path)

    waveglow = WaveGlow(config)
    waveglow, _, _ = load_checkpoint(config.vocoder_path, waveglow)

    #waveglow = torch.hub.load('nvidia/DeepLearningExamples:torchhub', 'nvidia_waveglow')
    waveglow = waveglow.remove_weightnorm(waveglow)
    waveglow = set_device(waveglow, config.device)
    waveglow.eval()

    denoiser = Denoiser(waveglow, config)
    denoiser = set_device(denoiser, config.device)

    with torch.no_grad():
        wave = waveglow.infer(mel, config.sigma).float()
        wave = denoiser(wave, strength=config.denoising_strength)

    wave = wave / torch.max(torch.abs(wave))

    return wave.cpu()
コード例 #2
0
ファイル: tts.py プロジェクト: malarinv/tacotron2
class TTSModel(object):
    """docstring for TTSModel."""
    def __init__(self, tacotron2_path, waveglow_path, **kwargs):
        super(TTSModel, self).__init__()
        hparams = HParams(**kwargs)
        self.hparams = hparams
        self.model = Tacotron2(hparams)
        if torch.cuda.is_available():
            self.model.load_state_dict(
                torch.load(tacotron2_path)["state_dict"])
            self.model.cuda().eval()
        else:
            self.model.load_state_dict(
                torch.load(tacotron2_path, map_location="cpu")["state_dict"])
            self.model.eval()
        self.k_cache = klepto.archives.file_archive(cached=False)
        if waveglow_path:
            if torch.cuda.is_available():
                wave_params = torch.load(waveglow_path)
            else:
                wave_params = torch.load(waveglow_path, map_location="cpu")
            try:
                self.waveglow = WaveGlow(**WAVEGLOW_CONFIG)
                self.waveglow.load_state_dict(wave_params)
            except:
                self.waveglow = wave_params["model"]
                self.waveglow = self.waveglow.remove_weightnorm(self.waveglow)
            if torch.cuda.is_available():
                self.waveglow.cuda().eval()
            else:
                self.waveglow.eval()
            # workaround from
            # https://github.com/NVIDIA/waveglow/issues/127
            for m in self.waveglow.modules():
                if "Conv" in str(type(m)):
                    setattr(m, "padding_mode", "zeros")
            for k in self.waveglow.convinv:
                k.float().half()
            self.denoiser = Denoiser(self.waveglow,
                                     n_mel_channels=hparams.n_mel_channels)
            self.synth_speech = klepto.safe.inf_cache(cache=self.k_cache)(
                self._synth_speech)
        else:
            self.synth_speech = klepto.safe.inf_cache(cache=self.k_cache)(
                self._synth_speech_fast)
        self.taco_stft = TacotronSTFT(
            hparams.filter_length,
            hparams.hop_length,
            hparams.win_length,
            n_mel_channels=hparams.n_mel_channels,
            sampling_rate=hparams.sampling_rate,
            mel_fmax=4000,
        )

    def _generate_mel_postnet(self, text):
        sequence = np.array(text_to_sequence(text,
                                             ["english_cleaners"]))[None, :]
        if torch.cuda.is_available():
            sequence = torch.autograd.Variable(
                torch.from_numpy(sequence)).cuda().long()
        else:
            sequence = torch.autograd.Variable(
                torch.from_numpy(sequence)).long()
        with torch.no_grad():
            mel_outputs, mel_outputs_postnet, _, alignments = self.model.inference(
                sequence)
        return mel_outputs_postnet

    def synth_speech_array(self, text, vocoder):
        mel_outputs_postnet = self._generate_mel_postnet(text)

        if vocoder == VOCODER_WAVEGLOW:
            with torch.no_grad():
                audio_t = self.waveglow.infer(mel_outputs_postnet, sigma=0.666)
                audio_t = self.denoiser(audio_t, 0.1)[0]
            audio = audio_t[0].data
        elif vocoder == VOCODER_GL:
            mel_decompress = self.taco_stft.spectral_de_normalize(
                mel_outputs_postnet)
            mel_decompress = mel_decompress.transpose(1, 2).data.cpu()
            spec_from_mel_scaling = 1000
            spec_from_mel = torch.mm(mel_decompress[0],
                                     self.taco_stft.mel_basis)
            spec_from_mel = spec_from_mel.transpose(0, 1).unsqueeze(0)
            spec_from_mel = spec_from_mel * spec_from_mel_scaling
            spec_from_mel = (spec_from_mel.cuda()
                             if torch.cuda.is_available() else spec_from_mel)
            audio = griffin_lim(
                torch.autograd.Variable(spec_from_mel[:, :, :-1]),
                self.taco_stft.stft_fn,
                GL_ITERS,
            )
            audio = audio.squeeze()
        else:
            raise ValueError("vocoder arg should be one of [wavglow|gl]")
        audio = audio.cpu().numpy()
        return audio

    def _synth_speech(self,
                      text,
                      speed: float = 1.0,
                      sample_rate: int = OUTPUT_SAMPLE_RATE):
        audio = self.synth_speech_array(text, VOCODER_WAVEGLOW)

        return postprocess_audio(
            audio,
            src_rate=self.hparams.sampling_rate,
            dst_rate=sample_rate,
            tempo=speed,
        )

    def _synth_speech_fast(self,
                           text,
                           speed: float = 1.0,
                           sample_rate: int = OUTPUT_SAMPLE_RATE):
        audio = self.synth_speech_array(text, VOCODER_GL)

        return postprocess_audio(
            audio,
            tempo=speed,
            src_rate=self.hparams.sampling_rate,
            dst_rate=sample_rate,
        )
コード例 #3
0
ファイル: synthesis.py プロジェクト: dodohow1011/textglow
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

MAX_WAV_VALUE = 32768.0

if __name__ == "__main__":
    # Test
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    sampling_rate = 22050

    torch.manual_seed(hp.seed)
    torch.cuda.manual_seed(hp.seed)
    model = WaveGlow().cuda()
    checkpoint = torch.load('test/TTSglow_67000')
    model.load_state_dict(checkpoint['model'].state_dict())
    model = model.remove_weightnorm(model)

    dataset = FastSpeechDataset()
    testing_loader = DataLoader(dataset,
                                batch_size=1,
                                shuffle=False,
                                collate_fn=collate_fn,
                                drop_last=True,
                                num_workers=4)
    model = model.train()

    for i, data_of_batch in enumerate(testing_loader):
        audio_tgt = data_of_batch["audios"]
        src_seq = data_of_batch["texts"]
        src_pos = data_of_batch["pos"]
        mel_tgt = data_of_batch["mels"]