Esempio n. 1
0
    def forward(self, phonemes, spectrograms, len_phonemes, training=False):
        """
        :param phonemes: (batch, alphabet, time), padded phonemes
        :param spectrograms: (batch, freq, time), padded spectrograms
        :param len_phonemes: list of phoneme lengths
        :return: decoded_spectrograms, attention_weights
        """
        spectrs = ZeroPad2d(
            (0, 0, 1, 0))(spectrograms)[:, :-1, :]  # move this to encoder?
        keys, values = self.txt_encoder(phonemes)
        queries = self.audio_encoder(spectrs)

        att_mask = mask(shape=(len(keys), queries.shape[1], keys.shape[1]),
                        lengths=len_phonemes,
                        dim=-1).to(self.device)

        if hp.positional_encoding:
            keys += positional_encoding(keys.shape[-1], keys.shape[1],
                                        w=hp.w).to(self.device)
            queries += positional_encoding(queries.shape[-1],
                                           queries.shape[1],
                                           w=1).to(self.device)

        attention, weights = self.attention(queries,
                                            keys,
                                            values,
                                            mask=att_mask)
        decoded = self.audio_decoder(attention + queries)
        return decoded, weights
Esempio n. 2
0
def binary_divergence_masked(input, target, lengths):
    """ Provides non-vanishing gradient, but does not equal zero if spectrograms are the same
    Inspired by https://github.com/r9y9/deepvoice3_pytorch/blob/897f31e57eb6ec2f0cafa8dc62968e60f6a96407/train.py#L537
    """

    input_logits = logit(input)
    z = -target * input_logits + torch.log1p(torch.exp(input_logits))
    m = mask(input.shape, lengths, dim=1).float().to(input.device)

    return masked_mean(z, m)
Esempio n. 3
0
def masked_huber(input, target, lengths):
    """
    Always mask the first (non-batch dimension) -> usually time

    :param input:
    :param target:
    :param lengths:
    :return:
    """
    m = mask(input.shape, lengths, dim=1).float().to(input.device)
    return F.smooth_l1_loss(input * m, target * m, reduction='sum') / m.sum()
Esempio n. 4
0
    def generate_naive(self, phonemes, len_phonemes, steps=1, window=(0, 1)):
        """Naive generation without layer-level caching for testing purposes"""

        self.train(False)

        with torch.no_grad():
            phonemes = torch.as_tensor(phonemes)

            keys, values = self.txt_encoder(phonemes)

            if hp.positional_encoding:
                keys += positional_encoding(keys.shape[-1],
                                            keys.shape[1],
                                            w=hp.w).to(self.device)
                pe = positional_encoding(hp.channels, steps,
                                         w=1).to(self.device)

            dec = torch.zeros(len(phonemes),
                              1,
                              hp.out_channels,
                              device=self.device)

            weights = None

            att_mask = mask(shape=(len(phonemes), 1, keys.shape[1]),
                            lengths=len_phonemes,
                            dim=-1).to(self.device)

            for i in range(steps):
                print(i)
                queries = self.audio_encoder(dec)
                if hp.positional_encoding:
                    queries += pe[i]

                att, w = self.attention(queries, keys, values, att_mask)
                d = self.audio_decoder(att + queries)
                d = d[:, -1:]
                w = w[:, -1:]
                weights = w if weights is None else torch.cat(
                    (weights, w), dim=1)
                dec = torch.cat((dec, d), dim=1)

                if window is not None:
                    att_mask = median_mask(weights, window=window)

        return dec[:, 1:, :], weights
Esempio n. 5
0
    def synthesize(self, input_text):
        print(input_text)
        text = [input_text.strip()]
        phonemes, plen = self.txt_processor(text)

        # append more zeros - avoid cutoff at the end of the largest sequence
        phonemes = torch.cat((phonemes, torch.zeros(len(phonemes), 5).long()),
                             dim=-1)
        phonemes = phonemes.to(self.device)

        # generate spectrograms
        with torch.no_grad():
            spec, durations = self.speedyspeech((phonemes, plen))

        # invert to log(mel-spectrogram)
        spec = self.speedyspeech.collate.norm.inverse(spec)

        # mask with pad value expected by MelGan
        msk = mask(spec.shape, durations.sum(dim=-1).long(),
                   dim=1).to(self.device)
        spec = spec.masked_fill(~msk, -11.5129)

        # Append more pad frames to improve end of the longest sequence
        spec = torch.cat(
            (spec.transpose(2, 1), -11.5129 *
             torch.ones(len(spec), HPStft.n_mel, 5).to(self.device)),
            dim=-1)

        # generate audio
        with torch.no_grad():
            audio = self.melgan(spec).squeeze(1)
            audio = audio.detach().cpu().numpy()[0]

        # denormalize
        x = 2**self.bit_depth - 1
        audio = np.int16(audio * x)
        return audio
Esempio n. 6
0
    def generate(self,
                 phonemes,
                 len_phonemes,
                 steps=False,
                 window=3,
                 spectrograms=None):
        """Sequentially generate spectrogram from phonemes

        If spectrograms are provided, they are used on input instead of self-generated frames (teacher forcing)
        If steps are provided with spectrograms, only 'steps' frames will be generated in supervised fashion
        Uses layer-level caching for faster inference.

        :param phonemes: Padded phoneme indices
        :param len_phonemes: Length of each sentence in `phonemes` (list of lengths)
        :param steps: How many steps to generate
        :param window: Window size for attention masking
        :param spectrograms: Padded spectrograms
        :return: Generated spectrograms
        """
        self.generating(True)
        self.train(False)

        assert steps or (spectrograms is not None)
        steps = steps if steps else spectrograms.shape[1]

        with torch.no_grad():
            phonemes = torch.as_tensor(phonemes)
            keys, values = self.txt_encoder(phonemes)

            if hp.positional_encoding:
                keys += positional_encoding(keys.shape[-1],
                                            keys.shape[1],
                                            w=hp.w).to(self.device)
                pe = positional_encoding(hp.channels, steps,
                                         w=1).to(self.device)

            if spectrograms is None:
                dec = torch.zeros(len(phonemes),
                                  1,
                                  hp.out_channels,
                                  device=self.device)
            else:
                input = ZeroPad2d((0, 0, 1, 0))(spectrograms)[:, :-1, :]

            weights, decoded = None, None

            if window is not None:
                shape = (len(phonemes), 1, phonemes.shape[-1])
                idx = torch.zeros(len(phonemes), 1,
                                  phonemes.shape[-1]).to(phonemes.device)
                att_mask = idx_mask(shape, idx, window)
            else:
                att_mask = mask(shape=(len(phonemes), 1, keys.shape[1]),
                                lengths=len_phonemes,
                                dim=-1).to(self.device)

            for i in range(steps):
                if spectrograms is None:
                    queries = self.audio_encoder(dec)
                else:
                    queries = self.audio_encoder(input[:, i:i + 1, :])

                if hp.positional_encoding:
                    queries += pe[i]

                att, w = self.attention(queries, keys, values, att_mask)
                dec = self.audio_decoder(att + queries)
                weights = w if weights is None else torch.cat(
                    (weights, w), dim=1)
                decoded = dec if decoded is None else torch.cat(
                    (decoded, dec), dim=1)
                if window is not None:
                    idx = torch.argmax(w, dim=-1).unsqueeze(2).float()
                    att_mask = idx_mask(shape, idx, window)

        self.generating(False)
        return decoded, weights
Esempio n. 7
0
def l1_masked(input, target, lengths):
    m = mask(input.shape, lengths, dim=1).float().to(input.device)
    return F.l1_loss(input * m, target * m, reduction='sum') / m.sum()
Esempio n. 8
0
    def forward(self, input, target, lengths):

        m = mask(input.shape, lengths, dim=1).float().to(input.device)
        return self.l1(input * m, target * m) / m.sum()
Esempio n. 9
0
def masked_ssim(input, target, lengths):
    m = mask(input.shape, lengths, dim=1).float().to(input.device)
    input, target = input * m, target * m
    return 1 - ssim(input.unsqueeze(1), target.unsqueeze(1))
Esempio n. 10
0
phonemes, plen = txt_processor(text)
# append more zeros - avoid cutoff at the end of the largest sequence
phonemes = torch.cat((phonemes, torch.zeros(len(phonemes), 5).long()), dim=-1)
phonemes = phonemes.to(args.device)

print('Synthesizing')
# generate spectrograms
with torch.no_grad():
    spec, durations = m((phonemes, plen))

# invert to log(mel-spectrogram)
spec = m.collate.norm.inverse(spec)

# mask with pad value expected by MelGan
msk = mask(spec.shape, durations.sum(dim=-1).long(), dim=1).to(args.device)
spec = spec.masked_fill(~msk, -11.5129)

# Append more pad frames to improve end of the longest sequence
spec = torch.cat((spec.transpose(
    2, 1), -11.5129 * torch.ones(len(spec), HPStft.n_mel, 5).to(args.device)),
                 dim=-1)

# generate audio
with torch.no_grad():
    audio = melgan(spec).squeeze(1)

print('Saving audio')
# TODO: cut audios to proper length
for i, a in enumerate(audio.detach().cpu().numpy()):
    write_wav(os.path.join(args.audio_folder, f'{i}.wav'),
Esempio n. 11
0
def mask_durations(durations, plen):
    m = mask(durations.shape, plen, dim=-1).to(durations.device).float()
    return durations * m