示例#1
0
    def generate_spect(self, *, text, text_lengths, speaker=None, noise_scale=0.3, length_scale=1.0):

        if speaker is not None:
            speaker = F.normalize(self.emb_g(speaker)).unsqueeze(-1)  # [b, h]

        x_m, x_logs, log_durs_predicted, x_mask = self.encoder(
            text=text, text_lengths=text_lengths, speaker_embeddings=speaker
        )

        w = torch.exp(log_durs_predicted) * x_mask.squeeze() * length_scale
        w_ceil = torch.ceil(w)
        spect_lengths = torch.clamp_min(torch.sum(w_ceil, [1]), 1).long()
        y_max_length = None

        spect_lengths = (spect_lengths // self.decoder.n_sqz) * self.decoder.n_sqz

        y_mask = torch.unsqueeze(glow_tts_submodules.sequence_mask(spect_lengths, y_max_length), 1).to(x_mask.dtype)
        attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)

        attn = glow_tts_submodules.generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1))

        y_m = torch.matmul(x_m, attn)
        y_logs = torch.matmul(x_logs, attn)

        z = (y_m + torch.exp(y_logs) * torch.randn_like(y_m) * noise_scale) * y_mask
        y, _ = self.decoder(spect=z, spect_mask=y_mask, speaker_embeddings=speaker, reverse=True)

        return y, attn
示例#2
0
    def forward(self,
                *,
                text,
                text_lengths,
                spect,
                spect_lengths,
                speaker=None):

        if speaker is not None:
            speaker = F.normalize(self.emb_g(speaker)).unsqueeze(-1)  # [b, h]

        x_m, x_logs, log_durs_predicted, x_mask = self.encoder(
            text=text, text_lengths=text_lengths, speaker_embeddings=speaker)

        y_max_length = spect.size(2)
        y_max_length = (y_max_length //
                        self.decoder.n_sqz) * self.decoder.n_sqz
        spect = spect[:, :, :y_max_length]

        spect_lengths = (spect_lengths //
                         self.decoder.n_sqz) * self.decoder.n_sqz

        y_mask = torch.unsqueeze(
            glow_tts_submodules.sequence_mask(spect_lengths, y_max_length),
            1).to(x_mask.dtype)
        attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)

        z, logdet = self.decoder(spect=spect,
                                 spect_mask=y_mask,
                                 speaker_embeddings=speaker,
                                 reverse=False)

        with torch.no_grad():
            x_s_sq_r = torch.exp(-2 * x_logs)
            logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - x_logs,
                              [1]).unsqueeze(-1)  # [b, t, 1]
            logp2 = torch.matmul(x_s_sq_r.transpose(1, 2), -0.5 *
                                 (z**2))  # [b, t, d] x [b, d, t'] = [b, t, t']
            logp3 = torch.matmul((x_m * x_s_sq_r).transpose(1, 2),
                                 z)  # [b, t, d] x [b, d, t'] = [b, t, t']
            logp4 = torch.sum(-0.5 * (x_m**2) * x_s_sq_r,
                              [1]).unsqueeze(-1)  # [b, t, 1]
            logp = logp1 + logp2 + logp3 + logp4  # [b, t, t']

            attn = (glow_tts_submodules.maximum_path(
                logp, attn_mask.squeeze(1)).unsqueeze(1).detach()).squeeze(1)

        y_m = torch.matmul(x_m, attn)
        y_logs = torch.matmul(x_logs, attn)

        log_durs_extracted = torch.log(1e-8 +
                                       torch.sum(attn, -1)) * x_mask.squeeze()

        return z, y_m, y_logs, logdet, log_durs_predicted, log_durs_extracted, spect_lengths, attn
示例#3
0
    def forward(self, *, text, text_lengths, speaker_embeddings=None):

        x = self.emb(text) * math.sqrt(self.hidden_channels)  # [b, t, h]

        x = torch.transpose(x, 1, -1)  # [b, h, t]
        x_mask = torch.unsqueeze(
            glow_tts_submodules.sequence_mask(text_lengths, x.size(2)),
            1).to(x.dtype)

        if self.prenet:
            x = self.pre(x, x_mask)

        attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
        for i in range(self.n_layers):
            x = x * x_mask
            y = self.attn_layers[i](x, x, attn_mask)
            y = self.drop(y)
            x = self.norm_layers_1[i](x + y)

            y = self.ffn_layers[i](x, x_mask)
            y = self.drop(y)
            x = self.norm_layers_2[i](x + y)
        x = x * x_mask

        if speaker_embeddings is not None:
            g_exp = speaker_embeddings.expand(-1, -1, x.size(-1))
            x_dp = torch.cat([torch.detach(x), g_exp], 1)
        else:
            x_dp = torch.detach(x)

        x_m = self.proj_m(x) * x_mask
        if not self.mean_only:
            x_logs = self.proj_s(x) * x_mask
        else:
            x_logs = torch.zeros_like(x_m)

        logw = self.proj_w(spect=x_dp, mask=x_mask)

        return x_m, x_logs, logw, x_mask