Exemplo n.º 1
0
    def infer(
        self,
        text,
        text_len=None,
        text_mask=None,
        spect=None,
        spect_len=None,
        attn_prior=None,
        use_gt_durs=False,
        lm_tokens=None,
        pitch=None,
    ):
        if text_mask is None:
            text_mask = get_mask_from_lengths(text_len).unsqueeze(2)

        enc_out, enc_mask = self.encoder(text, text_mask)

        # Aligner
        attn_hard_dur = None
        if use_gt_durs:
            attn_soft, attn_logprob, attn_hard, attn_hard_dur = self.run_aligner(
                text, text_len, text_mask, spect, spect_len, attn_prior
            )

        if self.cond_on_lm_embeddings:
            lm_emb = self.lm_embeddings(lm_tokens)
            lm_features = self.self_attention_module(
                enc_out, lm_emb, lm_emb, q_mask=enc_mask.squeeze(2), kv_mask=lm_tokens != self.lm_padding_value
            )

        # Duration predictor
        log_durs_predicted = self.duration_predictor(enc_out, enc_mask)
        durs_predicted = torch.clamp(log_durs_predicted.exp() - 1, 0)

        # Avg pitch, pitch predictor
        if use_gt_durs and pitch is not None:
            pitch = average_pitch(pitch.unsqueeze(1), attn_hard_dur).squeeze(1)
            pitch_emb = self.pitch_emb(pitch.unsqueeze(1))
        else:
            pitch_predicted = self.pitch_predictor(enc_out, enc_mask)
            pitch_emb = self.pitch_emb(pitch_predicted.unsqueeze(1))

        # Add pitch emb
        enc_out = enc_out + pitch_emb.transpose(1, 2)

        if self.cond_on_lm_embeddings:
            enc_out = enc_out + lm_features

        if use_gt_durs:
            if attn_hard_dur is not None:
                len_regulated_enc_out, dec_lens = regulate_len(attn_hard_dur, enc_out)
            else:
                raise NotImplementedError
        else:
            len_regulated_enc_out, dec_lens = regulate_len(durs_predicted, enc_out)

        dec_out, _ = self.decoder(len_regulated_enc_out, get_mask_from_lengths(dec_lens).unsqueeze(2))
        pred_spect = self.proj(dec_out)

        return pred_spect
Exemplo n.º 2
0
    def forward(self, text, text_len, pitch=None, spect=None, spect_len=None, attn_prior=None, lm_tokens=None):
        if self.training:
            assert pitch is not None

        text_mask = get_mask_from_lengths(text_len).unsqueeze(2)

        enc_out, enc_mask = self.encoder(text, text_mask)

        # Aligner
        attn_soft, attn_logprob, attn_hard, attn_hard_dur = None, None, None, None
        if spect is not None:
            attn_soft, attn_logprob, attn_hard, attn_hard_dur = self.run_aligner(
                text, text_len, text_mask, spect, spect_len, attn_prior
            )

        if self.cond_on_lm_embeddings:
            lm_emb = self.lm_embeddings(lm_tokens)
            lm_features = self.self_attention_module(
                enc_out, lm_emb, lm_emb, q_mask=enc_mask.squeeze(2), kv_mask=lm_tokens != self.lm_padding_value
            )

        # Duration predictor
        log_durs_predicted = self.duration_predictor(enc_out, enc_mask)
        durs_predicted = torch.clamp(log_durs_predicted.exp() - 1, 0)

        # Pitch predictor
        pitch_predicted = self.pitch_predictor(enc_out, enc_mask)

        # Avg pitch, add pitch_emb
        if not self.training:
            if pitch is not None:
                pitch = average_pitch(pitch.unsqueeze(1), attn_hard_dur).squeeze(1)
                pitch_emb = self.pitch_emb(pitch.unsqueeze(1))
            else:
                pitch_emb = self.pitch_emb(pitch_predicted.unsqueeze(1))
        else:
            pitch = average_pitch(pitch.unsqueeze(1), attn_hard_dur).squeeze(1)
            pitch_emb = self.pitch_emb(pitch.unsqueeze(1))

        enc_out = enc_out + pitch_emb.transpose(1, 2)

        if self.cond_on_lm_embeddings:
            enc_out = enc_out + lm_features

        # Regulate length
        len_regulated_enc_out, dec_lens = regulate_len(attn_hard_dur, enc_out)

        dec_out, dec_lens = self.decoder(len_regulated_enc_out, get_mask_from_lengths(dec_lens).unsqueeze(2))
        pred_spect = self.proj(dec_out)

        return (
            pred_spect,
            durs_predicted,
            log_durs_predicted,
            pitch_predicted,
            attn_soft,
            attn_logprob,
            attn_hard,
            attn_hard_dur,
        )
Exemplo n.º 3
0
    def _metrics(
        self,
        true_durs,
        true_text_len,
        pred_durs,
        true_pitch,
        pred_pitch,
        true_spect=None,
        pred_spect=None,
        true_spect_len=None,
        attn_logprob=None,
        attn_soft=None,
        attn_hard=None,
        attn_hard_dur=None,
    ):
        text_mask = get_mask_from_lengths(true_text_len)
        mel_mask = get_mask_from_lengths(true_spect_len)
        loss = 0.0

        # Dur loss and metrics
        durs_loss = F.mse_loss(pred_durs, (true_durs + 1).float().log(), reduction='none')
        durs_loss = durs_loss * text_mask.float()
        durs_loss = durs_loss.sum() / text_mask.sum()

        durs_pred = pred_durs.exp() - 1
        durs_pred = torch.clamp_min(durs_pred, min=0)
        durs_pred = durs_pred.round().long()

        acc = ((true_durs == durs_pred) * text_mask).sum().float() / text_mask.sum() * 100
        acc_dist_1 = (((true_durs - durs_pred).abs() <= 1) * text_mask).sum().float() / text_mask.sum() * 100
        acc_dist_3 = (((true_durs - durs_pred).abs() <= 3) * text_mask).sum().float() / text_mask.sum() * 100

        pred_spect = pred_spect.transpose(1, 2)

        # Mel loss
        mel_loss = F.mse_loss(pred_spect, true_spect, reduction='none').mean(dim=-2)
        mel_loss = mel_loss * mel_mask.float()
        mel_loss = mel_loss.sum() / mel_mask.sum()

        loss = loss + self.durs_loss_scale * durs_loss + self.mel_loss_scale * mel_loss

        # Aligner loss
        bin_loss, ctc_loss = None, None
        ctc_loss = self.forward_sum_loss(attn_logprob=attn_logprob, in_lens=true_text_len, out_lens=true_spect_len)
        loss = loss + ctc_loss
        if self.add_bin_loss:
            bin_loss = self.bin_loss(hard_attention=attn_hard, soft_attention=attn_soft)
            loss = loss + self.bin_loss_scale * bin_loss
        true_avg_pitch = average_pitch(true_pitch.unsqueeze(1), attn_hard_dur).squeeze(1)

        # Pitch loss
        pitch_loss = F.mse_loss(pred_pitch, true_avg_pitch, reduction='none')  # noqa
        pitch_loss = (pitch_loss * text_mask).sum() / text_mask.sum()

        loss = loss + self.pitch_loss_scale * pitch_loss

        return loss, durs_loss, acc, acc_dist_1, acc_dist_3, pitch_loss, mel_loss, ctc_loss, bin_loss
Exemplo n.º 4
0
    def _log(true_mel, true_len, pred_mel):
        loss = F.mse_loss(pred_mel, true_mel, reduction='none').mean(dim=-2)
        mask = get_mask_from_lengths(true_len)
        loss *= mask.float()
        loss = loss.sum() / mask.sum()

        return dict(loss=loss)
Exemplo n.º 5
0
    def validation_step(self, batch, batch_idx):
        audio, audio_len = batch
        with torch.no_grad():
            spec, _ = self.audio_to_melspec_precessor(audio, audio_len)
            audio_pred = self(spec=spec)

            loss = 0
            loss_dict = {}
            spec_pred, _ = self.audio_to_melspec_precessor(
                audio_pred.squeeze(1), audio_len)

            # Ensure that audio len is consistent between audio_pred and audio
            # For SC Norm loss, we can just zero out
            # For Mag L1 loss, we need to mask
            if audio_pred.shape[-1] < audio.shape[-1]:
                # prediction audio is less than audio, pad predicted audio to real audio
                pad_amount = audio.shape[-1] - audio_pred.shape[-1]
                audio_pred = torch.nn.functional.pad(audio_pred,
                                                     (0, pad_amount),
                                                     value=0.0)
            else:
                # prediction audio is larger than audio, slice predicted audio to real audio
                audio_pred = audio_pred[:, :, :audio.shape[1]]

            mask = ~get_mask_from_lengths(audio_len,
                                          max_len=torch.max(audio_len))
            mask = mask.unsqueeze(1)
            audio_pred.data.masked_fill_(mask, 0.0)

            # full-band loss
            sc_loss, mag_loss = self.loss(x=audio_pred.squeeze(1),
                                          y=audio,
                                          input_lengths=audio_len)
            loss_feat = (sum(sc_loss) + sum(mag_loss)) / len(sc_loss)
            loss_dict["sc_loss"] = sc_loss
            loss_dict["mag_loss"] = mag_loss

            loss += loss_feat
            loss_dict["loss_feat"] = loss_feat

            if self.start_training_disc:
                fake_score = self.discriminator(x=audio_pred)[0]

                loss_gen = [0] * len(fake_score)
                for i, scale in enumerate(fake_score):
                    loss_gen[i] += self.mse_loss(scale,
                                                 scale.new_ones(scale.size()))

                loss_dict["gan_loss"] = loss_gen
                loss += sum(loss_gen) / len(fake_score)

        if not self.logged_real_samples:
            loss_dict["spec"] = spec
            loss_dict["audio"] = audio
        loss_dict["audio_pred"] = audio_pred
        loss_dict["spec_pred"] = spec_pred
        loss_dict["loss"] = loss
        return loss_dict
Exemplo n.º 6
0
    def forward(self, *, spec, spec_len, text, text_len, attn_prior=None):
        with torch.cuda.amp.autocast(enabled=False):
            attn_soft, attn_logprob = self.alignment_encoder(
                queries=spec,
                keys=self.embed(text).transpose(1, 2),
                mask=get_mask_from_lengths(text_len).unsqueeze(-1) == 0,
                attn_prior=attn_prior,
            )

        return attn_soft, attn_logprob
Exemplo n.º 7
0
    def _metrics(true_durs, true_text_len, pred_durs):
        loss = F.mse_loss(pred_durs, (true_durs + 1).float().log(), reduction='none')
        mask = get_mask_from_lengths(true_text_len)
        loss *= mask.float()
        loss = loss.sum() / mask.sum()

        durs_pred = pred_durs.exp() - 1
        durs_pred[durs_pred < 0.0] = 0.0
        durs_pred = durs_pred.round().long()
        acc = ((true_durs == durs_pred) * mask).sum().float() / mask.sum() * 100

        return loss, acc
Exemplo n.º 8
0
    def generate_spectrogram(self, *, tokens):
        self.eval()
        self.calculate_loss = False
        token_len = torch.tensor([len(i) for i in tokens]).to(self.device)
        tensors = self(tokens=tokens, token_len=token_len)
        spectrogram_pred = tensors[1]

        if spectrogram_pred.shape[0] > 1:
            # Silence all frames past the predicted end
            mask = ~get_mask_from_lengths(tensors[-1])
            mask = mask.expand(spectrogram_pred.shape[1], mask.size(0), mask.size(1))
            mask = mask.permute(1, 0, 2)
            spectrogram_pred.data.masked_fill_(mask, self.pad_value)

        return spectrogram_pred
Exemplo n.º 9
0
    def infer(self, *, memory, memory_lengths):
        decoder_input = self.get_go_frame(memory)

        if memory.size(0) > 1:
            mask = ~get_mask_from_lengths(memory_lengths)
        else:
            mask = None

        self.initialize_decoder_states(memory, mask=mask)

        mel_lengths = torch.zeros([memory.size(0)], dtype=torch.int32)
        not_finished = torch.ones([memory.size(0)], dtype=torch.int32)
        if torch.cuda.is_available():
            mel_lengths = mel_lengths.cuda()
            not_finished = not_finished.cuda()

        mel_outputs, gate_outputs, alignments = [], [], []
        stepped = False
        while True:
            decoder_input = self.prenet(decoder_input, inference=True)
            mel_output, gate_output, alignment = self.decode(decoder_input)

            dec = torch.le(torch.sigmoid(gate_output.data),
                           self.gate_threshold).to(torch.int32).squeeze(1)

            not_finished = not_finished * dec
            mel_lengths += not_finished

            if self.early_stopping and torch.sum(
                    not_finished) == 0 and stepped:
                break
            stepped = True

            mel_outputs += [mel_output.squeeze(1)]
            gate_outputs += [gate_output]
            alignments += [alignment]

            if len(mel_outputs) == self.max_decoder_steps:
                logging.warning("Reached max decoder steps %d.",
                                self.max_decoder_steps)
                break

            decoder_input = mel_output

        mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(
            mel_outputs, gate_outputs, alignments)

        return mel_outputs, gate_outputs, alignments, mel_lengths
Exemplo n.º 10
0
    def forward(self, *, spec_pred_dec, spec_pred_postnet, gate_pred,
                spec_target, spec_target_len, pad_value):
        # Make the gate target
        max_len = spec_target.shape[2]
        gate_target = torch.zeros(spec_target_len.shape[0], max_len)
        gate_target = gate_target.type_as(gate_pred)
        for i, length in enumerate(spec_target_len):
            gate_target[i, length.data - 1:] = 1

        spec_target.requires_grad = False
        gate_target.requires_grad = False
        gate_target = gate_target.view(-1, 1)

        max_len = spec_target.shape[2]

        if max_len < spec_pred_dec.shape[2]:
            # Predicted len is larger than reference
            # Need to slice
            spec_pred_dec = spec_pred_dec.narrow(2, 0, max_len)
            spec_pred_postnet = spec_pred_postnet.narrow(2, 0, max_len)
            gate_pred = gate_pred.narrow(1, 0, max_len).contiguous()
        elif max_len > spec_pred_dec.shape[2]:
            # Need to do padding
            pad_amount = max_len - spec_pred_dec.shape[2]
            spec_pred_dec = torch.nn.functional.pad(spec_pred_dec,
                                                    (0, pad_amount),
                                                    value=pad_value)
            spec_pred_postnet = torch.nn.functional.pad(spec_pred_postnet,
                                                        (0, pad_amount),
                                                        value=pad_value)
            gate_pred = torch.nn.functional.pad(gate_pred, (0, pad_amount),
                                                value=1e3)
            max_len = spec_pred_dec.shape[2]

        mask = ~get_mask_from_lengths(spec_target_len, max_len=max_len)
        mask = mask.expand(spec_target.shape[1], mask.size(0), mask.size(1))
        mask = mask.permute(1, 0, 2)
        spec_pred_dec.data.masked_fill_(mask, pad_value)
        spec_pred_postnet.data.masked_fill_(mask, pad_value)
        gate_pred.data.masked_fill_(mask[:, 0, :], 1e3)

        gate_pred = gate_pred.view(-1, 1)
        rnn_mel_loss = torch.nn.functional.mse_loss(spec_pred_dec, spec_target)
        postnet_mel_loss = torch.nn.functional.mse_loss(
            spec_pred_postnet, spec_target)
        gate_loss = torch.nn.functional.binary_cross_entropy_with_logits(
            gate_pred, gate_target)
        return rnn_mel_loss + postnet_mel_loss + gate_loss, gate_target
Exemplo n.º 11
0
    def forward(self, dec_inp, seq_lens=None):
        if self.word_emb is None:
            inp = dec_inp
            mask = get_mask_from_lengths(seq_lens).unsqueeze(2)
        else:
            inp = self.word_emb(dec_inp)
            # [bsz x L x 1]
            mask = (dec_inp != self.padding_idx).unsqueeze(2)

        pos_seq = torch.arange(inp.size(1), device=inp.device, dtype=inp.dtype)
        pos_emb = self.pos_emb(pos_seq) * mask
        out = self.drop(inp + pos_emb)

        for layer in self.layers:
            out = layer(out, mask=mask)

        # out = self.drop(out)
        return out, mask
Exemplo n.º 12
0
    def train_forward(self, *, memory, decoder_inputs, memory_lengths):
        decoder_input = self.get_go_frame(memory).unsqueeze(0)
        decoder_inputs = self.parse_decoder_inputs(decoder_inputs)
        decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0)
        decoder_inputs = self.prenet(decoder_inputs)

        self.initialize_decoder_states(memory, mask=~get_mask_from_lengths(memory_lengths))

        mel_outputs, gate_outputs, alignments = [], [], []
        while len(mel_outputs) < decoder_inputs.size(0) - 1:
            decoder_input = decoder_inputs[len(mel_outputs)]
            mel_output, gate_output, attention_weights = self.decode(decoder_input)

            mel_outputs += [mel_output.squeeze(1)]
            gate_outputs += [gate_output.squeeze()]
            alignments += [attention_weights]

        mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(mel_outputs, gate_outputs, alignments)
        return mel_outputs, gate_outputs, alignments
Exemplo n.º 13
0
    def forward(self, *, spec_pred, spec_target, spec_target_len, pad_value):
        spec_target.requires_grad = False
        max_len = spec_target.shape[2]

        if max_len < spec_pred.shape[2]:
            # Predicted len is larger than reference
            # Need to slice
            spec_pred = spec_pred.narrow(2, 0, max_len)
        elif max_len > spec_pred.shape[2]:
            # Need to do padding
            pad_amount = max_len - spec_pred.shape[2]
            spec_pred = torch.nn.functional.pad(spec_pred, (0, pad_amount),
                                                value=pad_value)
            max_len = spec_pred.shape[2]

        mask = ~get_mask_from_lengths(spec_target_len, max_len=max_len)
        mask = mask.expand(spec_target.shape[1], mask.size(0), mask.size(1))
        mask = mask.permute(1, 0, 2)
        spec_pred.masked_fill_(mask, pad_value)

        mel_loss = torch.nn.functional.l1_loss(spec_pred, spec_target)
        return mel_loss
Exemplo n.º 14
0
 def _acc(durs_true, len_true, durs_pred):
     mask = get_mask_from_lengths(len_true)
     durs_pred = durs_pred.exp() - 1
     durs_pred[durs_pred < 0.0] = 0.0
     durs_pred = durs_pred.round().long()
     return ((durs_true == durs_pred) * mask).sum().float() / mask.sum() * 100
Exemplo n.º 15
0
 def forward(self, *, durs_true, len_true, durs_pred):
     loss = F.mse_loss(durs_pred, (durs_true + 1).float().log(), reduction='none')
     mask = get_mask_from_lengths(len_true)
     loss *= mask.float()
     return loss.sum() / mask.sum()
Exemplo n.º 16
0
    def forward(self,
                *,
                x,
                x_len,
                dur_target=None,
                pitch_target=None,
                energy_target=None,
                spec_len=None):
        """
        Args:
            x: Input from the encoder.
            x_len: Length of the input.
            dur_target:  Duration targets for the duration predictor. Needs to be passed in during training.
            pitch_target: Pitch targets for the pitch predictor. Needs to be passed in during training.
            energy_target: Energy targets for the energy predictor. Needs to be passed in during training.
            spec_len: Target spectrogram length. Needs to be passed in during training.
        """
        # Duration predictions (or ground truth) fed into Length Regulator to
        # expand the hidden states of the encoder embedding
        log_dur_preds = self.duration_predictor(x)
        log_dur_preds.masked_fill_(~get_mask_from_lengths(x_len), 0)
        # Output is Batch, Time
        if dur_target is not None:
            dur_out = self.length_regulator(x, dur_target)
        else:
            dur_preds = torch.clamp_min(
                torch.round(torch.exp(log_dur_preds)) - 1, 0).long()
            if not torch.sum(dur_preds, dim=1).bool().all():
                logging.error(
                    "Duration prediction failed on this batch. Settings to 1s")
                dur_preds += 1
            dur_out = self.length_regulator(x, dur_preds)
            spec_len = torch.sum(dur_preds, dim=1)
        out = dur_out
        out *= get_mask_from_lengths(spec_len).unsqueeze(-1)

        # Pitch
        pitch_preds = None
        if self.pitch:
            # Possible future work:
            #   Add pitch spectrogram prediction & conversion back to pitch contour using iCWT
            #   (see Appendix C of the FastSpeech 2/2s paper).
            pitch_preds = self.pitch_predictor(dur_out)
            pitch_preds.masked_fill_(~get_mask_from_lengths(spec_len), 0)
            if pitch_target is not None:
                pitch_out = self.pitch_lookup(
                    torch.bucketize(pitch_target, self.pitch_bins))
            else:
                pitch_out = self.pitch_lookup(
                    torch.bucketize(pitch_preds.detach(), self.pitch_bins))
            out += pitch_out
        out *= get_mask_from_lengths(spec_len).unsqueeze(-1)

        # Energy
        energy_preds = None
        if self.energy:
            energy_preds = self.energy_predictor(dur_out)
            if energy_target is not None:
                energy_out = self.energy_lookup(
                    torch.bucketize(energy_target, self.energy_bins))
            else:
                energy_out = self.energy_lookup(
                    torch.bucketize(energy_preds.detach(), self.energy_bins))
            out += energy_out
        out *= get_mask_from_lengths(spec_len).unsqueeze(-1)

        return out, log_dur_preds, pitch_preds, energy_preds, spec_len
Exemplo n.º 17
0
 def forward(self, *, mel_true, len_true, mel_pred):
     loss = F.mse_loss(mel_pred, mel_true, reduction='none').mean(dim=-2)
     mask = get_mask_from_lengths(len_true)
     loss *= mask.float()
     return loss.sum() / mask.sum()