def forward(self, *, text, durs=None, pitch=None, pace=1.0, splice=True): if self.training: assert durs is not None assert pitch is not None # Input FFT enc_out, enc_mask = self.encoder(input=text, conditioning=0) # Embedded for predictors pred_enc_out, pred_enc_mask = enc_out, enc_mask # Predict durations log_durs_predicted = self.duration_predictor(pred_enc_out, pred_enc_mask) durs_predicted = torch.clamp( torch.exp(log_durs_predicted) - 1, 0, self.max_token_duration) # Predict pitch pitch_predicted = self.pitch_predictor(enc_out, enc_mask) if pitch is None: pitch_emb = self.pitch_emb(pitch_predicted.unsqueeze(1)) else: pitch_emb = self.pitch_emb(pitch.unsqueeze(1)) enc_out = enc_out + pitch_emb.transpose(1, 2) if durs is None: len_regulated, dec_lens = regulate_len(durs_predicted, enc_out, pace) else: len_regulated, dec_lens = regulate_len(durs, enc_out, pace) gen_in = len_regulated splices = [] if splice: output = [] for i, sample in enumerate(len_regulated): start = np.random.randint( low=0, high=min(int(sample.size(0)), int(dec_lens[i])) - self.splice_length) # Splice generated spec output.append(sample[start:start + self.splice_length, :]) splices.append(start) gen_in = torch.stack(output) output = self.generator(x=gen_in.transpose(1, 2)) return output, torch.tensor( splices), log_durs_predicted, pitch_predicted
def infer(self, *, text, pitch=None, speaker=None, pace=1.0): # Calculate speaker embedding if self.speaker_emb is None or speaker is None: spk_emb = 0 else: spk_emb = self.speaker_emb(speaker).unsqueeze(1) # Input FFT enc_out, enc_mask = self.encoder(input=text, conditioning=spk_emb) # Predict duration and pitch log_durs_predicted = self.duration_predictor(enc_out, enc_mask) durs_predicted = torch.clamp( torch.exp(log_durs_predicted) - 1.0, self.min_token_duration, self.max_token_duration) pitch_predicted = self.pitch_predictor(enc_out, enc_mask) + pitch pitch_emb = self.pitch_emb(pitch_predicted.unsqueeze(1)) enc_out = enc_out + pitch_emb.transpose(1, 2) # Expand to decoder time dimension len_regulated, dec_lens = regulate_len(durs_predicted, enc_out, pace) # Output FFT dec_out, _ = self.decoder(input=len_regulated, seq_lens=dec_lens) spect = self.proj(dec_out).transpose(1, 2) return spect.to( torch.float ), dec_lens, durs_predicted, log_durs_predicted, pitch_predicted
def forward( self, *, text, durs=None, pitch=None, speaker=None, pace=1.0, spec=None, attn_prior=None, mel_lens=None, input_lens=None, ): if not self.learn_alignment and self.training: assert durs is not None assert pitch is not None # Calculate speaker embedding if self.speaker_emb is None or speaker is None: spk_emb = 0 else: spk_emb = self.speaker_emb(speaker).unsqueeze(1) # Input FFT enc_out, enc_mask = self.encoder(input=text, conditioning=spk_emb) log_durs_predicted = self.duration_predictor(enc_out, enc_mask) durs_predicted = torch.clamp( torch.exp(log_durs_predicted) - 1, 0, self.max_token_duration) attn_soft, attn_hard, attn_hard_dur, attn_logprob = None, None, None, None if self.learn_alignment and spec is not None: text_emb = self.encoder.word_emb(text) attn_soft, attn_logprob = self.aligner(spec, text_emb.permute(0, 2, 1), enc_mask == 0, attn_prior) attn_hard = binarize_attention_parallel(attn_soft, input_lens, mel_lens) attn_hard_dur = attn_hard.sum(2)[:, 0, :] # Predict pitch pitch_predicted = self.pitch_predictor(enc_out, enc_mask) if pitch is not None: if self.learn_alignment and pitch.shape[ -1] != pitch_predicted.shape[-1]: # Pitch during training is per spectrogram frame, but during inference, it should be per character pitch = average_pitch(pitch.unsqueeze(1), attn_hard_dur).squeeze(1) pitch_emb = self.pitch_emb(pitch.unsqueeze(1)) else: pitch_emb = self.pitch_emb(pitch_predicted.unsqueeze(1)) enc_out = enc_out + pitch_emb.transpose(1, 2) if self.learn_alignment and spec is not None: len_regulated, dec_lens = regulate_len(attn_hard_dur, enc_out, pace) elif spec is None and durs is not None: len_regulated, dec_lens = regulate_len(durs, enc_out, pace) # Use predictions during inference elif spec is None: len_regulated, dec_lens = regulate_len(durs_predicted, enc_out, pace) # Output FFT dec_out, _ = self.decoder(input=len_regulated, seq_lens=dec_lens) spect = self.proj(dec_out).transpose(1, 2) return ( spect, dec_lens, durs_predicted, log_durs_predicted, pitch_predicted, attn_soft, attn_logprob, attn_hard, attn_hard_dur, pitch, )