Beispiel #1
0
 def run_aligner(self, text, text_len, text_mask, spect, spect_len, attn_prior):
     text_emb = self.symbol_emb(text)
     attn_soft, attn_logprob = self.aligner(
         spect, text_emb.permute(0, 2, 1), mask=text_mask == 0, attn_prior=attn_prior,
     )
     attn_hard = binarize_attention_parallel(attn_soft, text_len, spect_len)
     attn_hard_dur = attn_hard.sum(2)[:, 0, :]
     assert torch.all(torch.eq(attn_hard_dur.sum(dim=1), spect_len))
     return attn_soft, attn_logprob, attn_hard, attn_hard_dur
Beispiel #2
0
    def get_durations(attn_soft, text_len, spect_len):
        """Calculation of durations.

        Args:
            attn_soft (torch.tensor): B x 1 x T1 x T2 tensor.
            text_len (torch.tensor): B tensor, lengths of text.
            spect_len (torch.tensor): B tensor, lengths of mel spectrogram.
        """
        attn_hard = binarize_attention_parallel(attn_soft, text_len, spect_len)
        durations = attn_hard.sum(2)[:, 0, :]
        assert torch.all(torch.eq(durations.sum(dim=1), spect_len))
        return durations
Beispiel #3
0
    def forward(
        self,
        *,
        text,
        durs=None,
        pitch=None,
        speaker=None,
        pace=1.0,
        spec=None,
        attn_prior=None,
        mel_lens=None,
        input_lens=None,
    ):

        if not self.learn_alignment and self.training:
            assert durs is not None
            assert pitch is not None

        # Calculate speaker embedding
        if self.speaker_emb is None or speaker is None:
            spk_emb = 0
        else:
            spk_emb = self.speaker_emb(speaker).unsqueeze(1)

        # Input FFT
        enc_out, enc_mask = self.encoder(input=text, conditioning=spk_emb)

        log_durs_predicted = self.duration_predictor(enc_out, enc_mask)
        durs_predicted = torch.clamp(
            torch.exp(log_durs_predicted) - 1, 0, self.max_token_duration)

        attn_soft, attn_hard, attn_hard_dur, attn_logprob = None, None, None, None
        if self.learn_alignment and spec is not None:
            text_emb = self.encoder.word_emb(text)
            attn_soft, attn_logprob = self.aligner(spec,
                                                   text_emb.permute(0, 2, 1),
                                                   enc_mask == 0, attn_prior)
            attn_hard = binarize_attention_parallel(attn_soft, input_lens,
                                                    mel_lens)
            attn_hard_dur = attn_hard.sum(2)[:, 0, :]

        # Predict pitch
        pitch_predicted = self.pitch_predictor(enc_out, enc_mask)
        if pitch is not None:
            if self.learn_alignment and pitch.shape[
                    -1] != pitch_predicted.shape[-1]:
                # Pitch during training is per spectrogram frame, but during inference, it should be per character
                pitch = average_pitch(pitch.unsqueeze(1),
                                      attn_hard_dur).squeeze(1)
            pitch_emb = self.pitch_emb(pitch.unsqueeze(1))
        else:
            pitch_emb = self.pitch_emb(pitch_predicted.unsqueeze(1))

        enc_out = enc_out + pitch_emb.transpose(1, 2)

        if self.learn_alignment and spec is not None:
            len_regulated, dec_lens = regulate_len(attn_hard_dur, enc_out,
                                                   pace)
        elif spec is None and durs is not None:
            len_regulated, dec_lens = regulate_len(durs, enc_out, pace)
        # Use predictions during inference
        elif spec is None:
            len_regulated, dec_lens = regulate_len(durs_predicted, enc_out,
                                                   pace)

        # Output FFT
        dec_out, _ = self.decoder(input=len_regulated, seq_lens=dec_lens)
        spect = self.proj(dec_out).transpose(1, 2)
        return (
            spect,
            dec_lens,
            durs_predicted,
            log_durs_predicted,
            pitch_predicted,
            attn_soft,
            attn_logprob,
            attn_hard,
            attn_hard_dur,
            pitch,
        )