Exemple #1
0
def expand_positional_encodings(durations, channels, repeat=False):
    """Expand positional encoding to align with phoneme durations

    Example:
        If repeat:
        phonemes a, b, c have durations 3,5,4
        The expanded encoding is
          a   a   a   b   b   b   b   b   c   c   c   c
        [e1, e2, e3, e1, e2, e3, e4, e5, e1, e2, e3, e4]

    Use Pad from transforms to get batched tensor.

    :param durations: (batch, time), 0-masked tensor
    :return: positional_encodings as list of tensors, (batch, time)
    """

    durations = durations.long()

    def rng(l):
        return list(range(l))

    if repeat:
        max_len = torch.max(durations)
        pe = positional_encoding(channels, max_len)
        idx = []
        for d in durations:
            idx.append(
                list(itertools.chain.from_iterable([rng(dd) for dd in d])))
        return [pe[i] for i in idx]
    else:
        max_len = torch.max(durations.sum(dim=-1))
        pe = positional_encoding(channels, max_len)
        return [pe[:s] for s in durations.sum(dim=-1)]
Exemple #2
0
    def generate_naive(self, phonemes, len_phonemes, steps=1, window=(0,1)):
        """Naive generation without layer-level caching for testing purposes"""

        self.train(False)

        with torch.no_grad():
            phonemes = torch.as_tensor(phonemes)

            keys, values = self.txt_encoder(phonemes)

            if hp.positional_encoding:
                keys += positional_encoding(keys.shape[-1], keys.shape[1], w=hp.w).to(self.device)
                pe = positional_encoding(hp.channels, steps, w=1).to(self.device)

            dec = torch.zeros(len(phonemes), 1, hp.out_channels, device=self.device)

            weights = None

            att_mask = mask(shape=(len(phonemes), 1, keys.shape[1]),
                            lengths=len_phonemes,
                            dim=-1).to(self.device)

            for i in range(steps):
                print(i)
                queries = self.audio_encoder(dec)
                if hp.positional_encoding:
                    queries += pe[i]

                att, w = self.attention(queries, keys, values, att_mask)
                d = self.audio_decoder(att + queries)
                d = d[:, -1:]
                w = w[:, -1:]
                weights = w if weights is None else torch.cat((weights, w), dim=1)
                dec = torch.cat((dec, d), dim=1)

                if window is not None:
                    att_mask = median_mask(weights, window=window)

        return dec[:, 1:, :], weights
Exemple #3
0
    def forward(self, phonemes, spectrograms, len_phonemes, training=False):
        """
        :param phonemes: (batch, alphabet, time), padded phonemes
        :param spectrograms: (batch, freq, time), padded spectrograms
        :param len_phonemes: list of phoneme lengths
        :return: decoded_spectrograms, attention_weights
        """
        spectrs = ZeroPad2d((0,0,1,0))(spectrograms)[:, :-1, :]  # move this to encoder?
        keys, values = self.txt_encoder(phonemes)
        queries = self.audio_encoder(spectrs)

        att_mask = mask(shape=(len(keys), queries.shape[1], keys.shape[1]),
                        lengths=len_phonemes,
                        dim=-1).to(self.device)

        if hp.positional_encoding:
            keys += positional_encoding(keys.shape[-1], keys.shape[1], w=hp.w).to(self.device)
            queries += positional_encoding(queries.shape[-1], queries.shape[1], w=1).to(self.device)

        attention, weights = self.attention(queries, keys, values, mask=att_mask)
        decoded = self.audio_decoder(attention + queries)
        return decoded, weights
Exemple #4
0
 def expand_enc(self, encodings, durations):
     """Copy each phoneme encoding as many times as the duration predictor predicts"""
     encodings = self.pad(expand_encodings(encodings, durations))
     if hp.pos_enc:
         if hp.pos_enc == 'ours':
             encodings += self.pad(
                 expand_positional_encodings(
                     durations, encodings.shape[-1])).to(encodings.device)
         elif hp.pos_enc == 'standard':
             encodings += positional_encoding(encodings.shape[-1],
                                              encodings.shape[1]).to(
                                                  encodings.device)
     return encodings
Exemple #5
0
    def generate(self, phonemes, len_phonemes, steps=False, window=3, spectrograms=None):
        """Sequentially generate spectrogram from phonemes

        If spectrograms are provided, they are used on input instead of self-generated frames (teacher forcing)
        If steps are provided with spectrograms, only 'steps' frames will be generated in supervised fashion
        Uses layer-level caching for faster inference.

        :param phonemes: Padded phoneme indices
        :param len_phonemes: Length of each sentence in `phonemes` (list of lengths)
        :param steps: How many steps to generate
        :param window: Window size for attention masking
        :param spectrograms: Padded spectrograms
        :return: Generated spectrograms
        """
        self.generating(True)
        self.train(False)

        assert steps or (spectrograms is not None)
        steps = steps if steps else spectrograms.shape[1]

        with torch.no_grad():
            phonemes = torch.as_tensor(phonemes)
            keys, values = self.txt_encoder(phonemes)

            if hp.positional_encoding:
                keys += positional_encoding(keys.shape[-1], keys.shape[1], w=hp.w).to(self.device)
                pe = positional_encoding(hp.channels, steps, w=1).to(self.device)

            if spectrograms is None:
                dec = torch.zeros(len(phonemes), 1, hp.out_channels, device=self.device)
            else:
                input = ZeroPad2d((0, 0, 1, 0))(spectrograms)[:, :-1, :]

            weights, decoded = None, None

            if window is not None:
                shape = (len(phonemes), 1, phonemes.shape[-1])
                idx = torch.zeros(len(phonemes), 1, phonemes.shape[-1]).to(phonemes.device)
                att_mask = idx_mask(shape, idx, window)
            else:
                att_mask = mask(shape=(len(phonemes), 1, keys.shape[1]),
                                lengths=len_phonemes,
                                dim=-1).to(self.device)

            for i in range(steps):
                if spectrograms is None:
                    queries = self.audio_encoder(dec)
                else:
                    queries = self.audio_encoder(input[:, i:i+1, :])

                if hp.positional_encoding:
                    queries += pe[i]

                att, w = self.attention(queries, keys, values, att_mask)
                dec = self.audio_decoder(att + queries)
                weights = w if weights is None else torch.cat((weights, w), dim=1)
                decoded = dec if decoded is None else torch.cat((decoded, dec), dim=1)
                if window is not None:
                    idx = torch.argmax(w, dim=-1).unsqueeze(2).float()
                    att_mask = idx_mask(shape, idx, window)

        self.generating(False)
        return decoded, weights