def run_aligner(self, text, text_len, text_mask, spect, spect_len, attn_prior): text_emb = self.symbol_emb(text) attn_soft, attn_logprob = self.aligner( spect, text_emb.permute(0, 2, 1), mask=text_mask == 0, attn_prior=attn_prior, ) attn_hard = binarize_attention_parallel(attn_soft, text_len, spect_len) attn_hard_dur = attn_hard.sum(2)[:, 0, :] assert torch.all(torch.eq(attn_hard_dur.sum(dim=1), spect_len)) return attn_soft, attn_logprob, attn_hard, attn_hard_dur
def get_durations(attn_soft, text_len, spect_len): """Calculation of durations. Args: attn_soft (torch.tensor): B x 1 x T1 x T2 tensor. text_len (torch.tensor): B tensor, lengths of text. spect_len (torch.tensor): B tensor, lengths of mel spectrogram. """ attn_hard = binarize_attention_parallel(attn_soft, text_len, spect_len) durations = attn_hard.sum(2)[:, 0, :] assert torch.all(torch.eq(durations.sum(dim=1), spect_len)) return durations
def forward( self, *, text, durs=None, pitch=None, speaker=None, pace=1.0, spec=None, attn_prior=None, mel_lens=None, input_lens=None, ): if not self.learn_alignment and self.training: assert durs is not None assert pitch is not None # Calculate speaker embedding if self.speaker_emb is None or speaker is None: spk_emb = 0 else: spk_emb = self.speaker_emb(speaker).unsqueeze(1) # Input FFT enc_out, enc_mask = self.encoder(input=text, conditioning=spk_emb) log_durs_predicted = self.duration_predictor(enc_out, enc_mask) durs_predicted = torch.clamp( torch.exp(log_durs_predicted) - 1, 0, self.max_token_duration) attn_soft, attn_hard, attn_hard_dur, attn_logprob = None, None, None, None if self.learn_alignment and spec is not None: text_emb = self.encoder.word_emb(text) attn_soft, attn_logprob = self.aligner(spec, text_emb.permute(0, 2, 1), enc_mask == 0, attn_prior) attn_hard = binarize_attention_parallel(attn_soft, input_lens, mel_lens) attn_hard_dur = attn_hard.sum(2)[:, 0, :] # Predict pitch pitch_predicted = self.pitch_predictor(enc_out, enc_mask) if pitch is not None: if self.learn_alignment and pitch.shape[ -1] != pitch_predicted.shape[-1]: # Pitch during training is per spectrogram frame, but during inference, it should be per character pitch = average_pitch(pitch.unsqueeze(1), attn_hard_dur).squeeze(1) pitch_emb = self.pitch_emb(pitch.unsqueeze(1)) else: pitch_emb = self.pitch_emb(pitch_predicted.unsqueeze(1)) enc_out = enc_out + pitch_emb.transpose(1, 2) if self.learn_alignment and spec is not None: len_regulated, dec_lens = regulate_len(attn_hard_dur, enc_out, pace) elif spec is None and durs is not None: len_regulated, dec_lens = regulate_len(durs, enc_out, pace) # Use predictions during inference elif spec is None: len_regulated, dec_lens = regulate_len(durs_predicted, enc_out, pace) # Output FFT dec_out, _ = self.decoder(input=len_regulated, seq_lens=dec_lens) spect = self.proj(dec_out).transpose(1, 2) return ( spect, dec_lens, durs_predicted, log_durs_predicted, pitch_predicted, attn_soft, attn_logprob, attn_hard, attn_hard_dur, pitch, )