def forward(self, phonemes, spectrograms, len_phonemes, training=False): """ :param phonemes: (batch, alphabet, time), padded phonemes :param spectrograms: (batch, freq, time), padded spectrograms :param len_phonemes: list of phoneme lengths :return: decoded_spectrograms, attention_weights """ spectrs = ZeroPad2d( (0, 0, 1, 0))(spectrograms)[:, :-1, :] # move this to encoder? keys, values = self.txt_encoder(phonemes) queries = self.audio_encoder(spectrs) att_mask = mask(shape=(len(keys), queries.shape[1], keys.shape[1]), lengths=len_phonemes, dim=-1).to(self.device) if hp.positional_encoding: keys += positional_encoding(keys.shape[-1], keys.shape[1], w=hp.w).to(self.device) queries += positional_encoding(queries.shape[-1], queries.shape[1], w=1).to(self.device) attention, weights = self.attention(queries, keys, values, mask=att_mask) decoded = self.audio_decoder(attention + queries) return decoded, weights
def binary_divergence_masked(input, target, lengths): """ Provides non-vanishing gradient, but does not equal zero if spectrograms are the same Inspired by https://github.com/r9y9/deepvoice3_pytorch/blob/897f31e57eb6ec2f0cafa8dc62968e60f6a96407/train.py#L537 """ input_logits = logit(input) z = -target * input_logits + torch.log1p(torch.exp(input_logits)) m = mask(input.shape, lengths, dim=1).float().to(input.device) return masked_mean(z, m)
def masked_huber(input, target, lengths): """ Always mask the first (non-batch dimension) -> usually time :param input: :param target: :param lengths: :return: """ m = mask(input.shape, lengths, dim=1).float().to(input.device) return F.smooth_l1_loss(input * m, target * m, reduction='sum') / m.sum()
def generate_naive(self, phonemes, len_phonemes, steps=1, window=(0, 1)): """Naive generation without layer-level caching for testing purposes""" self.train(False) with torch.no_grad(): phonemes = torch.as_tensor(phonemes) keys, values = self.txt_encoder(phonemes) if hp.positional_encoding: keys += positional_encoding(keys.shape[-1], keys.shape[1], w=hp.w).to(self.device) pe = positional_encoding(hp.channels, steps, w=1).to(self.device) dec = torch.zeros(len(phonemes), 1, hp.out_channels, device=self.device) weights = None att_mask = mask(shape=(len(phonemes), 1, keys.shape[1]), lengths=len_phonemes, dim=-1).to(self.device) for i in range(steps): print(i) queries = self.audio_encoder(dec) if hp.positional_encoding: queries += pe[i] att, w = self.attention(queries, keys, values, att_mask) d = self.audio_decoder(att + queries) d = d[:, -1:] w = w[:, -1:] weights = w if weights is None else torch.cat( (weights, w), dim=1) dec = torch.cat((dec, d), dim=1) if window is not None: att_mask = median_mask(weights, window=window) return dec[:, 1:, :], weights
def synthesize(self, input_text): print(input_text) text = [input_text.strip()] phonemes, plen = self.txt_processor(text) # append more zeros - avoid cutoff at the end of the largest sequence phonemes = torch.cat((phonemes, torch.zeros(len(phonemes), 5).long()), dim=-1) phonemes = phonemes.to(self.device) # generate spectrograms with torch.no_grad(): spec, durations = self.speedyspeech((phonemes, plen)) # invert to log(mel-spectrogram) spec = self.speedyspeech.collate.norm.inverse(spec) # mask with pad value expected by MelGan msk = mask(spec.shape, durations.sum(dim=-1).long(), dim=1).to(self.device) spec = spec.masked_fill(~msk, -11.5129) # Append more pad frames to improve end of the longest sequence spec = torch.cat( (spec.transpose(2, 1), -11.5129 * torch.ones(len(spec), HPStft.n_mel, 5).to(self.device)), dim=-1) # generate audio with torch.no_grad(): audio = self.melgan(spec).squeeze(1) audio = audio.detach().cpu().numpy()[0] # denormalize x = 2**self.bit_depth - 1 audio = np.int16(audio * x) return audio
def generate(self, phonemes, len_phonemes, steps=False, window=3, spectrograms=None): """Sequentially generate spectrogram from phonemes If spectrograms are provided, they are used on input instead of self-generated frames (teacher forcing) If steps are provided with spectrograms, only 'steps' frames will be generated in supervised fashion Uses layer-level caching for faster inference. :param phonemes: Padded phoneme indices :param len_phonemes: Length of each sentence in `phonemes` (list of lengths) :param steps: How many steps to generate :param window: Window size for attention masking :param spectrograms: Padded spectrograms :return: Generated spectrograms """ self.generating(True) self.train(False) assert steps or (spectrograms is not None) steps = steps if steps else spectrograms.shape[1] with torch.no_grad(): phonemes = torch.as_tensor(phonemes) keys, values = self.txt_encoder(phonemes) if hp.positional_encoding: keys += positional_encoding(keys.shape[-1], keys.shape[1], w=hp.w).to(self.device) pe = positional_encoding(hp.channels, steps, w=1).to(self.device) if spectrograms is None: dec = torch.zeros(len(phonemes), 1, hp.out_channels, device=self.device) else: input = ZeroPad2d((0, 0, 1, 0))(spectrograms)[:, :-1, :] weights, decoded = None, None if window is not None: shape = (len(phonemes), 1, phonemes.shape[-1]) idx = torch.zeros(len(phonemes), 1, phonemes.shape[-1]).to(phonemes.device) att_mask = idx_mask(shape, idx, window) else: att_mask = mask(shape=(len(phonemes), 1, keys.shape[1]), lengths=len_phonemes, dim=-1).to(self.device) for i in range(steps): if spectrograms is None: queries = self.audio_encoder(dec) else: queries = self.audio_encoder(input[:, i:i + 1, :]) if hp.positional_encoding: queries += pe[i] att, w = self.attention(queries, keys, values, att_mask) dec = self.audio_decoder(att + queries) weights = w if weights is None else torch.cat( (weights, w), dim=1) decoded = dec if decoded is None else torch.cat( (decoded, dec), dim=1) if window is not None: idx = torch.argmax(w, dim=-1).unsqueeze(2).float() att_mask = idx_mask(shape, idx, window) self.generating(False) return decoded, weights
def l1_masked(input, target, lengths): m = mask(input.shape, lengths, dim=1).float().to(input.device) return F.l1_loss(input * m, target * m, reduction='sum') / m.sum()
def forward(self, input, target, lengths): m = mask(input.shape, lengths, dim=1).float().to(input.device) return self.l1(input * m, target * m) / m.sum()
def masked_ssim(input, target, lengths): m = mask(input.shape, lengths, dim=1).float().to(input.device) input, target = input * m, target * m return 1 - ssim(input.unsqueeze(1), target.unsqueeze(1))
phonemes, plen = txt_processor(text) # append more zeros - avoid cutoff at the end of the largest sequence phonemes = torch.cat((phonemes, torch.zeros(len(phonemes), 5).long()), dim=-1) phonemes = phonemes.to(args.device) print('Synthesizing') # generate spectrograms with torch.no_grad(): spec, durations = m((phonemes, plen)) # invert to log(mel-spectrogram) spec = m.collate.norm.inverse(spec) # mask with pad value expected by MelGan msk = mask(spec.shape, durations.sum(dim=-1).long(), dim=1).to(args.device) spec = spec.masked_fill(~msk, -11.5129) # Append more pad frames to improve end of the longest sequence spec = torch.cat((spec.transpose( 2, 1), -11.5129 * torch.ones(len(spec), HPStft.n_mel, 5).to(args.device)), dim=-1) # generate audio with torch.no_grad(): audio = melgan(spec).squeeze(1) print('Saving audio') # TODO: cut audios to proper length for i, a in enumerate(audio.detach().cpu().numpy()): write_wav(os.path.join(args.audio_folder, f'{i}.wav'),
def mask_durations(durations, plen): m = mask(durations.shape, plen, dim=-1).to(durations.device).float() return durations * m