Ejemplo n.º 1
0
    def infer(
        self,
        text,
        text_len=None,
        text_mask=None,
        spect=None,
        spect_len=None,
        attn_prior=None,
        use_gt_durs=False,
        lm_tokens=None,
        pitch=None,
    ):
        if text_mask is None:
            text_mask = get_mask_from_lengths(text_len).unsqueeze(2)

        enc_out, enc_mask = self.encoder(text, text_mask)

        # Aligner
        attn_hard_dur = None
        if use_gt_durs:
            attn_soft, attn_logprob, attn_hard, attn_hard_dur = self.run_aligner(
                text, text_len, text_mask, spect, spect_len, attn_prior
            )

        if self.cond_on_lm_embeddings:
            lm_emb = self.lm_embeddings(lm_tokens)
            lm_features = self.self_attention_module(
                enc_out, lm_emb, lm_emb, q_mask=enc_mask.squeeze(2), kv_mask=lm_tokens != self.lm_padding_value
            )

        # Duration predictor
        log_durs_predicted = self.duration_predictor(enc_out, enc_mask)
        durs_predicted = torch.clamp(log_durs_predicted.exp() - 1, 0)

        # Avg pitch, pitch predictor
        if use_gt_durs and pitch is not None:
            pitch = average_pitch(pitch.unsqueeze(1), attn_hard_dur).squeeze(1)
            pitch_emb = self.pitch_emb(pitch.unsqueeze(1))
        else:
            pitch_predicted = self.pitch_predictor(enc_out, enc_mask)
            pitch_emb = self.pitch_emb(pitch_predicted.unsqueeze(1))

        # Add pitch emb
        enc_out = enc_out + pitch_emb.transpose(1, 2)

        if self.cond_on_lm_embeddings:
            enc_out = enc_out + lm_features

        if use_gt_durs:
            if attn_hard_dur is not None:
                len_regulated_enc_out, dec_lens = regulate_len(attn_hard_dur, enc_out)
            else:
                raise NotImplementedError
        else:
            len_regulated_enc_out, dec_lens = regulate_len(durs_predicted, enc_out)

        dec_out, _ = self.decoder(len_regulated_enc_out, get_mask_from_lengths(dec_lens).unsqueeze(2))
        pred_spect = self.proj(dec_out)

        return pred_spect
Ejemplo n.º 2
0
    def forward(self, *, text, durs=None, pitch=None, pace=1.0, splice=True):
        if self.training:
            assert durs is not None
            assert pitch is not None

        # Input FFT
        enc_out, enc_mask = self.encoder(input=text, conditioning=0)

        # Embedded for predictors
        pred_enc_out, pred_enc_mask = enc_out, enc_mask

        # Predict durations
        log_durs_predicted = self.duration_predictor(pred_enc_out,
                                                     pred_enc_mask)
        durs_predicted = torch.clamp(
            torch.exp(log_durs_predicted) - 1, 0, self.max_token_duration)

        # Predict pitch
        pitch_predicted = self.pitch_predictor(enc_out, enc_mask)
        if pitch is None:
            pitch_emb = self.pitch_emb(pitch_predicted.unsqueeze(1))
        else:
            pitch_emb = self.pitch_emb(pitch.unsqueeze(1))
        enc_out = enc_out + pitch_emb.transpose(1, 2)

        if durs is None:
            len_regulated, dec_lens = regulate_len(durs_predicted, enc_out,
                                                   pace)
        else:
            len_regulated, dec_lens = regulate_len(durs, enc_out, pace)

        gen_in = len_regulated
        splices = None
        if splice:
            output = []
            splices = []
            for i, sample in enumerate(len_regulated):
                start = np.random.randint(
                    low=0,
                    high=min(int(sample.size(0)), int(dec_lens[i])) -
                    self.splice_length)
                # Splice generated spec
                output.append(sample[start:start + self.splice_length, :])
                splices.append(start)
            gen_in = torch.stack(output)

        output = self.generator(x=gen_in.transpose(1, 2))

        return output, splices, log_durs_predicted, pitch_predicted
Ejemplo n.º 3
0
    def forward(self, text, text_len, pitch=None, spect=None, spect_len=None, attn_prior=None, lm_tokens=None):
        if self.training:
            assert pitch is not None

        text_mask = get_mask_from_lengths(text_len).unsqueeze(2)

        enc_out, enc_mask = self.encoder(text, text_mask)

        # Aligner
        attn_soft, attn_logprob, attn_hard, attn_hard_dur = None, None, None, None
        if spect is not None:
            attn_soft, attn_logprob, attn_hard, attn_hard_dur = self.run_aligner(
                text, text_len, text_mask, spect, spect_len, attn_prior
            )

        if self.cond_on_lm_embeddings:
            lm_emb = self.lm_embeddings(lm_tokens)
            lm_features = self.self_attention_module(
                enc_out, lm_emb, lm_emb, q_mask=enc_mask.squeeze(2), kv_mask=lm_tokens != self.lm_padding_value
            )

        # Duration predictor
        log_durs_predicted = self.duration_predictor(enc_out, enc_mask)
        durs_predicted = torch.clamp(log_durs_predicted.exp() - 1, 0)

        # Pitch predictor
        pitch_predicted = self.pitch_predictor(enc_out, enc_mask)

        # Avg pitch, add pitch_emb
        if not self.training:
            if pitch is not None:
                pitch = average_pitch(pitch.unsqueeze(1), attn_hard_dur).squeeze(1)
                pitch_emb = self.pitch_emb(pitch.unsqueeze(1))
            else:
                pitch_emb = self.pitch_emb(pitch_predicted.unsqueeze(1))
        else:
            pitch = average_pitch(pitch.unsqueeze(1), attn_hard_dur).squeeze(1)
            pitch_emb = self.pitch_emb(pitch.unsqueeze(1))

        enc_out = enc_out + pitch_emb.transpose(1, 2)

        if self.cond_on_lm_embeddings:
            enc_out = enc_out + lm_features

        # Regulate length
        len_regulated_enc_out, dec_lens = regulate_len(attn_hard_dur, enc_out)

        dec_out, dec_lens = self.decoder(len_regulated_enc_out, get_mask_from_lengths(dec_lens).unsqueeze(2))
        pred_spect = self.proj(dec_out)

        return (
            pred_spect,
            durs_predicted,
            log_durs_predicted,
            pitch_predicted,
            attn_soft,
            attn_logprob,
            attn_hard,
            attn_hard_dur,
        )