def expand_positional_encodings(durations, channels, repeat=False): """Expand positional encoding to align with phoneme durations Example: If repeat: phonemes a, b, c have durations 3,5,4 The expanded encoding is a a a b b b b b c c c c [e1, e2, e3, e1, e2, e3, e4, e5, e1, e2, e3, e4] Use Pad from transforms to get batched tensor. :param durations: (batch, time), 0-masked tensor :return: positional_encodings as list of tensors, (batch, time) """ durations = durations.long() def rng(l): return list(range(l)) if repeat: max_len = torch.max(durations) pe = positional_encoding(channels, max_len) idx = [] for d in durations: idx.append( list(itertools.chain.from_iterable([rng(dd) for dd in d]))) return [pe[i] for i in idx] else: max_len = torch.max(durations.sum(dim=-1)) pe = positional_encoding(channels, max_len) return [pe[:s] for s in durations.sum(dim=-1)]
def forward(self, phonemes, spectrograms, len_phonemes, training=False): """ :param phonemes: (batch, alphabet, time), padded phonemes :param spectrograms: (batch, freq, time), padded spectrograms :param len_phonemes: list of phoneme lengths :return: decoded_spectrograms, attention_weights """ spectrs = ZeroPad2d( (0, 0, 1, 0))(spectrograms)[:, :-1, :] # move this to encoder? keys, values = self.txt_encoder(phonemes) queries = self.audio_encoder(spectrs) att_mask = mask(shape=(len(keys), queries.shape[1], keys.shape[1]), lengths=len_phonemes, dim=-1).to(self.device) if hp.positional_encoding: keys += positional_encoding(keys.shape[-1], keys.shape[1], w=hp.w).to(self.device) queries += positional_encoding(queries.shape[-1], queries.shape[1], w=1).to(self.device) attention, weights = self.attention(queries, keys, values, mask=att_mask) decoded = self.audio_decoder(attention + queries) return decoded, weights
def generate_naive(self, phonemes, len_phonemes, steps=1, window=(0, 1)): """Naive generation without layer-level caching for testing purposes""" self.train(False) with torch.no_grad(): phonemes = torch.as_tensor(phonemes) keys, values = self.txt_encoder(phonemes) if hp.positional_encoding: keys += positional_encoding(keys.shape[-1], keys.shape[1], w=hp.w).to(self.device) pe = positional_encoding(hp.channels, steps, w=1).to(self.device) dec = torch.zeros(len(phonemes), 1, hp.out_channels, device=self.device) weights = None att_mask = mask(shape=(len(phonemes), 1, keys.shape[1]), lengths=len_phonemes, dim=-1).to(self.device) for i in range(steps): print(i) queries = self.audio_encoder(dec) if hp.positional_encoding: queries += pe[i] att, w = self.attention(queries, keys, values, att_mask) d = self.audio_decoder(att + queries) d = d[:, -1:] w = w[:, -1:] weights = w if weights is None else torch.cat( (weights, w), dim=1) dec = torch.cat((dec, d), dim=1) if window is not None: att_mask = median_mask(weights, window=window) return dec[:, 1:, :], weights
def expand_enc(self, encodings, durations): """Copy each phoneme encoding as many times as the duration predictor predicts""" encodings = self.pad(expand_encodings(encodings, durations)) if hp.pos_enc: if hp.pos_enc == 'ours': encodings += self.pad(expand_positional_encodings(durations, encodings.shape[-1])).to(encodings.device) elif hp.pos_enc == 'standard': encodings += positional_encoding(encodings.shape[-1], encodings.shape[1]).to(encodings.device) return encodings
def generate(self, phonemes, len_phonemes, steps=False, window=3, spectrograms=None): """Sequentially generate spectrogram from phonemes If spectrograms are provided, they are used on input instead of self-generated frames (teacher forcing) If steps are provided with spectrograms, only 'steps' frames will be generated in supervised fashion Uses layer-level caching for faster inference. :param phonemes: Padded phoneme indices :param len_phonemes: Length of each sentence in `phonemes` (list of lengths) :param steps: How many steps to generate :param window: Window size for attention masking :param spectrograms: Padded spectrograms :return: Generated spectrograms """ self.generating(True) self.train(False) assert steps or (spectrograms is not None) steps = steps if steps else spectrograms.shape[1] with torch.no_grad(): phonemes = torch.as_tensor(phonemes) keys, values = self.txt_encoder(phonemes) if hp.positional_encoding: keys += positional_encoding(keys.shape[-1], keys.shape[1], w=hp.w).to(self.device) pe = positional_encoding(hp.channels, steps, w=1).to(self.device) if spectrograms is None: dec = torch.zeros(len(phonemes), 1, hp.out_channels, device=self.device) else: input = ZeroPad2d((0, 0, 1, 0))(spectrograms)[:, :-1, :] weights, decoded = None, None if window is not None: shape = (len(phonemes), 1, phonemes.shape[-1]) idx = torch.zeros(len(phonemes), 1, phonemes.shape[-1]).to(phonemes.device) att_mask = idx_mask(shape, idx, window) else: att_mask = mask(shape=(len(phonemes), 1, keys.shape[1]), lengths=len_phonemes, dim=-1).to(self.device) for i in range(steps): if spectrograms is None: queries = self.audio_encoder(dec) else: queries = self.audio_encoder(input[:, i:i + 1, :]) if hp.positional_encoding: queries += pe[i] att, w = self.attention(queries, keys, values, att_mask) dec = self.audio_decoder(att + queries) weights = w if weights is None else torch.cat( (weights, w), dim=1) decoded = dec if decoded is None else torch.cat( (decoded, dec), dim=1) if window is not None: idx = torch.argmax(w, dim=-1).unsqueeze(2).float() att_mask = idx_mask(shape, idx, window) self.generating(False) return decoded, weights