示例#1
0
 def generate_attn(dr, x_mask, y_mask=None):
     # compute decode mask from the durations
     if y_mask is None:
         y_lengths = dr.sum(1).long()
         y_lengths[y_lengths < 1] = 1
         y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(dr.dtype)
     attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
     attn = generate_path(dr, attn_mask.squeeze(1)).to(dr.dtype)
     return attn
示例#2
0
    def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}):
        """
        Shapes:
            - x: :math:`[B, T_seq]`
            - d_vectors: :math:`[B, C, 1]`
            - speaker_ids: :math:`[B]`
        """
        sid, g = self._set_cond_input(aux_input)
        x_lengths = torch.tensor(x.shape[1:2]).to(x.device)

        x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths)

        if self.num_speakers > 0 and (sid is not None):
            g = self.emb_g(sid).unsqueeze(-1)

        if self.args.use_sdp:
            logw = self.duration_predictor(
                x,
                x_mask,
                g=g,
                reverse=True,
                noise_scale=self.inference_noise_scale_dp)
        else:
            logw = self.duration_predictor(x, x_mask, g=g)

        w = torch.exp(logw) * x_mask * self.length_scale
        w_ceil = torch.ceil(w)
        y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
        y_mask = sequence_mask(y_lengths, None).to(x_mask.dtype)
        attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
        attn = generate_path(w_ceil.squeeze(1),
                             attn_mask.squeeze(1).transpose(1, 2))

        m_p = torch.matmul(attn.transpose(1, 2),
                           m_p.transpose(1, 2)).transpose(1, 2)
        logs_p = torch.matmul(attn.transpose(1, 2),
                              logs_p.transpose(1, 2)).transpose(1, 2)

        z_p = (m_p + torch.randn_like(m_p) * torch.exp(logs_p) *
               self.inference_noise_scale)
        z = self.flow(z_p, y_mask, g=g, reverse=True)
        o = self.waveform_decoder((z * y_mask)[:, :, :self.max_inference_len],
                                  g=g)

        outputs = {
            "model_outputs": o,
            "alignments": attn.squeeze(1),
            "z": z,
            "z_p": z_p,
            "m_p": m_p,
            "logs_p": logs_p,
        }
        return outputs
示例#3
0
    def inference(self,
                  x,
                  aux_input={
                      "x_lengths": None,
                      "d_vectors": None,
                      "speaker_ids": None
                  }):  # pylint: disable=dangerous-default-value
        x_lengths = aux_input["x_lengths"]
        g = aux_input[
            "d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None

        if g is not None:
            if self.d_vector_dim:
                g = F.normalize(g).unsqueeze(-1)
            else:
                g = F.normalize(self.emb_g(g)).unsqueeze(-1)  # [b, h]

        # embedding pass
        o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x,
                                                              x_lengths,
                                                              g=g)
        # compute output durations
        w = (torch.exp(o_dur_log) - 1) * x_mask * self.length_scale
        w_ceil = torch.ceil(w)
        y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
        y_max_length = None
        # compute masks
        y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length),
                                 1).to(x_mask.dtype)
        attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
        # compute attention mask
        attn = generate_path(w_ceil.squeeze(1),
                             attn_mask.squeeze(1)).unsqueeze(1)
        y_mean, y_log_scale, o_attn_dur = self.compute_outputs(
            attn, o_mean, o_log_scale, x_mask)

        z = (y_mean + torch.exp(y_log_scale) * torch.randn_like(y_mean) *
             self.inference_noise_scale) * y_mask
        # decoder pass
        y, logdet = self.decoder(z, y_mask, g=g, reverse=True)
        attn = attn.squeeze(1).permute(0, 2, 1)
        outputs = {
            "model_outputs": y.transpose(1, 2),
            "logdet": logdet,
            "y_mean": y_mean.transpose(1, 2),
            "y_log_scale": y_log_scale.transpose(1, 2),
            "alignments": attn,
            "durations_log": o_dur_log.transpose(1, 2),
            "total_durations_log": o_attn_dur.transpose(1, 2),
        }
        return outputs
示例#4
0
    def generate_attn(dr, x_mask, y_mask=None):
        """Generate an attention mask from the durations.

        Shapes
           - dr: :math:`(B, T_{en})`
           - x_mask: :math:`(B, T_{en})`
           - y_mask: :math:`(B, T_{de})`
        """
        # compute decode mask from the durations
        if y_mask is None:
            y_lengths = dr.sum(1).long()
            y_lengths[y_lengths < 1] = 1
            y_mask = torch.unsqueeze(sequence_mask(y_lengths, None),
                                     1).to(dr.dtype)
        attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
        attn = generate_path(dr, attn_mask.squeeze(1)).to(dr.dtype)
        return attn
示例#5
0
def generate_path_test():
    durations = T.randint(1, 4, (10, 21))
    x_length = T.randint(18, 22, (10, ))
    x_mask = sequence_mask(x_length).unsqueeze(1).long()
    durations = durations * x_mask.squeeze(1)
    y_length = durations.sum(1)
    y_mask = sequence_mask(y_length).unsqueeze(1).long()
    attn_mask = (T.unsqueeze(x_mask, -1) *
                 T.unsqueeze(y_mask, 2)).squeeze(1).long()
    print(attn_mask.shape)
    path = generate_path(durations, attn_mask)
    assert path.shape == (10, 21, durations.sum(1).max().item())
    for b in range(durations.shape[0]):
        current_idx = 0
        for t in range(durations.shape[1]):
            assert all(path[b, t, current_idx:current_idx +
                            durations[b, t].item()] == 1.0)
            assert all(path[b, t, :current_idx] == 0.0)
            assert all(path[b, t,
                            current_idx + durations[b, t].item():] == 0.0)
            current_idx += durations[b, t].item()