def infer( self, text, text_len=None, text_mask=None, spect=None, spect_len=None, attn_prior=None, use_gt_durs=False, lm_tokens=None, pitch=None, ): if text_mask is None: text_mask = get_mask_from_lengths(text_len).unsqueeze(2) enc_out, enc_mask = self.encoder(text, text_mask) # Aligner attn_hard_dur = None if use_gt_durs: attn_soft, attn_logprob, attn_hard, attn_hard_dur = self.run_aligner( text, text_len, text_mask, spect, spect_len, attn_prior ) if self.cond_on_lm_embeddings: lm_emb = self.lm_embeddings(lm_tokens) lm_features = self.self_attention_module( enc_out, lm_emb, lm_emb, q_mask=enc_mask.squeeze(2), kv_mask=lm_tokens != self.lm_padding_value ) # Duration predictor log_durs_predicted = self.duration_predictor(enc_out, enc_mask) durs_predicted = torch.clamp(log_durs_predicted.exp() - 1, 0) # Avg pitch, pitch predictor if use_gt_durs and pitch is not None: pitch = average_pitch(pitch.unsqueeze(1), attn_hard_dur).squeeze(1) pitch_emb = self.pitch_emb(pitch.unsqueeze(1)) else: pitch_predicted = self.pitch_predictor(enc_out, enc_mask) pitch_emb = self.pitch_emb(pitch_predicted.unsqueeze(1)) # Add pitch emb enc_out = enc_out + pitch_emb.transpose(1, 2) if self.cond_on_lm_embeddings: enc_out = enc_out + lm_features if use_gt_durs: if attn_hard_dur is not None: len_regulated_enc_out, dec_lens = regulate_len(attn_hard_dur, enc_out) else: raise NotImplementedError else: len_regulated_enc_out, dec_lens = regulate_len(durs_predicted, enc_out) dec_out, _ = self.decoder(len_regulated_enc_out, get_mask_from_lengths(dec_lens).unsqueeze(2)) pred_spect = self.proj(dec_out) return pred_spect
def forward(self, text, text_len, pitch=None, spect=None, spect_len=None, attn_prior=None, lm_tokens=None): if self.training: assert pitch is not None text_mask = get_mask_from_lengths(text_len).unsqueeze(2) enc_out, enc_mask = self.encoder(text, text_mask) # Aligner attn_soft, attn_logprob, attn_hard, attn_hard_dur = None, None, None, None if spect is not None: attn_soft, attn_logprob, attn_hard, attn_hard_dur = self.run_aligner( text, text_len, text_mask, spect, spect_len, attn_prior ) if self.cond_on_lm_embeddings: lm_emb = self.lm_embeddings(lm_tokens) lm_features = self.self_attention_module( enc_out, lm_emb, lm_emb, q_mask=enc_mask.squeeze(2), kv_mask=lm_tokens != self.lm_padding_value ) # Duration predictor log_durs_predicted = self.duration_predictor(enc_out, enc_mask) durs_predicted = torch.clamp(log_durs_predicted.exp() - 1, 0) # Pitch predictor pitch_predicted = self.pitch_predictor(enc_out, enc_mask) # Avg pitch, add pitch_emb if not self.training: if pitch is not None: pitch = average_pitch(pitch.unsqueeze(1), attn_hard_dur).squeeze(1) pitch_emb = self.pitch_emb(pitch.unsqueeze(1)) else: pitch_emb = self.pitch_emb(pitch_predicted.unsqueeze(1)) else: pitch = average_pitch(pitch.unsqueeze(1), attn_hard_dur).squeeze(1) pitch_emb = self.pitch_emb(pitch.unsqueeze(1)) enc_out = enc_out + pitch_emb.transpose(1, 2) if self.cond_on_lm_embeddings: enc_out = enc_out + lm_features # Regulate length len_regulated_enc_out, dec_lens = regulate_len(attn_hard_dur, enc_out) dec_out, dec_lens = self.decoder(len_regulated_enc_out, get_mask_from_lengths(dec_lens).unsqueeze(2)) pred_spect = self.proj(dec_out) return ( pred_spect, durs_predicted, log_durs_predicted, pitch_predicted, attn_soft, attn_logprob, attn_hard, attn_hard_dur, )
def _metrics( self, true_durs, true_text_len, pred_durs, true_pitch, pred_pitch, true_spect=None, pred_spect=None, true_spect_len=None, attn_logprob=None, attn_soft=None, attn_hard=None, attn_hard_dur=None, ): text_mask = get_mask_from_lengths(true_text_len) mel_mask = get_mask_from_lengths(true_spect_len) loss = 0.0 # Dur loss and metrics durs_loss = F.mse_loss(pred_durs, (true_durs + 1).float().log(), reduction='none') durs_loss = durs_loss * text_mask.float() durs_loss = durs_loss.sum() / text_mask.sum() durs_pred = pred_durs.exp() - 1 durs_pred = torch.clamp_min(durs_pred, min=0) durs_pred = durs_pred.round().long() acc = ((true_durs == durs_pred) * text_mask).sum().float() / text_mask.sum() * 100 acc_dist_1 = (((true_durs - durs_pred).abs() <= 1) * text_mask).sum().float() / text_mask.sum() * 100 acc_dist_3 = (((true_durs - durs_pred).abs() <= 3) * text_mask).sum().float() / text_mask.sum() * 100 pred_spect = pred_spect.transpose(1, 2) # Mel loss mel_loss = F.mse_loss(pred_spect, true_spect, reduction='none').mean(dim=-2) mel_loss = mel_loss * mel_mask.float() mel_loss = mel_loss.sum() / mel_mask.sum() loss = loss + self.durs_loss_scale * durs_loss + self.mel_loss_scale * mel_loss # Aligner loss bin_loss, ctc_loss = None, None ctc_loss = self.forward_sum_loss(attn_logprob=attn_logprob, in_lens=true_text_len, out_lens=true_spect_len) loss = loss + ctc_loss if self.add_bin_loss: bin_loss = self.bin_loss(hard_attention=attn_hard, soft_attention=attn_soft) loss = loss + self.bin_loss_scale * bin_loss true_avg_pitch = average_pitch(true_pitch.unsqueeze(1), attn_hard_dur).squeeze(1) # Pitch loss pitch_loss = F.mse_loss(pred_pitch, true_avg_pitch, reduction='none') # noqa pitch_loss = (pitch_loss * text_mask).sum() / text_mask.sum() loss = loss + self.pitch_loss_scale * pitch_loss return loss, durs_loss, acc, acc_dist_1, acc_dist_3, pitch_loss, mel_loss, ctc_loss, bin_loss
def _log(true_mel, true_len, pred_mel): loss = F.mse_loss(pred_mel, true_mel, reduction='none').mean(dim=-2) mask = get_mask_from_lengths(true_len) loss *= mask.float() loss = loss.sum() / mask.sum() return dict(loss=loss)
def validation_step(self, batch, batch_idx): audio, audio_len = batch with torch.no_grad(): spec, _ = self.audio_to_melspec_precessor(audio, audio_len) audio_pred = self(spec=spec) loss = 0 loss_dict = {} spec_pred, _ = self.audio_to_melspec_precessor( audio_pred.squeeze(1), audio_len) # Ensure that audio len is consistent between audio_pred and audio # For SC Norm loss, we can just zero out # For Mag L1 loss, we need to mask if audio_pred.shape[-1] < audio.shape[-1]: # prediction audio is less than audio, pad predicted audio to real audio pad_amount = audio.shape[-1] - audio_pred.shape[-1] audio_pred = torch.nn.functional.pad(audio_pred, (0, pad_amount), value=0.0) else: # prediction audio is larger than audio, slice predicted audio to real audio audio_pred = audio_pred[:, :, :audio.shape[1]] mask = ~get_mask_from_lengths(audio_len, max_len=torch.max(audio_len)) mask = mask.unsqueeze(1) audio_pred.data.masked_fill_(mask, 0.0) # full-band loss sc_loss, mag_loss = self.loss(x=audio_pred.squeeze(1), y=audio, input_lengths=audio_len) loss_feat = (sum(sc_loss) + sum(mag_loss)) / len(sc_loss) loss_dict["sc_loss"] = sc_loss loss_dict["mag_loss"] = mag_loss loss += loss_feat loss_dict["loss_feat"] = loss_feat if self.start_training_disc: fake_score = self.discriminator(x=audio_pred)[0] loss_gen = [0] * len(fake_score) for i, scale in enumerate(fake_score): loss_gen[i] += self.mse_loss(scale, scale.new_ones(scale.size())) loss_dict["gan_loss"] = loss_gen loss += sum(loss_gen) / len(fake_score) if not self.logged_real_samples: loss_dict["spec"] = spec loss_dict["audio"] = audio loss_dict["audio_pred"] = audio_pred loss_dict["spec_pred"] = spec_pred loss_dict["loss"] = loss return loss_dict
def forward(self, *, spec, spec_len, text, text_len, attn_prior=None): with torch.cuda.amp.autocast(enabled=False): attn_soft, attn_logprob = self.alignment_encoder( queries=spec, keys=self.embed(text).transpose(1, 2), mask=get_mask_from_lengths(text_len).unsqueeze(-1) == 0, attn_prior=attn_prior, ) return attn_soft, attn_logprob
def _metrics(true_durs, true_text_len, pred_durs): loss = F.mse_loss(pred_durs, (true_durs + 1).float().log(), reduction='none') mask = get_mask_from_lengths(true_text_len) loss *= mask.float() loss = loss.sum() / mask.sum() durs_pred = pred_durs.exp() - 1 durs_pred[durs_pred < 0.0] = 0.0 durs_pred = durs_pred.round().long() acc = ((true_durs == durs_pred) * mask).sum().float() / mask.sum() * 100 return loss, acc
def generate_spectrogram(self, *, tokens): self.eval() self.calculate_loss = False token_len = torch.tensor([len(i) for i in tokens]).to(self.device) tensors = self(tokens=tokens, token_len=token_len) spectrogram_pred = tensors[1] if spectrogram_pred.shape[0] > 1: # Silence all frames past the predicted end mask = ~get_mask_from_lengths(tensors[-1]) mask = mask.expand(spectrogram_pred.shape[1], mask.size(0), mask.size(1)) mask = mask.permute(1, 0, 2) spectrogram_pred.data.masked_fill_(mask, self.pad_value) return spectrogram_pred
def infer(self, *, memory, memory_lengths): decoder_input = self.get_go_frame(memory) if memory.size(0) > 1: mask = ~get_mask_from_lengths(memory_lengths) else: mask = None self.initialize_decoder_states(memory, mask=mask) mel_lengths = torch.zeros([memory.size(0)], dtype=torch.int32) not_finished = torch.ones([memory.size(0)], dtype=torch.int32) if torch.cuda.is_available(): mel_lengths = mel_lengths.cuda() not_finished = not_finished.cuda() mel_outputs, gate_outputs, alignments = [], [], [] stepped = False while True: decoder_input = self.prenet(decoder_input, inference=True) mel_output, gate_output, alignment = self.decode(decoder_input) dec = torch.le(torch.sigmoid(gate_output.data), self.gate_threshold).to(torch.int32).squeeze(1) not_finished = not_finished * dec mel_lengths += not_finished if self.early_stopping and torch.sum( not_finished) == 0 and stepped: break stepped = True mel_outputs += [mel_output.squeeze(1)] gate_outputs += [gate_output] alignments += [alignment] if len(mel_outputs) == self.max_decoder_steps: logging.warning("Reached max decoder steps %d.", self.max_decoder_steps) break decoder_input = mel_output mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs( mel_outputs, gate_outputs, alignments) return mel_outputs, gate_outputs, alignments, mel_lengths
def forward(self, *, spec_pred_dec, spec_pred_postnet, gate_pred, spec_target, spec_target_len, pad_value): # Make the gate target max_len = spec_target.shape[2] gate_target = torch.zeros(spec_target_len.shape[0], max_len) gate_target = gate_target.type_as(gate_pred) for i, length in enumerate(spec_target_len): gate_target[i, length.data - 1:] = 1 spec_target.requires_grad = False gate_target.requires_grad = False gate_target = gate_target.view(-1, 1) max_len = spec_target.shape[2] if max_len < spec_pred_dec.shape[2]: # Predicted len is larger than reference # Need to slice spec_pred_dec = spec_pred_dec.narrow(2, 0, max_len) spec_pred_postnet = spec_pred_postnet.narrow(2, 0, max_len) gate_pred = gate_pred.narrow(1, 0, max_len).contiguous() elif max_len > spec_pred_dec.shape[2]: # Need to do padding pad_amount = max_len - spec_pred_dec.shape[2] spec_pred_dec = torch.nn.functional.pad(spec_pred_dec, (0, pad_amount), value=pad_value) spec_pred_postnet = torch.nn.functional.pad(spec_pred_postnet, (0, pad_amount), value=pad_value) gate_pred = torch.nn.functional.pad(gate_pred, (0, pad_amount), value=1e3) max_len = spec_pred_dec.shape[2] mask = ~get_mask_from_lengths(spec_target_len, max_len=max_len) mask = mask.expand(spec_target.shape[1], mask.size(0), mask.size(1)) mask = mask.permute(1, 0, 2) spec_pred_dec.data.masked_fill_(mask, pad_value) spec_pred_postnet.data.masked_fill_(mask, pad_value) gate_pred.data.masked_fill_(mask[:, 0, :], 1e3) gate_pred = gate_pred.view(-1, 1) rnn_mel_loss = torch.nn.functional.mse_loss(spec_pred_dec, spec_target) postnet_mel_loss = torch.nn.functional.mse_loss( spec_pred_postnet, spec_target) gate_loss = torch.nn.functional.binary_cross_entropy_with_logits( gate_pred, gate_target) return rnn_mel_loss + postnet_mel_loss + gate_loss, gate_target
def forward(self, dec_inp, seq_lens=None): if self.word_emb is None: inp = dec_inp mask = get_mask_from_lengths(seq_lens).unsqueeze(2) else: inp = self.word_emb(dec_inp) # [bsz x L x 1] mask = (dec_inp != self.padding_idx).unsqueeze(2) pos_seq = torch.arange(inp.size(1), device=inp.device, dtype=inp.dtype) pos_emb = self.pos_emb(pos_seq) * mask out = self.drop(inp + pos_emb) for layer in self.layers: out = layer(out, mask=mask) # out = self.drop(out) return out, mask
def train_forward(self, *, memory, decoder_inputs, memory_lengths): decoder_input = self.get_go_frame(memory).unsqueeze(0) decoder_inputs = self.parse_decoder_inputs(decoder_inputs) decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0) decoder_inputs = self.prenet(decoder_inputs) self.initialize_decoder_states(memory, mask=~get_mask_from_lengths(memory_lengths)) mel_outputs, gate_outputs, alignments = [], [], [] while len(mel_outputs) < decoder_inputs.size(0) - 1: decoder_input = decoder_inputs[len(mel_outputs)] mel_output, gate_output, attention_weights = self.decode(decoder_input) mel_outputs += [mel_output.squeeze(1)] gate_outputs += [gate_output.squeeze()] alignments += [attention_weights] mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(mel_outputs, gate_outputs, alignments) return mel_outputs, gate_outputs, alignments
def forward(self, *, spec_pred, spec_target, spec_target_len, pad_value): spec_target.requires_grad = False max_len = spec_target.shape[2] if max_len < spec_pred.shape[2]: # Predicted len is larger than reference # Need to slice spec_pred = spec_pred.narrow(2, 0, max_len) elif max_len > spec_pred.shape[2]: # Need to do padding pad_amount = max_len - spec_pred.shape[2] spec_pred = torch.nn.functional.pad(spec_pred, (0, pad_amount), value=pad_value) max_len = spec_pred.shape[2] mask = ~get_mask_from_lengths(spec_target_len, max_len=max_len) mask = mask.expand(spec_target.shape[1], mask.size(0), mask.size(1)) mask = mask.permute(1, 0, 2) spec_pred.masked_fill_(mask, pad_value) mel_loss = torch.nn.functional.l1_loss(spec_pred, spec_target) return mel_loss
def _acc(durs_true, len_true, durs_pred): mask = get_mask_from_lengths(len_true) durs_pred = durs_pred.exp() - 1 durs_pred[durs_pred < 0.0] = 0.0 durs_pred = durs_pred.round().long() return ((durs_true == durs_pred) * mask).sum().float() / mask.sum() * 100
def forward(self, *, durs_true, len_true, durs_pred): loss = F.mse_loss(durs_pred, (durs_true + 1).float().log(), reduction='none') mask = get_mask_from_lengths(len_true) loss *= mask.float() return loss.sum() / mask.sum()
def forward(self, *, x, x_len, dur_target=None, pitch_target=None, energy_target=None, spec_len=None): """ Args: x: Input from the encoder. x_len: Length of the input. dur_target: Duration targets for the duration predictor. Needs to be passed in during training. pitch_target: Pitch targets for the pitch predictor. Needs to be passed in during training. energy_target: Energy targets for the energy predictor. Needs to be passed in during training. spec_len: Target spectrogram length. Needs to be passed in during training. """ # Duration predictions (or ground truth) fed into Length Regulator to # expand the hidden states of the encoder embedding log_dur_preds = self.duration_predictor(x) log_dur_preds.masked_fill_(~get_mask_from_lengths(x_len), 0) # Output is Batch, Time if dur_target is not None: dur_out = self.length_regulator(x, dur_target) else: dur_preds = torch.clamp_min( torch.round(torch.exp(log_dur_preds)) - 1, 0).long() if not torch.sum(dur_preds, dim=1).bool().all(): logging.error( "Duration prediction failed on this batch. Settings to 1s") dur_preds += 1 dur_out = self.length_regulator(x, dur_preds) spec_len = torch.sum(dur_preds, dim=1) out = dur_out out *= get_mask_from_lengths(spec_len).unsqueeze(-1) # Pitch pitch_preds = None if self.pitch: # Possible future work: # Add pitch spectrogram prediction & conversion back to pitch contour using iCWT # (see Appendix C of the FastSpeech 2/2s paper). pitch_preds = self.pitch_predictor(dur_out) pitch_preds.masked_fill_(~get_mask_from_lengths(spec_len), 0) if pitch_target is not None: pitch_out = self.pitch_lookup( torch.bucketize(pitch_target, self.pitch_bins)) else: pitch_out = self.pitch_lookup( torch.bucketize(pitch_preds.detach(), self.pitch_bins)) out += pitch_out out *= get_mask_from_lengths(spec_len).unsqueeze(-1) # Energy energy_preds = None if self.energy: energy_preds = self.energy_predictor(dur_out) if energy_target is not None: energy_out = self.energy_lookup( torch.bucketize(energy_target, self.energy_bins)) else: energy_out = self.energy_lookup( torch.bucketize(energy_preds.detach(), self.energy_bins)) out += energy_out out *= get_mask_from_lengths(spec_len).unsqueeze(-1) return out, log_dur_preds, pitch_preds, energy_preds, spec_len
def forward(self, *, mel_true, len_true, mel_pred): loss = F.mse_loss(mel_pred, mel_true, reduction='none').mean(dim=-2) mask = get_mask_from_lengths(len_true) loss *= mask.float() return loss.sum() / mask.sum()