Ejemplo n.º 1
0
def binary_divergence_masked(input, target, lengths):
    """ Provides non-vanishing gradient, but does not equal zero if spectrograms are the same
    Inspired by https://github.com/r9y9/deepvoice3_pytorch/blob/897f31e57eb6ec2f0cafa8dc62968e60f6a96407/train.py#L537
    """

    input_logits = logit(input)
    z = -target * input_logits + torch.log1p(torch.exp(input_logits))
    m = mask(input.shape, lengths, dim=1).float().to(input.device)

    return masked_mean(z, m)
Ejemplo n.º 2
0
def masked_huber(input, target, lengths):
    """
    Always mask the first (non-batch dimension) -> usually time

    :param input:
    :param target:
    :param lengths:
    :return:
    """
    m = mask(input.shape, lengths, dim=1).float().to(input.device)
    return F.smooth_l1_loss(input * m, target * m, reduction='sum') / m.sum()
Ejemplo n.º 3
0
def speedyspeech_tts(text_str, device_str):
    print('Loading model checkpoints')
    m = SpeedySpeech(device=device_str).load('models/speedyspeech.pth',
                                             device_str)
    m.eval()

    checkpoint = torch.load('models/melgan.pth', device_str)
    hp = HParam("mikuai/speedyspeech/melgan/config/default.yaml")
    melgan = Generator(hp.audio.n_mel_channels).to(device_str)
    melgan.load_state_dict(checkpoint["model_g"])
    melgan.eval(inference=False)

    print('Processing text')
    txt_processor = TextProcessor(HPText.graphemes,
                                  phonemize=HPText.use_phonemes)
    text = [text_str]

    phonemes, plen = txt_processor(text)
    # append more zeros - avoid cutoff at the end of the largest sequence
    phonemes = torch.cat((phonemes, torch.zeros(len(phonemes), 5).long()),
                         dim=-1)
    phonemes = phonemes.to('cpu')

    print('Synthesizing')
    # generate spectrograms
    with torch.no_grad():
        spec, durations = m((phonemes, plen))

    # invert to log(mel-spectrogram)
    spec = m.collate.norm.inverse(spec)

    # mask with pad value expected by MelGan
    msk = mask(spec.shape, durations.sum(dim=-1).long(), dim=1).to('cpu')
    spec = spec.masked_fill(~msk, -11.5129)

    # Append more pad frames to improve end of the longest sequence
    spec = torch.cat((spec.transpose(
        2, 1), -11.5129 * torch.ones(len(spec), HPStft.n_mel, 5).to('cpu')),
                     dim=-1)

    # generate audio
    with torch.no_grad():
        audio = melgan(spec).squeeze(1)

    print('Saving audio')
    # TODO: cut audios to proper length
    for i, a in enumerate(audio.detach().cpu().numpy()):
        write_wav(('output.wav'), a, HPStft.sample_rate, norm=False)
Ejemplo n.º 4
0
    def generate_naive(self, phonemes, len_phonemes, steps=1, window=(0,1)):
        """Naive generation without layer-level caching for testing purposes"""

        self.train(False)

        with torch.no_grad():
            phonemes = torch.as_tensor(phonemes)

            keys, values = self.txt_encoder(phonemes)

            if hp.positional_encoding:
                keys += positional_encoding(keys.shape[-1], keys.shape[1], w=hp.w).to(self.device)
                pe = positional_encoding(hp.channels, steps, w=1).to(self.device)

            dec = torch.zeros(len(phonemes), 1, hp.out_channels, device=self.device)

            weights = None

            att_mask = mask(shape=(len(phonemes), 1, keys.shape[1]),
                            lengths=len_phonemes,
                            dim=-1).to(self.device)

            for i in range(steps):
                print(i)
                queries = self.audio_encoder(dec)
                if hp.positional_encoding:
                    queries += pe[i]

                att, w = self.attention(queries, keys, values, att_mask)
                d = self.audio_decoder(att + queries)
                d = d[:, -1:]
                w = w[:, -1:]
                weights = w if weights is None else torch.cat((weights, w), dim=1)
                dec = torch.cat((dec, d), dim=1)

                if window is not None:
                    att_mask = median_mask(weights, window=window)

        return dec[:, 1:, :], weights
Ejemplo n.º 5
0
    def forward(self, phonemes, spectrograms, len_phonemes, training=False):
        """
        :param phonemes: (batch, alphabet, time), padded phonemes
        :param spectrograms: (batch, freq, time), padded spectrograms
        :param len_phonemes: list of phoneme lengths
        :return: decoded_spectrograms, attention_weights
        """
        spectrs = ZeroPad2d((0,0,1,0))(spectrograms)[:, :-1, :]  # move this to encoder?
        keys, values = self.txt_encoder(phonemes)
        queries = self.audio_encoder(spectrs)

        att_mask = mask(shape=(len(keys), queries.shape[1], keys.shape[1]),
                        lengths=len_phonemes,
                        dim=-1).to(self.device)

        if hp.positional_encoding:
            keys += positional_encoding(keys.shape[-1], keys.shape[1], w=hp.w).to(self.device)
            queries += positional_encoding(queries.shape[-1], queries.shape[1], w=1).to(self.device)

        attention, weights = self.attention(queries, keys, values, mask=att_mask)
        decoded = self.audio_decoder(attention + queries)
        return decoded, weights
Ejemplo n.º 6
0
    def generate(self, phonemes, len_phonemes, steps=False, window=3, spectrograms=None):
        """Sequentially generate spectrogram from phonemes

        If spectrograms are provided, they are used on input instead of self-generated frames (teacher forcing)
        If steps are provided with spectrograms, only 'steps' frames will be generated in supervised fashion
        Uses layer-level caching for faster inference.

        :param phonemes: Padded phoneme indices
        :param len_phonemes: Length of each sentence in `phonemes` (list of lengths)
        :param steps: How many steps to generate
        :param window: Window size for attention masking
        :param spectrograms: Padded spectrograms
        :return: Generated spectrograms
        """
        self.generating(True)
        self.train(False)

        assert steps or (spectrograms is not None)
        steps = steps if steps else spectrograms.shape[1]

        with torch.no_grad():
            phonemes = torch.as_tensor(phonemes)
            keys, values = self.txt_encoder(phonemes)

            if hp.positional_encoding:
                keys += positional_encoding(keys.shape[-1], keys.shape[1], w=hp.w).to(self.device)
                pe = positional_encoding(hp.channels, steps, w=1).to(self.device)

            if spectrograms is None:
                dec = torch.zeros(len(phonemes), 1, hp.out_channels, device=self.device)
            else:
                input = ZeroPad2d((0, 0, 1, 0))(spectrograms)[:, :-1, :]

            weights, decoded = None, None

            if window is not None:
                shape = (len(phonemes), 1, phonemes.shape[-1])
                idx = torch.zeros(len(phonemes), 1, phonemes.shape[-1]).to(phonemes.device)
                att_mask = idx_mask(shape, idx, window)
            else:
                att_mask = mask(shape=(len(phonemes), 1, keys.shape[1]),
                                lengths=len_phonemes,
                                dim=-1).to(self.device)

            for i in range(steps):
                if spectrograms is None:
                    queries = self.audio_encoder(dec)
                else:
                    queries = self.audio_encoder(input[:, i:i+1, :])

                if hp.positional_encoding:
                    queries += pe[i]

                att, w = self.attention(queries, keys, values, att_mask)
                dec = self.audio_decoder(att + queries)
                weights = w if weights is None else torch.cat((weights, w), dim=1)
                decoded = dec if decoded is None else torch.cat((decoded, dec), dim=1)
                if window is not None:
                    idx = torch.argmax(w, dim=-1).unsqueeze(2).float()
                    att_mask = idx_mask(shape, idx, window)

        self.generating(False)
        return decoded, weights
Ejemplo n.º 7
0
def l1_masked(input, target, lengths):
    m = mask(input.shape, lengths, dim=1).float().to(input.device)
    return F.l1_loss(input * m, target * m, reduction='sum') / m.sum()
Ejemplo n.º 8
0
    def forward(self, input, target, lengths):

        m = mask(input.shape, lengths, dim=1).float().to(input.device)
        return self.l1(input * m, target * m) / m.sum()
Ejemplo n.º 9
0
def masked_ssim(input, target, lengths):
    m = mask(input.shape, lengths, dim=1).float().to(input.device)
    input, target = input * m, target * m
    return 1 - ssim(input.unsqueeze(1), target.unsqueeze(1))
Ejemplo n.º 10
0
def mask_durations(durations, plen):
    m = mask(durations.shape, plen, dim=-1).to(durations.device).float()
    return durations * m