예제 #1
0
    def forward(self, *, text, durs=None, pitch=None, pace=1.0, splice=True):
        if self.training:
            assert durs is not None
            assert pitch is not None

        # Input FFT
        enc_out, enc_mask = self.encoder(input=text, conditioning=0)

        # Embedded for predictors
        pred_enc_out, pred_enc_mask = enc_out, enc_mask

        # Predict durations
        log_durs_predicted = self.duration_predictor(pred_enc_out,
                                                     pred_enc_mask)
        durs_predicted = torch.clamp(
            torch.exp(log_durs_predicted) - 1, 0, self.max_token_duration)

        # Predict pitch
        pitch_predicted = self.pitch_predictor(enc_out, enc_mask)
        if pitch is None:
            pitch_emb = self.pitch_emb(pitch_predicted.unsqueeze(1))
        else:
            pitch_emb = self.pitch_emb(pitch.unsqueeze(1))
        enc_out = enc_out + pitch_emb.transpose(1, 2)

        if durs is None:
            len_regulated, dec_lens = regulate_len(durs_predicted, enc_out,
                                                   pace)
        else:
            len_regulated, dec_lens = regulate_len(durs, enc_out, pace)

        gen_in = len_regulated
        splices = []
        if splice:
            output = []
            for i, sample in enumerate(len_regulated):
                start = np.random.randint(
                    low=0,
                    high=min(int(sample.size(0)), int(dec_lens[i])) -
                    self.splice_length)
                # Splice generated spec
                output.append(sample[start:start + self.splice_length, :])
                splices.append(start)
            gen_in = torch.stack(output)

        output = self.generator(x=gen_in.transpose(1, 2))

        return output, torch.tensor(
            splices), log_durs_predicted, pitch_predicted
예제 #2
0
파일: fastpitch.py 프로젝트: sycomix/NeMo
    def infer(self, *, text, pitch=None, speaker=None, pace=1.0):
        # Calculate speaker embedding
        if self.speaker_emb is None or speaker is None:
            spk_emb = 0
        else:
            spk_emb = self.speaker_emb(speaker).unsqueeze(1)

        # Input FFT
        enc_out, enc_mask = self.encoder(input=text, conditioning=spk_emb)

        # Predict duration and pitch
        log_durs_predicted = self.duration_predictor(enc_out, enc_mask)
        durs_predicted = torch.clamp(
            torch.exp(log_durs_predicted) - 1.0, self.min_token_duration,
            self.max_token_duration)
        pitch_predicted = self.pitch_predictor(enc_out, enc_mask) + pitch
        pitch_emb = self.pitch_emb(pitch_predicted.unsqueeze(1))
        enc_out = enc_out + pitch_emb.transpose(1, 2)

        # Expand to decoder time dimension
        len_regulated, dec_lens = regulate_len(durs_predicted, enc_out, pace)

        # Output FFT
        dec_out, _ = self.decoder(input=len_regulated, seq_lens=dec_lens)
        spect = self.proj(dec_out).transpose(1, 2)
        return spect.to(
            torch.float
        ), dec_lens, durs_predicted, log_durs_predicted, pitch_predicted
예제 #3
0
파일: fastpitch.py 프로젝트: sycomix/NeMo
    def forward(
        self,
        *,
        text,
        durs=None,
        pitch=None,
        speaker=None,
        pace=1.0,
        spec=None,
        attn_prior=None,
        mel_lens=None,
        input_lens=None,
    ):

        if not self.learn_alignment and self.training:
            assert durs is not None
            assert pitch is not None

        # Calculate speaker embedding
        if self.speaker_emb is None or speaker is None:
            spk_emb = 0
        else:
            spk_emb = self.speaker_emb(speaker).unsqueeze(1)

        # Input FFT
        enc_out, enc_mask = self.encoder(input=text, conditioning=spk_emb)

        log_durs_predicted = self.duration_predictor(enc_out, enc_mask)
        durs_predicted = torch.clamp(
            torch.exp(log_durs_predicted) - 1, 0, self.max_token_duration)

        attn_soft, attn_hard, attn_hard_dur, attn_logprob = None, None, None, None
        if self.learn_alignment and spec is not None:
            text_emb = self.encoder.word_emb(text)
            attn_soft, attn_logprob = self.aligner(spec,
                                                   text_emb.permute(0, 2, 1),
                                                   enc_mask == 0, attn_prior)
            attn_hard = binarize_attention_parallel(attn_soft, input_lens,
                                                    mel_lens)
            attn_hard_dur = attn_hard.sum(2)[:, 0, :]

        # Predict pitch
        pitch_predicted = self.pitch_predictor(enc_out, enc_mask)
        if pitch is not None:
            if self.learn_alignment and pitch.shape[
                    -1] != pitch_predicted.shape[-1]:
                # Pitch during training is per spectrogram frame, but during inference, it should be per character
                pitch = average_pitch(pitch.unsqueeze(1),
                                      attn_hard_dur).squeeze(1)
            pitch_emb = self.pitch_emb(pitch.unsqueeze(1))
        else:
            pitch_emb = self.pitch_emb(pitch_predicted.unsqueeze(1))

        enc_out = enc_out + pitch_emb.transpose(1, 2)

        if self.learn_alignment and spec is not None:
            len_regulated, dec_lens = regulate_len(attn_hard_dur, enc_out,
                                                   pace)
        elif spec is None and durs is not None:
            len_regulated, dec_lens = regulate_len(durs, enc_out, pace)
        # Use predictions during inference
        elif spec is None:
            len_regulated, dec_lens = regulate_len(durs_predicted, enc_out,
                                                   pace)

        # Output FFT
        dec_out, _ = self.decoder(input=len_regulated, seq_lens=dec_lens)
        spect = self.proj(dec_out).transpose(1, 2)
        return (
            spect,
            dec_lens,
            durs_predicted,
            log_durs_predicted,
            pitch_predicted,
            attn_soft,
            attn_logprob,
            attn_hard,
            attn_hard_dur,
            pitch,
        )