Ejemplo n.º 1
0
 def forward(self, x, target, length):
     """
     Args:
         x: A Variable containing a FloatTensor of size
             (batch, max_len, dim) which contains the
             unnormalized probability for each class.
         target: A Variable containing a LongTensor of size
             (batch, max_len, dim) which contains the index of the true
             class for each corresponding step.
         length: A Variable containing a LongTensor of size (batch,)
             which contains the length of each data in a batch.
     Shapes:
         x: B x T X D
         target: B x T x D
         length: B
     Returns:
         loss: An average loss value in range [0, 1] masked by the length.
     """
     # mask: (batch, max_len, 1)
     target.requires_grad = False
     mask = sequence_mask(sequence_length=length,
                          max_len=target.size(1)).unsqueeze(2).float()
     if self.seq_len_norm:
         norm_w = mask / mask.sum(dim=1, keepdim=True)
         out_weights = norm_w.div(target.shape[0] * target.shape[2])
         mask = mask.expand_as(x)
         loss = functional.l1_loss(x * mask,
                                   target * mask,
                                   reduction="none")
         loss = loss.mul(out_weights.to(loss.device)).sum()
     else:
         mask = mask.expand_as(x)
         loss = functional.l1_loss(x * mask, target * mask, reduction="sum")
         loss = loss / mask.sum()
     return loss
Ejemplo n.º 2
0
 def forward(self, x, target, length):
     """
     Args:
         x: A Variable containing a FloatTensor of size
             (batch, max_len) which contains the
             unnormalized probability for each class.
         target: A Variable containing a LongTensor of size
             (batch, max_len) which contains the index of the true
             class for each corresponding step.
         length: A Variable containing a LongTensor of size (batch,)
             which contains the length of each data in a batch.
     Shapes:
         x: B x T
         target: B x T
         length: B
     Returns:
         loss: An average loss value in range [0, 1] masked by the length.
     """
     # mask: (batch, max_len, 1)
     target.requires_grad = False
     if length is not None:
         mask = sequence_mask(sequence_length=length,
                              max_len=target.size(1)).float()
         x = x * mask
         target = target * mask
         num_items = mask.sum()
     else:
         num_items = torch.numel(x)
     loss = functional.binary_cross_entropy_with_logits(
         x, target, pos_weight=self.pos_weight, reduction="sum")
     loss = loss / num_items
     return loss
Ejemplo n.º 3
0
 def decoder_inference(self,
                       y,
                       y_lengths=None,
                       aux_input={
                           "d_vectors": None,
                           "speaker_ids": None
                       }):  # pylint: disable=dangerous-default-value
     """
     Shapes:
         - y: :math:`[B, T, C]`
         - y_lengths: :math:`B`
         - g: :math:`[B, C] or B`
     """
     y = y.transpose(1, 2)
     y_max_length = y.size(2)
     g = self._speaker_embedding(aux_input)
     y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length),
                              1).to(y.dtype)
     # decoder pass
     z, logdet = self.decoder(y, y_mask, g=g, reverse=False)
     # reverse decoder and predict
     y, logdet = self.decoder(z, y_mask, g=g, reverse=True)
     outputs = {}
     outputs["model_outputs"] = y.transpose(1, 2)
     outputs["logdet"] = logdet
     return outputs
Ejemplo n.º 4
0
def seqeunce_mask_test():
    lengths = T.randint(10, 15, (8, ))
    mask = sequence_mask(lengths)
    for i in range(8):
        l = lengths[i].item()
        assert mask[i, :l].sum() == l
        assert mask[i, l:].sum() == 0
Ejemplo n.º 5
0
 def _forward_mdn(self, o_en, y, y_lengths, x_mask):
     # MAS potentials and alignment
     mu, log_sigma = self.mdn_block(o_en)
     y_mask = torch.unsqueeze(sequence_mask(y_lengths, None),
                              1).to(o_en.dtype)
     dr_mas, logp = self.compute_align_path(mu, log_sigma, y, x_mask,
                                            y_mask)
     return dr_mas, mu, log_sigma, logp
Ejemplo n.º 6
0
    def forward(
        self,
        mel_slice,
        mel_slice_hat,
        z_p,
        logs_q,
        m_p,
        logs_p,
        z_len,
        scores_disc_fake,
        feats_disc_fake,
        feats_disc_real,
        loss_duration,
        use_speaker_encoder_as_loss=False,
        gt_spk_emb=None,
        syn_spk_emb=None,
    ):
        """
        Shapes:
            - mel_slice : :math:`[B, 1, T]`
            - mel_slice_hat: :math:`[B, 1, T]`
            - z_p: :math:`[B, C, T]`
            - logs_q: :math:`[B, C, T]`
            - m_p: :math:`[B, C, T]`
            - logs_p: :math:`[B, C, T]`
            - z_len: :math:`[B]`
            - scores_disc_fake[i]: :math:`[B, C]`
            - feats_disc_fake[i][j]: :math:`[B, C, T', P]`
            - feats_disc_real[i][j]: :math:`[B, C, T', P]`
        """
        loss = 0.0
        return_dict = {}
        z_mask = sequence_mask(z_len).float()
        # compute losses
        loss_kl = (
            self.kl_loss(z_p=z_p, logs_q=logs_q, m_p=m_p, logs_p=logs_p, z_mask=z_mask.unsqueeze(1))
            * self.kl_loss_alpha
        )
        loss_feat = (
            self.feature_loss(feats_real=feats_disc_real, feats_generated=feats_disc_fake) * self.feat_loss_alpha
        )
        loss_gen = self.generator_loss(scores_fake=scores_disc_fake)[0] * self.gen_loss_alpha
        loss_mel = torch.nn.functional.l1_loss(mel_slice, mel_slice_hat) * self.mel_loss_alpha
        loss_duration = torch.sum(loss_duration.float()) * self.dur_loss_alpha
        loss = loss_kl + loss_feat + loss_mel + loss_gen + loss_duration

        if use_speaker_encoder_as_loss:
            loss_se = self.cosine_similarity_loss(gt_spk_emb, syn_spk_emb) * self.spk_encoder_loss_alpha
            loss = loss + loss_se
            return_dict["loss_spk_encoder"] = loss_se
        # pass losses to the dict
        return_dict["loss_gen"] = loss_gen
        return_dict["loss_kl"] = loss_kl
        return_dict["loss_feat"] = loss_feat
        return_dict["loss_mel"] = loss_mel
        return_dict["loss_duration"] = loss_duration
        return_dict["loss"] = loss
        return return_dict
Ejemplo n.º 7
0
 def generate_attn(dr, x_mask, y_mask=None):
     # compute decode mask from the durations
     if y_mask is None:
         y_lengths = dr.sum(1).long()
         y_lengths[y_lengths < 1] = 1
         y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(dr.dtype)
     attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
     attn = generate_path(dr, attn_mask.squeeze(1)).to(dr.dtype)
     return attn
Ejemplo n.º 8
0
 def forward(self, x, y, length=None):
     """
     Shapes:
         x: B x T
         y: B x T
         length: B
     """
     mask = sequence_mask(sequence_length=length, max_len=y.size(1)).unsqueeze(2).float()
     return torch.nn.functional.smooth_l1_loss(x * mask, y * mask, reduction="sum") / mask.sum()
Ejemplo n.º 9
0
def test_decoder():
    input_dummy = torch.rand(8, 128, 37).to(device)
    input_lengths = torch.randint(31, 37, (8,)).long().to(device)
    input_lengths[-1] = 37

    input_mask = torch.unsqueeze(sequence_mask(input_lengths, input_dummy.size(2)), 1).to(device)
    # residual bn conv decoder
    layer = Decoder(out_channels=11, in_hidden_channels=128).to(device)
    output = layer(input_dummy, input_mask)
    assert list(output.shape) == [8, 11, 37]
    # transformer decoder
    layer = Decoder(
        out_channels=11,
        in_hidden_channels=128,
        decoder_type="relative_position_transformer",
        decoder_params={
            "hidden_channels_ffn": 128,
            "num_heads": 2,
            "kernel_size": 3,
            "dropout_p": 0.1,
            "num_layers": 8,
            "rel_attn_window_size": 4,
            "input_length": None,
        },
    ).to(device)
    output = layer(input_dummy, input_mask)
    assert list(output.shape) == [8, 11, 37]
    # wavenet decoder
    layer = Decoder(
        out_channels=11,
        in_hidden_channels=128,
        decoder_type="wavenet",
        decoder_params={
            "num_blocks": 12,
            "hidden_channels": 192,
            "kernel_size": 5,
            "dilation_rate": 1,
            "num_layers": 4,
            "dropout_p": 0.05,
        },
    ).to(device)
    output = layer(input_dummy, input_mask)
    # FFTransformer decoder
    layer = Decoder(
        out_channels=11,
        in_hidden_channels=128,
        decoder_type="fftransformer",
        decoder_params={
            "hidden_channels_ffn": 31,
            "num_heads": 2,
            "dropout_p": 0.1,
            "num_layers": 2,
        },
    ).to(device)
    output = layer(input_dummy, input_mask)
    assert list(output.shape) == [8, 11, 37]
Ejemplo n.º 10
0
 def forward(
     self, x, x_lengths, y, y_lengths, aux_input={"d_vectors": None}, phase=None
 ):  # pylint: disable=unused-argument
     """
     Shapes:
         - x: :math:`[B, T_max]`
         - x_lengths: :math:`[B]`
         - y_lengths: :math:`[B]`
         - dr: :math:`[B, T_max]`
         - g: :math:`[B, C]`
     """
     y = y.transpose(1, 2)
     g = aux_input["d_vectors"] if "d_vectors" in aux_input else None
     o_de, o_dr_log, dr_mas_log, attn, mu, log_sigma, logp = None, None, None, None, None, None, None
     if phase == 0:
         # train encoder and MDN
         o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
         dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask)
         y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en_dp.dtype)
         attn = self.generate_attn(dr_mas, x_mask, y_mask)
     elif phase == 1:
         # train decoder
         o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
         dr_mas, _, _, _ = self._forward_mdn(o_en, y, y_lengths, x_mask)
         o_de, attn = self._forward_decoder(o_en.detach(), o_en_dp.detach(), dr_mas.detach(), x_mask, y_lengths, g=g)
     elif phase == 2:
         # train the whole except duration predictor
         o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
         dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask)
         o_de, attn = self._forward_decoder(o_en, o_en_dp, dr_mas, x_mask, y_lengths, g=g)
     elif phase == 3:
         # train duration predictor
         o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
         o_dr_log = self.duration_predictor(x, x_mask)
         dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask)
         o_de, attn = self._forward_decoder(o_en, o_en_dp, dr_mas, x_mask, y_lengths, g=g)
         o_dr_log = o_dr_log.squeeze(1)
     else:
         o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
         o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask)
         dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask)
         o_de, attn = self._forward_decoder(o_en, o_en_dp, dr_mas, x_mask, y_lengths, g=g)
         o_dr_log = o_dr_log.squeeze(1)
     dr_mas_log = torch.log(dr_mas + 1).squeeze(1)
     outputs = {
         "model_outputs": o_de.transpose(1, 2),
         "alignments": attn,
         "durations_log": o_dr_log,
         "durations_mas_log": dr_mas_log,
         "mu": mu,
         "log_sigma": log_sigma,
         "logp": logp,
     }
     return outputs
Ejemplo n.º 11
0
def generate_path_test():
    durations = T.randint(1, 4, (10, 21))
    x_length = T.randint(18, 22, (10, ))
    x_mask = sequence_mask(x_length).unsqueeze(1).long()
    durations = durations * x_mask.squeeze(1)
    y_length = durations.sum(1)
    y_mask = sequence_mask(y_length).unsqueeze(1).long()
    attn_mask = (T.unsqueeze(x_mask, -1) *
                 T.unsqueeze(y_mask, 2)).squeeze(1).long()
    print(attn_mask.shape)
    path = generate_path(durations, attn_mask)
    assert path.shape == (10, 21, durations.sum(1).max().item())
    for b in range(durations.shape[0]):
        current_idx = 0
        for t in range(durations.shape[1]):
            assert all(path[b, t, current_idx:current_idx +
                            durations[b, t].item()] == 1.0)
            assert all(path[b, t, :current_idx] == 0.0)
            assert all(path[b, t,
                            current_idx + durations[b, t].item():] == 0.0)
            current_idx += durations[b, t].item()
Ejemplo n.º 12
0
    def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}):
        """
        Shapes:
            - x: :math:`[B, T_seq]`
            - d_vectors: :math:`[B, C, 1]`
            - speaker_ids: :math:`[B]`
        """
        sid, g = self._set_cond_input(aux_input)
        x_lengths = torch.tensor(x.shape[1:2]).to(x.device)

        x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths)

        if self.num_speakers > 0 and (sid is not None):
            g = self.emb_g(sid).unsqueeze(-1)

        if self.args.use_sdp:
            logw = self.duration_predictor(
                x,
                x_mask,
                g=g,
                reverse=True,
                noise_scale=self.inference_noise_scale_dp)
        else:
            logw = self.duration_predictor(x, x_mask, g=g)

        w = torch.exp(logw) * x_mask * self.length_scale
        w_ceil = torch.ceil(w)
        y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
        y_mask = sequence_mask(y_lengths, None).to(x_mask.dtype)
        attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
        attn = generate_path(w_ceil.squeeze(1),
                             attn_mask.squeeze(1).transpose(1, 2))

        m_p = torch.matmul(attn.transpose(1, 2),
                           m_p.transpose(1, 2)).transpose(1, 2)
        logs_p = torch.matmul(attn.transpose(1, 2),
                              logs_p.transpose(1, 2)).transpose(1, 2)

        z_p = (m_p + torch.randn_like(m_p) * torch.exp(logs_p) *
               self.inference_noise_scale)
        z = self.flow(z_p, y_mask, g=g, reverse=True)
        o = self.waveform_decoder((z * y_mask)[:, :, :self.max_inference_len],
                                  g=g)

        outputs = {
            "model_outputs": o,
            "alignments": attn.squeeze(1),
            "z": z,
            "z_p": z_p,
            "m_p": m_p,
            "logs_p": logs_p,
        }
        return outputs
Ejemplo n.º 13
0
 def forward(
     self,
     waveform,
     waveform_hat,
     z_p,
     logs_q,
     m_p,
     logs_p,
     z_len,
     scores_disc_fake,
     feats_disc_fake,
     feats_disc_real,
     loss_duration,
 ):
     """
     Shapes:
         - waveform : :math:`[B, 1, T]`
         - waveform_hat: :math:`[B, 1, T]`
         - z_p: :math:`[B, C, T]`
         - logs_q: :math:`[B, C, T]`
         - m_p: :math:`[B, C, T]`
         - logs_p: :math:`[B, C, T]`
         - z_len: :math:`[B]`
         - scores_disc_fake[i]: :math:`[B, C]`
         - feats_disc_fake[i][j]: :math:`[B, C, T', P]`
         - feats_disc_real[i][j]: :math:`[B, C, T', P]`
     """
     loss = 0.0
     return_dict = {}
     z_mask = sequence_mask(z_len).float()
     # compute mel spectrograms from the waveforms
     mel = self.stft(waveform)
     mel_hat = self.stft(waveform_hat)
     # compute losses
     loss_feat = self.feature_loss(feats_disc_fake,
                                   feats_disc_real) * self.feat_loss_alpha
     loss_gen = self.generator_loss(
         scores_disc_fake)[0] * self.gen_loss_alpha
     loss_kl = self.kl_loss(z_p, logs_q, m_p, logs_p,
                            z_mask.unsqueeze(1)) * self.kl_loss_alpha
     loss_mel = torch.nn.functional.l1_loss(mel,
                                            mel_hat) * self.mel_loss_alpha
     loss_duration = torch.sum(loss_duration.float()) * self.dur_loss_alpha
     loss = loss_kl + loss_feat + loss_mel + loss_gen + loss_duration
     # pass losses to the dict
     return_dict["loss_gen"] = loss_gen
     return_dict["loss_kl"] = loss_kl
     return_dict["loss_feat"] = loss_feat
     return_dict["loss_mel"] = loss_mel
     return_dict["loss_duration"] = loss_duration
     return_dict["loss"] = loss
     return return_dict
Ejemplo n.º 14
0
 def _forward_decoder(self, o_en, o_en_dp, dr, x_mask, y_lengths, g):
     y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en_dp.dtype)
     # expand o_en with durations
     o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask)
     # positional encoding
     if hasattr(self, "pos_encoder"):
         o_en_ex = self.pos_encoder(o_en_ex, y_mask)
     # speaker embedding
     if g is not None:
         o_en_ex = self._sum_speaker_embedding(o_en_ex, g)
     # decoder pass
     o_de = self.decoder(o_en_ex, y_mask, g=g)
     return o_de, attn.transpose(1, 2)
Ejemplo n.º 15
0
    def inference(self,
                  x,
                  aux_input={
                      "x_lengths": None,
                      "d_vectors": None,
                      "speaker_ids": None
                  }):  # pylint: disable=dangerous-default-value
        x_lengths = aux_input["x_lengths"]
        g = aux_input[
            "d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None

        if g is not None:
            if self.d_vector_dim:
                g = F.normalize(g).unsqueeze(-1)
            else:
                g = F.normalize(self.emb_g(g)).unsqueeze(-1)  # [b, h]

        # embedding pass
        o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x,
                                                              x_lengths,
                                                              g=g)
        # compute output durations
        w = (torch.exp(o_dur_log) - 1) * x_mask * self.length_scale
        w_ceil = torch.ceil(w)
        y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
        y_max_length = None
        # compute masks
        y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length),
                                 1).to(x_mask.dtype)
        attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
        # compute attention mask
        attn = generate_path(w_ceil.squeeze(1),
                             attn_mask.squeeze(1)).unsqueeze(1)
        y_mean, y_log_scale, o_attn_dur = self.compute_outputs(
            attn, o_mean, o_log_scale, x_mask)

        z = (y_mean + torch.exp(y_log_scale) * torch.randn_like(y_mean) *
             self.inference_noise_scale) * y_mask
        # decoder pass
        y, logdet = self.decoder(z, y_mask, g=g, reverse=True)
        attn = attn.squeeze(1).permute(0, 2, 1)
        outputs = {
            "model_outputs": y.transpose(1, 2),
            "logdet": logdet,
            "y_mean": y_mean.transpose(1, 2),
            "y_log_scale": y_log_scale.transpose(1, 2),
            "alignments": attn,
            "durations_log": o_dur_log.transpose(1, 2),
            "total_durations_log": o_attn_dur.transpose(1, 2),
        }
        return outputs
Ejemplo n.º 16
0
 def forward(self, x, x_lengths, g=None):
     """
     Shapes:
         - x: :math:`[B, C, T]`
         - x_lengths: :math:`[B, 1]`
         - g: :math:`[B, C, 1]`
     """
     x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
     x = self.pre(x) * x_mask
     x = self.enc(x, x_mask, g=g)
     stats = self.proj(x) * x_mask
     mean, log_scale = torch.split(stats, self.out_channels, dim=1)
     z = (mean + torch.randn_like(mean) * torch.exp(log_scale)) * x_mask
     return z, mean, log_scale, x_mask
Ejemplo n.º 17
0
    def forward(self, x, x_lengths):
        """
        Shapes:
            - x: :math:`[B, T]`
            - x_length: :math:`[B]`
        """
        x = self.emb(x) * math.sqrt(self.hidden_channels)  # [b, t, h]
        x = torch.transpose(x, 1, -1)  # [b, h, t]
        x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)

        x = self.encoder(x * x_mask, x_mask)
        stats = self.proj(x) * x_mask

        m, logs = torch.split(stats, self.out_channels, dim=1)
        return x, m, logs, x_mask
Ejemplo n.º 18
0
    def generate_attn(dr, x_mask, y_mask=None):
        """Generate an attention mask from the durations.

        Shapes
           - dr: :math:`(B, T_{en})`
           - x_mask: :math:`(B, T_{en})`
           - y_mask: :math:`(B, T_{de})`
        """
        # compute decode mask from the durations
        if y_mask is None:
            y_lengths = dr.sum(1).long()
            y_lengths[y_lengths < 1] = 1
            y_mask = torch.unsqueeze(sequence_mask(y_lengths, None),
                                     1).to(dr.dtype)
        attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
        attn = generate_path(dr, attn_mask.squeeze(1)).to(dr.dtype)
        return attn
Ejemplo n.º 19
0
 def forward(self, y_hat, y, length=None):
     """
     Args:
         y_hat (tensor): model prediction values.
         y (tensor): target values.
         length (tensor): length of each sample in a batch.
     Shapes:
         y_hat: B x T X D
         y: B x T x D
         length: B
      Returns:
         loss: An average loss value in range [0, 1] masked by the length.
     """
     if length is not None:
         m = sequence_mask(sequence_length=length, max_len=y.size(1)).unsqueeze(2).float().to(y_hat.device)
         y_hat, y = y_hat * m, y * m
     return 1 - self.loss_func(y_hat.unsqueeze(1), y.unsqueeze(1))
Ejemplo n.º 20
0
    def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}):  # pylint: disable=unused-argument
        """Model's inference pass.

        Args:
            x (torch.LongTensor): Input character sequence.
            aux_input (Dict): Auxiliary model inputs. Defaults to `{"d_vectors": None, "speaker_ids": None}`.

        Shapes:
            - x: [B, T_max]
            - x_lengths: [B]
            - g: [B, C]
        """
        g = self._set_speaker_input(aux_input)
        x_lengths = torch.tensor(x.shape[1:2]).to(x.device)
        x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]),
                                 1).to(x.dtype).float()
        # encoder pass
        o_en, x_mask, g, _ = self._forward_encoder(x, x_mask, g)
        # duration predictor pass
        o_dr_log = self.duration_predictor(o_en, x_mask)
        o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1)
        y_lengths = o_dr.sum(1)
        # pitch predictor pass
        o_pitch = None
        if self.args.use_pitch:
            o_pitch_emb, o_pitch = self._forward_pitch_predictor(o_en, x_mask)
            o_en = o_en + o_pitch_emb
        # decoder pass
        o_de, attn = self._forward_decoder(o_en,
                                           o_dr,
                                           x_mask,
                                           y_lengths,
                                           g=None)
        outputs = {
            "model_outputs": o_de,
            "alignments": attn,
            "pitch": o_pitch,
            "durations_log": o_dr_log,
        }
        return outputs
Ejemplo n.º 21
0
def test_encoder():
    input_dummy = torch.rand(8, 14, 37).to(device)
    input_lengths = torch.randint(31, 37, (8,)).long().to(device)
    input_lengths[-1] = 37
    input_mask = torch.unsqueeze(sequence_mask(input_lengths, input_dummy.size(2)), 1).to(device)
    # relative positional transformer encoder
    layer = Encoder(
        out_channels=11,
        in_hidden_channels=14,
        encoder_type="relative_position_transformer",
        encoder_params={
            "hidden_channels_ffn": 768,
            "num_heads": 2,
            "kernel_size": 3,
            "dropout_p": 0.1,
            "num_layers": 6,
            "rel_attn_window_size": 4,
            "input_length": None,
        },
    ).to(device)
    output = layer(input_dummy, input_mask)
    assert list(output.shape) == [8, 11, 37]
    # residual conv bn encoder
    layer = Encoder(
        out_channels=11,
        in_hidden_channels=14,
        encoder_type="residual_conv_bn",
        encoder_params={"kernel_size": 4, "dilations": 4 * [1, 2, 4] + [1], "num_conv_blocks": 2, "num_res_blocks": 13},
    ).to(device)
    output = layer(input_dummy, input_mask)
    assert list(output.shape) == [8, 11, 37]
    # FFTransformer encoder
    layer = Encoder(
        out_channels=14,
        in_hidden_channels=14,
        encoder_type="fftransformer",
        encoder_params={"hidden_channels_ffn": 31, "num_heads": 2, "num_layers": 2, "dropout_p": 0.1},
    ).to(device)
    output = layer(input_dummy, input_mask)
    assert list(output.shape) == [8, 14, 37]
Ejemplo n.º 22
0
    def _forward_decoder(
        self,
        o_en: torch.FloatTensor,
        dr: torch.IntTensor,
        x_mask: torch.FloatTensor,
        y_lengths: torch.IntTensor,
        g: torch.FloatTensor,
    ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
        """Decoding forward pass.

        1. Compute the decoder output mask
        2. Expand encoder output with the durations.
        3. Apply position encoding.
        4. Add speaker embeddings if multi-speaker mode.
        5. Run the decoder.

        Args:
            o_en (torch.FloatTensor): Encoder output.
            dr (torch.IntTensor): Ground truth durations or alignment network durations.
            x_mask (torch.IntTensor): Input sequence mask.
            y_lengths (torch.IntTensor): Output sequence lengths.
            g (torch.FloatTensor): Conditioning vectors. In general speaker embeddings.

        Returns:
            Tuple[torch.FloatTensor, torch.FloatTensor]: Decoder output, attention map from durations.
        """
        y_mask = torch.unsqueeze(sequence_mask(y_lengths, None),
                                 1).to(o_en.dtype)
        # expand o_en with durations
        o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask)
        # positional encoding
        if hasattr(self, "pos_encoder"):
            o_en_ex = self.pos_encoder(o_en_ex, y_mask)
        # speaker embedding
        if g is not None:
            o_en_ex = self._sum_speaker_embedding(o_en_ex, g)
        # decoder pass
        o_de = self.decoder(o_en_ex, y_mask, g=g)
        return o_de.transpose(1, 2), attn.transpose(1, 2)
Ejemplo n.º 23
0
 def forward(self, x, x_lengths, g=None):
     """
     Shapes:
         - x: :math:`[B, C, T]`
         - x_lengths: :math:`[B]`
         - g (optional): :math:`[B, 1, T]`
     """
     # embedding layer
     # [B ,T, D]
     x = self.emb(x) * math.sqrt(self.hidden_channels)
     # [B, D, T]
     x = torch.transpose(x, 1, -1)
     # compute input sequence mask
     x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)),
                              1).to(x.dtype)
     # prenet
     if hasattr(self, "prenet") and self.use_prenet:
         x = self.prenet(x, x_mask)
     # encoder
     x = self.encoder(x, x_mask)
     # postnet
     if hasattr(self, "postnet"):
         x = self.postnet(x) * x_mask
     # set duration predictor input
     if g is not None:
         g_exp = g.expand(-1, -1, x.size(-1))
         x_dp = torch.cat([x.detach(), g_exp], 1)
     else:
         x_dp = x.detach()
     # final projection layer
     x_m = self.proj_m(x) * x_mask
     if not self.mean_only:
         x_logs = self.proj_s(x) * x_mask
     else:
         x_logs = torch.zeros_like(x_m)
     # duration predictor
     logw = self.duration_predictor(x_dp, x_mask)
     return x_m, x_logs, logw, x_mask
Ejemplo n.º 24
0
    def forward(self, x, x_lengths, lang_emb=None):
        """
        Shapes:
            - x: :math:`[B, T]`
            - x_length: :math:`[B]`
        """
        x = self.emb(x) * math.sqrt(self.hidden_channels)  # [b, t, h]

        # concat the lang emb in embedding chars
        if lang_emb is not None:
            x = torch.cat(
                (x, lang_emb.transpose(2, 1).expand(x.size(0), x.size(1), -1)),
                dim=-1)

        x = torch.transpose(x, 1, -1)  # [b, h, t]
        x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)),
                                 1).to(x.dtype)

        x = self.encoder(x * x_mask, x_mask)
        stats = self.proj(x) * x_mask

        m, logs = torch.split(stats, self.out_channels, dim=1)
        return x, m, logs, x_mask
Ejemplo n.º 25
0
    def decoder_inference(self,
                          y,
                          y_lengths=None,
                          aux_input={
                              "d_vectors": None,
                              "speaker_ids": None
                          }):  # pylint: disable=dangerous-default-value
        """
        Shapes:
            - y: :math:`[B, T, C]`
            - y_lengths: :math:`B`
            - g: :math:`[B, C] or B`
        """
        y = y.transpose(1, 2)
        y_max_length = y.size(2)
        g = aux_input[
            "d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None
        # norm speaker embeddings
        if g is not None:
            if self.external_d_vector_dim:
                g = F.normalize(g).unsqueeze(-1)
            else:
                g = F.normalize(self.emb_g(g)).unsqueeze(-1)  # [b, h, 1]

        y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length),
                                 1).to(y.dtype)

        # decoder pass
        z, logdet = self.decoder(y, y_mask, g=g, reverse=False)

        # reverse decoder and predict
        y, logdet = self.decoder(z, y_mask, g=g, reverse=True)

        outputs = {}
        outputs["model_outputs"] = y.transpose(1, 2)
        outputs["logdet"] = logdet
        return outputs
Ejemplo n.º 26
0
    def _forward_encoder(self, x, x_lengths, g=None):
        if hasattr(self, "emb_g"):
            g = nn.functional.normalize(self.speaker_embedding(g))  # [B, C, 1]

        if g is not None:
            g = g.unsqueeze(-1)

        # [B, T, C]
        x_emb = self.emb(x)
        # [B, C, T]
        x_emb = torch.transpose(x_emb, 1, -1)

        # compute sequence masks
        x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype)

        # encoder pass
        o_en = self.encoder(x_emb, x_mask)

        # speaker conditioning for duration predictor
        if g is not None:
            o_en_dp = self._concat_speaker_embedding(o_en, g)
        else:
            o_en_dp = o_en
        return o_en, o_en_dp, x_mask, g
Ejemplo n.º 27
0
 def _make_masks(ilens, olens):
     in_masks = sequence_mask(ilens)
     out_masks = sequence_mask(olens)
     return out_masks.unsqueeze(-1) & in_masks.unsqueeze(-2)
Ejemplo n.º 28
0
    def forward(
            self,
            x: torch.LongTensor,
            x_lengths: torch.LongTensor,
            y_lengths: torch.LongTensor,
            y: torch.FloatTensor = None,
            dr: torch.IntTensor = None,
            pitch: torch.FloatTensor = None,
            aux_input: Dict = {
                "d_vectors": None,
                "speaker_ids": None
            },  # pylint: disable=unused-argument
    ) -> Dict:
        """Model's forward pass.

        Args:
            x (torch.LongTensor): Input character sequences.
            x_lengths (torch.LongTensor): Input sequence lengths.
            y_lengths (torch.LongTensor): Output sequnce lengths. Defaults to None.
            y (torch.FloatTensor): Spectrogram frames. Only used when the alignment network is on. Defaults to None.
            dr (torch.IntTensor): Character durations over the spectrogram frames. Only used when the alignment network is off. Defaults to None.
            pitch (torch.FloatTensor): Pitch values for each spectrogram frame. Only used when the pitch predictor is on. Defaults to None.
            aux_input (Dict): Auxiliary model inputs for multi-speaker training. Defaults to `{"d_vectors": 0, "speaker_ids": None}`.

        Shapes:
            - x: :math:`[B, T_max]`
            - x_lengths: :math:`[B]`
            - y_lengths: :math:`[B]`
            - y: :math:`[B, T_max2]`
            - dr: :math:`[B, T_max]`
            - g: :math:`[B, C]`
            - pitch: :math:`[B, 1, T]`
        """
        g = self._set_speaker_input(aux_input)
        # compute sequence masks
        y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).float()
        x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]),
                                 1).float()
        # encoder pass
        o_en, x_mask, g, x_emb = self._forward_encoder(x, x_mask, g)
        # duration predictor pass
        if self.args.detach_duration_predictor:
            o_dr_log = self.duration_predictor(o_en.detach(), x_mask)
        else:
            o_dr_log = self.duration_predictor(o_en, x_mask)
        o_dr = torch.clamp(torch.exp(o_dr_log) - 1, 0, self.max_duration)
        # generate attn mask from predicted durations
        o_attn = self.generate_attn(o_dr.squeeze(1), x_mask)
        # aligner
        o_alignment_dur = None
        alignment_soft = None
        alignment_logprob = None
        alignment_mas = None
        if self.use_aligner:
            o_alignment_dur, alignment_soft, alignment_logprob, alignment_mas = self._forward_aligner(
                x_emb, y, x_mask, y_mask)
            alignment_soft = alignment_soft.transpose(1, 2)
            alignment_mas = alignment_mas.transpose(1, 2)
            dr = o_alignment_dur
        # pitch predictor pass
        o_pitch = None
        avg_pitch = None
        if self.args.use_pitch:
            o_pitch_emb, o_pitch, avg_pitch = self._forward_pitch_predictor(
                o_en, x_mask, pitch, dr)
            o_en = o_en + o_pitch_emb
        # decoder pass
        o_de, attn = self._forward_decoder(
            o_en, dr, x_mask, y_lengths,
            g=None)  # TODO: maybe pass speaker embedding (g) too
        outputs = {
            "model_outputs": o_de,  # [B, T, C]
            "durations_log": o_dr_log.squeeze(1),  # [B, T]
            "durations": o_dr.squeeze(1),  # [B, T]
            "attn_durations": o_attn,  # for visualization [B, T_en, T_de']
            "pitch_avg": o_pitch,
            "pitch_avg_gt": avg_pitch,
            "alignments": attn,  # [B, T_de, T_en]
            "alignment_soft": alignment_soft,
            "alignment_mas": alignment_mas,
            "o_alignment_dur": o_alignment_dur,
            "alignment_logprob": alignment_logprob,
            "x_mask": x_mask,
            "y_mask": y_mask,
        }
        return outputs
Ejemplo n.º 29
0
    def inference_with_MAS(self,
                           x,
                           x_lengths,
                           y=None,
                           y_lengths=None,
                           aux_input={
                               "d_vectors": None,
                               "speaker_ids": None
                           }):  # pylint: disable=dangerous-default-value
        """
        It's similar to the teacher forcing in Tacotron.
        It was proposed in: https://arxiv.org/abs/2104.05557

        Shapes:
            - x: :math:`[B, T]`
            - x_lenghts: :math:`B`
            - y: :math:`[B, T, C]`
            - y_lengths: :math:`B`
            - g: :math:`[B, C] or B`
        """
        y = y.transpose(1, 2)
        y_max_length = y.size(2)
        # norm speaker embeddings
        g = aux_input[
            "d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None
        if self.use_speaker_embedding or self.use_d_vector_file:
            if not self.use_d_vector_file:
                g = F.normalize(g).unsqueeze(-1)
            else:
                g = F.normalize(self.emb_g(g)).unsqueeze(-1)  # [b, h, 1]
        # embedding pass
        o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x,
                                                              x_lengths,
                                                              g=g)
        # drop redisual frames wrt num_squeeze and set y_lengths.
        y, y_lengths, y_max_length, attn = self.preprocess(
            y, y_lengths, y_max_length, None)
        # create masks
        y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length),
                                 1).to(x_mask.dtype)
        attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
        # decoder pass
        z, logdet = self.decoder(y, y_mask, g=g, reverse=False)
        # find the alignment path between z and encoder output
        o_scale = torch.exp(-2 * o_log_scale)
        logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale,
                          [1]).unsqueeze(-1)  # [b, t, 1]
        logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 *
                             (z**2))  # [b, t, d] x [b, d, t'] = [b, t, t']
        logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2),
                             z)  # [b, t, d] x [b, d, t'] = [b, t, t']
        logp4 = torch.sum(-0.5 * (o_mean**2) * o_scale,
                          [1]).unsqueeze(-1)  # [b, t, 1]
        logp = logp1 + logp2 + logp3 + logp4  # [b, t, t']
        attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach()

        y_mean, y_log_scale, o_attn_dur = self.compute_outputs(
            attn, o_mean, o_log_scale, x_mask)
        attn = attn.squeeze(1).permute(0, 2, 1)

        # get predited aligned distribution
        z = y_mean * y_mask

        # reverse the decoder and predict using the aligned distribution
        y, logdet = self.decoder(z, y_mask, g=g, reverse=True)
        outputs = {
            "model_outputs": z.transpose(1, 2),
            "logdet": logdet,
            "y_mean": y_mean.transpose(1, 2),
            "y_log_scale": y_log_scale.transpose(1, 2),
            "alignments": attn,
            "durations_log": o_dur_log.transpose(1, 2),
            "total_durations_log": o_attn_dur.transpose(1, 2),
        }
        return outputs
Ejemplo n.º 30
0
 def forward(self,
             x,
             x_lengths,
             y,
             y_lengths=None,
             aux_input={
                 "d_vectors": None,
                 "speaker_ids": None
             }):  # pylint: disable=dangerous-default-value
     """
     Shapes:
         - x: :math:`[B, T]`
         - x_lenghts::math:`B`
         - y: :math:`[B, T, C]`
         - y_lengths::math:`B`
         - g: :math:`[B, C] or B`
     """
     # [B, T, C] -> [B, C, T]
     y = y.transpose(1, 2)
     y_max_length = y.size(2)
     # norm speaker embeddings
     g = aux_input[
         "d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None
     if self.use_speaker_embedding or self.use_d_vector_file:
         if not self.use_d_vector_file:
             g = F.normalize(g).unsqueeze(-1)
         else:
             g = F.normalize(self.emb_g(g)).unsqueeze(-1)  # [b, h, 1]
     # embedding pass
     o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x,
                                                           x_lengths,
                                                           g=g)
     # drop redisual frames wrt num_squeeze and set y_lengths.
     y, y_lengths, y_max_length, attn = self.preprocess(
         y, y_lengths, y_max_length, None)
     # create masks
     y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length),
                              1).to(x_mask.dtype)
     # [B, 1, T_en, T_de]
     attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
     # decoder pass
     z, logdet = self.decoder(y, y_mask, g=g, reverse=False)
     # find the alignment path
     with torch.no_grad():
         o_scale = torch.exp(-2 * o_log_scale)
         logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale,
                           [1]).unsqueeze(-1)  # [b, t, 1]
         logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 *
                              (z**2))  # [b, t, d] x [b, d, t'] = [b, t, t']
         logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2),
                              z)  # [b, t, d] x [b, d, t'] = [b, t, t']
         logp4 = torch.sum(-0.5 * (o_mean**2) * o_scale,
                           [1]).unsqueeze(-1)  # [b, t, 1]
         logp = logp1 + logp2 + logp3 + logp4  # [b, t, t']
         attn = maximum_path(logp,
                             attn_mask.squeeze(1)).unsqueeze(1).detach()
     y_mean, y_log_scale, o_attn_dur = self.compute_outputs(
         attn, o_mean, o_log_scale, x_mask)
     attn = attn.squeeze(1).permute(0, 2, 1)
     outputs = {
         "z": z.transpose(1, 2),
         "logdet": logdet,
         "y_mean": y_mean.transpose(1, 2),
         "y_log_scale": y_log_scale.transpose(1, 2),
         "alignments": attn,
         "durations_log": o_dur_log.transpose(1, 2),
         "total_durations_log": o_attn_dur.transpose(1, 2),
     }
     return outputs