Пример #1
0
 def compute_align_path(self, mu, log_sigma, y, x_mask, y_mask):
     # find the max alignment path
     attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
     log_p = self.compute_log_probs(mu, log_sigma, y)
     # [B, T_en, T_dec]
     attn = maximum_path(log_p, attn_mask.squeeze(1)).unsqueeze(1)
     dr_mas = torch.sum(attn, -1)
     return dr_mas.squeeze(1), log_p
Пример #2
0
    def _forward_aligner(
        self, x: torch.FloatTensor, y: torch.FloatTensor,
        x_mask: torch.IntTensor, y_mask: torch.IntTensor
    ) -> Tuple[torch.IntTensor, torch.FloatTensor, torch.FloatTensor,
               torch.FloatTensor]:
        """Aligner forward pass.

        1. Compute a mask to apply to the attention map.
        2. Run the alignment network.
        3. Apply MAS to compute the hard alignment map.
        4. Compute the durations from the hard alignment map.

        Args:
            x (torch.FloatTensor): Input sequence.
            y (torch.FloatTensor): Output sequence.
            x_mask (torch.IntTensor): Input sequence mask.
            y_mask (torch.IntTensor): Output sequence mask.

        Returns:
            Tuple[torch.IntTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
                Durations from the hard alignment map, soft alignment potentials, log scale alignment potentials,
                hard alignment map.

        Shapes:
            - x: :math:`[B, T_en, C_en]`
            - y: :math:`[B, T_de, C_de]`
            - x_mask: :math:`[B, 1, T_en]`
            - y_mask: :math:`[B, 1, T_de]`

            - o_alignment_dur: :math:`[B, T_en]`
            - alignment_soft: :math:`[B, T_en, T_de]`
            - alignment_logprob: :math:`[B, 1, T_de, T_en]`
            - alignment_mas: :math:`[B, T_en, T_de]`
        """
        attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
        alignment_soft, alignment_logprob = self.aligner(
            y.transpose(1, 2), x.transpose(1, 2), x_mask, None)
        alignment_mas = maximum_path(
            alignment_soft.squeeze(1).transpose(1, 2).contiguous(),
            attn_mask.squeeze(1).contiguous())
        o_alignment_dur = torch.sum(alignment_mas, -1).int()
        alignment_soft = alignment_soft.squeeze(1).transpose(1, 2)
        return o_alignment_dur, alignment_soft, alignment_logprob, alignment_mas
Пример #3
0
    def inference_with_MAS(self,
                           x,
                           x_lengths,
                           y=None,
                           y_lengths=None,
                           aux_input={
                               "d_vectors": None,
                               "speaker_ids": None
                           }):  # pylint: disable=dangerous-default-value
        """
        It's similar to the teacher forcing in Tacotron.
        It was proposed in: https://arxiv.org/abs/2104.05557

        Shapes:
            - x: :math:`[B, T]`
            - x_lenghts: :math:`B`
            - y: :math:`[B, T, C]`
            - y_lengths: :math:`B`
            - g: :math:`[B, C] or B`
        """
        y = y.transpose(1, 2)
        y_max_length = y.size(2)
        # norm speaker embeddings
        g = aux_input[
            "d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None
        if self.use_speaker_embedding or self.use_d_vector_file:
            if not self.use_d_vector_file:
                g = F.normalize(g).unsqueeze(-1)
            else:
                g = F.normalize(self.emb_g(g)).unsqueeze(-1)  # [b, h, 1]
        # embedding pass
        o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x,
                                                              x_lengths,
                                                              g=g)
        # drop redisual frames wrt num_squeeze and set y_lengths.
        y, y_lengths, y_max_length, attn = self.preprocess(
            y, y_lengths, y_max_length, None)
        # create masks
        y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length),
                                 1).to(x_mask.dtype)
        attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
        # decoder pass
        z, logdet = self.decoder(y, y_mask, g=g, reverse=False)
        # find the alignment path between z and encoder output
        o_scale = torch.exp(-2 * o_log_scale)
        logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale,
                          [1]).unsqueeze(-1)  # [b, t, 1]
        logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 *
                             (z**2))  # [b, t, d] x [b, d, t'] = [b, t, t']
        logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2),
                             z)  # [b, t, d] x [b, d, t'] = [b, t, t']
        logp4 = torch.sum(-0.5 * (o_mean**2) * o_scale,
                          [1]).unsqueeze(-1)  # [b, t, 1]
        logp = logp1 + logp2 + logp3 + logp4  # [b, t, t']
        attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach()

        y_mean, y_log_scale, o_attn_dur = self.compute_outputs(
            attn, o_mean, o_log_scale, x_mask)
        attn = attn.squeeze(1).permute(0, 2, 1)

        # get predited aligned distribution
        z = y_mean * y_mask

        # reverse the decoder and predict using the aligned distribution
        y, logdet = self.decoder(z, y_mask, g=g, reverse=True)
        outputs = {
            "model_outputs": z.transpose(1, 2),
            "logdet": logdet,
            "y_mean": y_mean.transpose(1, 2),
            "y_log_scale": y_log_scale.transpose(1, 2),
            "alignments": attn,
            "durations_log": o_dur_log.transpose(1, 2),
            "total_durations_log": o_attn_dur.transpose(1, 2),
        }
        return outputs
Пример #4
0
 def forward(self,
             x,
             x_lengths,
             y,
             y_lengths=None,
             aux_input={
                 "d_vectors": None,
                 "speaker_ids": None
             }):  # pylint: disable=dangerous-default-value
     """
     Shapes:
         - x: :math:`[B, T]`
         - x_lenghts::math:`B`
         - y: :math:`[B, T, C]`
         - y_lengths::math:`B`
         - g: :math:`[B, C] or B`
     """
     # [B, T, C] -> [B, C, T]
     y = y.transpose(1, 2)
     y_max_length = y.size(2)
     # norm speaker embeddings
     g = aux_input[
         "d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None
     if self.use_speaker_embedding or self.use_d_vector_file:
         if not self.use_d_vector_file:
             g = F.normalize(g).unsqueeze(-1)
         else:
             g = F.normalize(self.emb_g(g)).unsqueeze(-1)  # [b, h, 1]
     # embedding pass
     o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x,
                                                           x_lengths,
                                                           g=g)
     # drop redisual frames wrt num_squeeze and set y_lengths.
     y, y_lengths, y_max_length, attn = self.preprocess(
         y, y_lengths, y_max_length, None)
     # create masks
     y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length),
                              1).to(x_mask.dtype)
     # [B, 1, T_en, T_de]
     attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
     # decoder pass
     z, logdet = self.decoder(y, y_mask, g=g, reverse=False)
     # find the alignment path
     with torch.no_grad():
         o_scale = torch.exp(-2 * o_log_scale)
         logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale,
                           [1]).unsqueeze(-1)  # [b, t, 1]
         logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 *
                              (z**2))  # [b, t, d] x [b, d, t'] = [b, t, t']
         logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2),
                              z)  # [b, t, d] x [b, d, t'] = [b, t, t']
         logp4 = torch.sum(-0.5 * (o_mean**2) * o_scale,
                           [1]).unsqueeze(-1)  # [b, t, 1]
         logp = logp1 + logp2 + logp3 + logp4  # [b, t, t']
         attn = maximum_path(logp,
                             attn_mask.squeeze(1)).unsqueeze(1).detach()
     y_mean, y_log_scale, o_attn_dur = self.compute_outputs(
         attn, o_mean, o_log_scale, x_mask)
     attn = attn.squeeze(1).permute(0, 2, 1)
     outputs = {
         "z": z.transpose(1, 2),
         "logdet": logdet,
         "y_mean": y_mean.transpose(1, 2),
         "y_log_scale": y_log_scale.transpose(1, 2),
         "alignments": attn,
         "durations_log": o_dur_log.transpose(1, 2),
         "total_durations_log": o_attn_dur.transpose(1, 2),
     }
     return outputs
Пример #5
0
    def forward(
        self,
        x: torch.tensor,
        x_lengths: torch.tensor,
        y: torch.tensor,
        y_lengths: torch.tensor,
        aux_input={
            "d_vectors": None,
            "speaker_ids": None
        },
    ) -> Dict:
        """Forward pass of the model.

        Args:
            x (torch.tensor): Batch of input character sequence IDs.
            x_lengths (torch.tensor): Batch of input character sequence lengths.
            y (torch.tensor): Batch of input spectrograms.
            y_lengths (torch.tensor): Batch of input spectrogram lengths.
            aux_input (dict, optional): Auxiliary inputs for multi-speaker training. Defaults to {"d_vectors": None, "speaker_ids": None}.

        Returns:
            Dict: model outputs keyed by the output name.

        Shapes:
            - x: :math:`[B, T_seq]`
            - x_lengths: :math:`[B]`
            - y: :math:`[B, C, T_spec]`
            - y_lengths: :math:`[B]`
            - d_vectors: :math:`[B, C, 1]`
            - speaker_ids: :math:`[B]`
        """
        outputs = {}
        sid, g = self._set_cond_input(aux_input)
        x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths)

        # speaker embedding
        if self.num_speakers > 1 and sid is not None:
            g = self.emb_g(sid).unsqueeze(-1)  # [b, h, 1]

        # posterior encoder
        z, m_q, logs_q, y_mask = self.posterior_encoder(y, y_lengths, g=g)

        # flow layers
        z_p = self.flow(z, y_mask, g=g)

        # find the alignment path
        attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
        with torch.no_grad():
            o_scale = torch.exp(-2 * logs_p)
            logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p,
                              [1]).unsqueeze(-1)  # [b, t, 1]
            logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p**2)])
            logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p])
            logp4 = torch.sum(-0.5 * (m_p**2) * o_scale,
                              [1]).unsqueeze(-1)  # [b, t, 1]
            logp = logp2 + logp3 + logp1 + logp4
            attn = maximum_path(logp,
                                attn_mask.squeeze(1)).unsqueeze(1).detach()

        # duration predictor
        attn_durations = attn.sum(3)
        if self.args.use_sdp:
            loss_duration = self.duration_predictor(
                x.detach() if self.args.detach_dp_input else x,
                x_mask,
                attn_durations,
                g=g.detach()
                if self.args.detach_dp_input and g is not None else g,
            )
            loss_duration = loss_duration / torch.sum(x_mask)
        else:
            attn_log_durations = torch.log(attn_durations + 1e-6) * x_mask
            log_durations = self.duration_predictor(
                x.detach() if self.args.detach_dp_input else x,
                x_mask,
                g=g.detach()
                if self.args.detach_dp_input and g is not None else g,
            )
            loss_duration = torch.sum((log_durations - attn_log_durations)**2,
                                      [1, 2]) / torch.sum(x_mask)
        outputs["loss_duration"] = loss_duration

        # expand prior
        m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p])
        logs_p = torch.einsum("klmn, kjm -> kjn", [attn, logs_p])

        # select a random feature segment for the waveform decoder
        z_slice, slice_ids = rand_segments(z, y_lengths,
                                           self.spec_segment_size)
        o = self.waveform_decoder(z_slice, g=g)
        outputs.update({
            "model_outputs": o,
            "alignments": attn.squeeze(1),
            "slice_ids": slice_ids,
            "z": z,
            "z_p": z_p,
            "m_p": m_p,
            "logs_p": logs_p,
            "m_q": m_q,
            "logs_q": logs_q,
        })
        return outputs
Пример #6
0
    def forward(self,
                x,
                x_lengths,
                y,
                y_lengths=None,
                aux_input={
                    "d_vectors": None,
                    "speaker_ids": None
                }):  # pylint: disable=dangerous-default-value
        """
        Args:
            x (torch.Tensor):
                Input text sequence ids. :math:`[B, T_en]`

            x_lengths (torch.Tensor):
                Lengths of input text sequences. :math:`[B]`

            y (torch.Tensor):
                Target mel-spectrogram frames. :math:`[B, T_de, C_mel]`

            y_lengths (torch.Tensor):
                Lengths of target mel-spectrogram frames. :math:`[B]`

            aux_input (Dict):
                Auxiliary inputs. `d_vectors` is speaker embedding vectors for a multi-speaker model.
                :math:`[B, D_vec]`. `speaker_ids` is speaker ids for a multi-speaker model usind speaker-embedding
                layer. :math:`B`

        Returns:
            Dict:
                - z: :math: `[B, T_de, C]`
                - logdet: :math:`B`
                - y_mean: :math:`[B, T_de, C]`
                - y_log_scale: :math:`[B, T_de, C]`
                - alignments: :math:`[B, T_en, T_de]`
                - durations_log: :math:`[B, T_en, 1]`
                - total_durations_log: :math:`[B, T_en, 1]`
        """
        # [B, T, C] -> [B, C, T]
        y = y.transpose(1, 2)
        y_max_length = y.size(2)
        # norm speaker embeddings
        g = self._speaker_embedding(aux_input)
        # embedding pass
        o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x,
                                                              x_lengths,
                                                              g=g)
        # drop redisual frames wrt num_squeeze and set y_lengths.
        y, y_lengths, y_max_length, attn = self.preprocess(
            y, y_lengths, y_max_length, None)
        # create masks
        y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length),
                                 1).to(x_mask.dtype)
        # [B, 1, T_en, T_de]
        attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
        # decoder pass
        z, logdet = self.decoder(y, y_mask, g=g, reverse=False)
        # find the alignment path
        with torch.no_grad():
            o_scale = torch.exp(-2 * o_log_scale)
            logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale,
                              [1]).unsqueeze(-1)  # [b, t, 1]
            logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 *
                                 (z**2))  # [b, t, d] x [b, d, t'] = [b, t, t']
            logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2),
                                 z)  # [b, t, d] x [b, d, t'] = [b, t, t']
            logp4 = torch.sum(-0.5 * (o_mean**2) * o_scale,
                              [1]).unsqueeze(-1)  # [b, t, 1]
            logp = logp1 + logp2 + logp3 + logp4  # [b, t, t']
            attn = maximum_path(logp,
                                attn_mask.squeeze(1)).unsqueeze(1).detach()
        y_mean, y_log_scale, o_attn_dur = self.compute_outputs(
            attn, o_mean, o_log_scale, x_mask)
        attn = attn.squeeze(1).permute(0, 2, 1)
        outputs = {
            "z": z.transpose(1, 2),
            "logdet": logdet,
            "y_mean": y_mean.transpose(1, 2),
            "y_log_scale": y_log_scale.transpose(1, 2),
            "alignments": attn,
            "durations_log": o_dur_log.transpose(1, 2),
            "total_durations_log": o_attn_dur.transpose(1, 2),
        }
        return outputs
Пример #7
0
    def forward(
        self,
        x: torch.tensor,
        x_lengths: torch.tensor,
        y: torch.tensor,
        y_lengths: torch.tensor,
        waveform: torch.tensor,
        aux_input={
            "d_vectors": None,
            "speaker_ids": None,
            "language_ids": None
        },
    ) -> Dict:
        """Forward pass of the model.

        Args:
            x (torch.tensor): Batch of input character sequence IDs.
            x_lengths (torch.tensor): Batch of input character sequence lengths.
            y (torch.tensor): Batch of input spectrograms.
            y_lengths (torch.tensor): Batch of input spectrogram lengths.
            waveform (torch.tensor): Batch of ground truth waveforms per sample.
            aux_input (dict, optional): Auxiliary inputs for multi-speaker and multi-lingual training.
                Defaults to {"d_vectors": None, "speaker_ids": None, "language_ids": None}.

        Returns:
            Dict: model outputs keyed by the output name.

        Shapes:
            - x: :math:`[B, T_seq]`
            - x_lengths: :math:`[B]`
            - y: :math:`[B, C, T_spec]`
            - y_lengths: :math:`[B]`
            - waveform: :math:`[B, T_wav, 1]`
            - d_vectors: :math:`[B, C, 1]`
            - speaker_ids: :math:`[B]`
            - language_ids: :math:`[B]`
        """
        outputs = {}
        sid, g, lid = self._set_cond_input(aux_input)
        # speaker embedding
        if self.args.use_speaker_embedding and sid is not None:
            g = self.emb_g(sid).unsqueeze(-1)  # [b, h, 1]

        # language embedding
        lang_emb = None
        if self.args.use_language_embedding and lid is not None:
            lang_emb = self.emb_l(lid).unsqueeze(-1)

        x, m_p, logs_p, x_mask = self.text_encoder(x,
                                                   x_lengths,
                                                   lang_emb=lang_emb)

        # posterior encoder
        z, m_q, logs_q, y_mask = self.posterior_encoder(y, y_lengths, g=g)

        # flow layers
        z_p = self.flow(z, y_mask, g=g)

        # find the alignment path
        attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
        with torch.no_grad():
            o_scale = torch.exp(-2 * logs_p)
            logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p,
                              [1]).unsqueeze(-1)  # [b, t, 1]
            logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p**2)])
            logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p])
            logp4 = torch.sum(-0.5 * (m_p**2) * o_scale,
                              [1]).unsqueeze(-1)  # [b, t, 1]
            logp = logp2 + logp3 + logp1 + logp4
            attn = maximum_path(logp,
                                attn_mask.squeeze(1)).unsqueeze(1).detach()

        # duration predictor
        attn_durations = attn.sum(3)
        if self.args.use_sdp:
            loss_duration = self.duration_predictor(
                x.detach() if self.args.detach_dp_input else x,
                x_mask,
                attn_durations,
                g=g.detach()
                if self.args.detach_dp_input and g is not None else g,
                lang_emb=lang_emb.detach() if self.args.detach_dp_input
                and lang_emb is not None else lang_emb,
            )
            loss_duration = loss_duration / torch.sum(x_mask)
        else:
            attn_log_durations = torch.log(attn_durations + 1e-6) * x_mask
            log_durations = self.duration_predictor(
                x.detach() if self.args.detach_dp_input else x,
                x_mask,
                g=g.detach()
                if self.args.detach_dp_input and g is not None else g,
                lang_emb=lang_emb.detach() if self.args.detach_dp_input
                and lang_emb is not None else lang_emb,
            )
            loss_duration = torch.sum((log_durations - attn_log_durations)**2,
                                      [1, 2]) / torch.sum(x_mask)
        outputs["loss_duration"] = loss_duration

        # expand prior
        m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p])
        logs_p = torch.einsum("klmn, kjm -> kjn", [attn, logs_p])

        # select a random feature segment for the waveform decoder
        z_slice, slice_ids = rand_segments(z, y_lengths,
                                           self.spec_segment_size)
        o = self.waveform_decoder(z_slice, g=g)

        wav_seg = segment(
            waveform,
            slice_ids * self.config.audio.hop_length,
            self.args.spec_segment_size * self.config.audio.hop_length,
        )

        if self.args.use_speaker_encoder_as_loss and self.speaker_manager.speaker_encoder is not None:
            # concate generated and GT waveforms
            wavs_batch = torch.cat((wav_seg, o), dim=0)

            # resample audio to speaker encoder sample_rate
            # pylint: disable=W0105
            if self.audio_transform is not None:
                wavs_batch = self.audio_transform(wavs_batch)

            pred_embs = self.speaker_manager.speaker_encoder.forward(
                wavs_batch, l2_norm=True)

            # split generated and GT speaker embeddings
            gt_spk_emb, syn_spk_emb = torch.chunk(pred_embs, 2, dim=0)
        else:
            gt_spk_emb, syn_spk_emb = None, None

        outputs.update({
            "model_outputs": o,
            "alignments": attn.squeeze(1),
            "z": z,
            "z_p": z_p,
            "m_p": m_p,
            "logs_p": logs_p,
            "m_q": m_q,
            "logs_q": logs_q,
            "waveform_seg": wav_seg,
            "gt_spk_emb": gt_spk_emb,
            "syn_spk_emb": syn_spk_emb,
        })
        return outputs