def forward(self, text, text_len, pitch=None, spect=None, spect_len=None, attn_prior=None, lm_tokens=None): if self.training: assert pitch is not None text_mask = get_mask_from_lengths(text_len).unsqueeze(2) enc_out, enc_mask = self.encoder(text, text_mask) # Aligner attn_soft, attn_logprob, attn_hard, attn_hard_dur = None, None, None, None if spect is not None: attn_soft, attn_logprob, attn_hard, attn_hard_dur = self.run_aligner( text, text_len, text_mask, spect, spect_len, attn_prior ) if self.cond_on_lm_embeddings: lm_emb = self.lm_embeddings(lm_tokens) lm_features = self.self_attention_module( enc_out, lm_emb, lm_emb, q_mask=enc_mask.squeeze(2), kv_mask=lm_tokens != self.lm_padding_value ) # Duration predictor log_durs_predicted = self.duration_predictor(enc_out, enc_mask) durs_predicted = torch.clamp(log_durs_predicted.exp() - 1, 0) # Pitch predictor pitch_predicted = self.pitch_predictor(enc_out, enc_mask) # Avg pitch, add pitch_emb if not self.training: if pitch is not None: pitch = average_pitch(pitch.unsqueeze(1), attn_hard_dur).squeeze(1) pitch_emb = self.pitch_emb(pitch.unsqueeze(1)) else: pitch_emb = self.pitch_emb(pitch_predicted.unsqueeze(1)) else: pitch = average_pitch(pitch.unsqueeze(1), attn_hard_dur).squeeze(1) pitch_emb = self.pitch_emb(pitch.unsqueeze(1)) enc_out = enc_out + pitch_emb.transpose(1, 2) if self.cond_on_lm_embeddings: enc_out = enc_out + lm_features # Regulate length len_regulated_enc_out, dec_lens = regulate_len(attn_hard_dur, enc_out) dec_out, dec_lens = self.decoder(len_regulated_enc_out, get_mask_from_lengths(dec_lens).unsqueeze(2)) pred_spect = self.proj(dec_out) return ( pred_spect, durs_predicted, log_durs_predicted, pitch_predicted, attn_soft, attn_logprob, attn_hard, attn_hard_dur, )
def infer( self, text, text_len=None, text_mask=None, spect=None, spect_len=None, attn_prior=None, use_gt_durs=False, lm_tokens=None, pitch=None, ): if text_mask is None: text_mask = get_mask_from_lengths(text_len).unsqueeze(2) enc_out, enc_mask = self.encoder(text, text_mask) # Aligner attn_hard_dur = None if use_gt_durs: attn_soft, attn_logprob, attn_hard, attn_hard_dur = self.run_aligner( text, text_len, text_mask, spect, spect_len, attn_prior ) if self.cond_on_lm_embeddings: lm_emb = self.lm_embeddings(lm_tokens) lm_features = self.self_attention_module( enc_out, lm_emb, lm_emb, q_mask=enc_mask.squeeze(2), kv_mask=lm_tokens != self.lm_padding_value ) # Duration predictor log_durs_predicted = self.duration_predictor(enc_out, enc_mask) durs_predicted = torch.clamp(log_durs_predicted.exp() - 1, 0) # Avg pitch, pitch predictor if use_gt_durs and pitch is not None: pitch = average_pitch(pitch.unsqueeze(1), attn_hard_dur).squeeze(1) pitch_emb = self.pitch_emb(pitch.unsqueeze(1)) else: pitch_predicted = self.pitch_predictor(enc_out, enc_mask) pitch_emb = self.pitch_emb(pitch_predicted.unsqueeze(1)) # Add pitch emb enc_out = enc_out + pitch_emb.transpose(1, 2) if self.cond_on_lm_embeddings: enc_out = enc_out + lm_features if use_gt_durs: if attn_hard_dur is not None: len_regulated_enc_out, dec_lens = regulate_len(attn_hard_dur, enc_out) else: raise NotImplementedError else: len_regulated_enc_out, dec_lens = regulate_len(durs_predicted, enc_out) dec_out, _ = self.decoder(len_regulated_enc_out, get_mask_from_lengths(dec_lens).unsqueeze(2)) pred_spect = self.proj(dec_out) return pred_spect
def _metrics( self, true_durs, true_text_len, pred_durs, true_pitch, pred_pitch, true_spect=None, pred_spect=None, true_spect_len=None, attn_logprob=None, attn_soft=None, attn_hard=None, attn_hard_dur=None, ): text_mask = get_mask_from_lengths(true_text_len) mel_mask = get_mask_from_lengths(true_spect_len) loss = 0.0 # Dur loss and metrics durs_loss = F.mse_loss(pred_durs, (true_durs + 1).float().log(), reduction='none') durs_loss = durs_loss * text_mask.float() durs_loss = durs_loss.sum() / text_mask.sum() durs_pred = pred_durs.exp() - 1 durs_pred = torch.clamp_min(durs_pred, min=0) durs_pred = durs_pred.round().long() acc = ((true_durs == durs_pred) * text_mask).sum().float() / text_mask.sum() * 100 acc_dist_1 = (((true_durs - durs_pred).abs() <= 1) * text_mask).sum().float() / text_mask.sum() * 100 acc_dist_3 = (((true_durs - durs_pred).abs() <= 3) * text_mask).sum().float() / text_mask.sum() * 100 pred_spect = pred_spect.transpose(1, 2) # Mel loss mel_loss = F.mse_loss(pred_spect, true_spect, reduction='none').mean(dim=-2) mel_loss = mel_loss * mel_mask.float() mel_loss = mel_loss.sum() / mel_mask.sum() loss = loss + self.durs_loss_scale * durs_loss + self.mel_loss_scale * mel_loss # Aligner loss bin_loss, ctc_loss = None, None ctc_loss = self.forward_sum_loss(attn_logprob=attn_logprob, in_lens=true_text_len, out_lens=true_spect_len) loss = loss + ctc_loss if self.add_bin_loss: bin_loss = self.bin_loss(hard_attention=attn_hard, soft_attention=attn_soft) loss = loss + self.bin_loss_scale * bin_loss true_avg_pitch = average_pitch(true_pitch.unsqueeze(1), attn_hard_dur).squeeze(1) # Pitch loss pitch_loss = F.mse_loss(pred_pitch, true_avg_pitch, reduction='none') # noqa pitch_loss = (pitch_loss * text_mask).sum() / text_mask.sum() loss = loss + self.pitch_loss_scale * pitch_loss return loss, durs_loss, acc, acc_dist_1, acc_dist_3, pitch_loss, mel_loss, ctc_loss, bin_loss
def validation_step(self, batch, batch_idx): attn_prior, lm_tokens = None, None if self.cond_on_lm_embeddings: audio, audio_len, text, text_len, attn_prior, pitch, _, lm_tokens = batch else: audio, audio_len, text, text_len, attn_prior, pitch, _ = batch spect, spect_len = self.preprocessor(input_signal=audio, length=audio_len) # pitch normalization zero_pitch_idx = pitch == 0 pitch = (pitch - self.pitch_mean) / self.pitch_std pitch[zero_pitch_idx] = 0.0 ( pred_spect, _, pred_log_durs, pred_pitch, attn_soft, attn_logprob, attn_hard, attn_hard_dur, ) = self( text=text, text_len=text_len, pitch=pitch, spect=spect, spect_len=spect_len, attn_prior=attn_prior, lm_tokens=lm_tokens, ) ( loss, durs_loss, acc, acc_dist_1, acc_dist_3, pitch_loss, mel_loss, ctc_loss, bin_loss, ) = self._metrics( pred_durs=pred_log_durs, pred_pitch=pred_pitch, true_durs=attn_hard_dur, true_text_len=text_len, true_pitch=pitch, true_spect=spect, pred_spect=pred_spect, true_spect_len=spect_len, attn_logprob=attn_logprob, attn_soft=attn_soft, attn_hard=attn_hard, attn_hard_dur=attn_hard_dur, ) # without ground truth internal features except for durations pred_spect, _, pred_log_durs, pred_pitch, attn_soft, attn_logprob, attn_hard, attn_hard_dur = self( text=text, text_len=text_len, pitch=None, spect=spect, spect_len=spect_len, attn_prior=attn_prior, lm_tokens=lm_tokens, ) *_, with_pred_features_mel_loss, _, _ = self._metrics( pred_durs=pred_log_durs, pred_pitch=pred_pitch, true_durs=attn_hard_dur, true_text_len=text_len, true_pitch=pitch, true_spect=spect, pred_spect=pred_spect, true_spect_len=spect_len, attn_logprob=attn_logprob, attn_soft=attn_soft, attn_hard=attn_hard, attn_hard_dur=attn_hard_dur, ) val_log = { 'val_loss': loss, 'val_durs_loss': durs_loss, 'val_pitch_loss': torch.tensor(1.0).to(durs_loss.device) if pitch_loss is None else pitch_loss, 'val_mel_loss': mel_loss, 'val_with_pred_features_mel_loss': with_pred_features_mel_loss, 'val_durs_acc': acc, 'val_durs_acc_dist_3': acc_dist_3, 'val_ctc_loss': torch.tensor(1.0).to(durs_loss.device) if ctc_loss is None else ctc_loss, 'val_bin_loss': torch.tensor(1.0).to(durs_loss.device) if bin_loss is None else bin_loss, } self.log_dict(val_log, prog_bar=False, on_epoch=True, logger=True, sync_dist=True) if batch_idx == 0 and self.current_epoch % 5 == 0 and isinstance( self.logger, WandbLogger): specs = [] pitches = [] for i in range(min(3, spect.shape[0])): specs += [ wandb.Image( plot_spectrogram_to_numpy( spect[i, :, :spect_len[i]].data.cpu().numpy()), caption=f"gt mel {i}", ), wandb.Image( plot_spectrogram_to_numpy( pred_spect.transpose( 1, 2)[i, :, :spect_len[i]].data.cpu().numpy()), caption=f"pred mel {i}", ), ] pitches += [ wandb.Image( plot_pitch_to_numpy( average_pitch(pitch.unsqueeze(1), attn_hard_dur).squeeze(1) [i, :text_len[i]].data.cpu().numpy(), ylim_range=[-2.5, 2.5], ), caption=f"gt pitch {i}", ), ] pitches += [ wandb.Image( plot_pitch_to_numpy( pred_pitch[i, :text_len[i]].data.cpu().numpy(), ylim_range=[-2.5, 2.5]), caption=f"pred pitch {i}", ), ] self.logger.experiment.log({"specs": specs, "pitches": pitches})