def output_types(self): if not self.calculate_loss and not self.training: return { "spec_pred_dec": NeuralType(('B', 'D', 'T'), MelSpectrogramType()), "spec_pred_postnet": NeuralType(('B', 'D', 'T'), MelSpectrogramType()), "gate_pred": NeuralType(('B', 'T'), LogitsType()), "alignments": NeuralType(('B', 'T', 'T'), SequenceToSequenceAlignmentType()), "pred_length": NeuralType(('B'), LengthsType()), } return { "spec_pred_dec": NeuralType(('B', 'D', 'T'), MelSpectrogramType()), "spec_pred_postnet": NeuralType(('B', 'D', 'T'), MelSpectrogramType()), "gate_pred": NeuralType(('B', 'T'), LogitsType()), "spec_target": NeuralType(('B', 'D', 'T'), MelSpectrogramType()), "spec_target_len": NeuralType(('B'), LengthsType()), "alignments": NeuralType(('B', 'T', 'T'), SequenceToSequenceAlignmentType()), }
def output_types(self): output_dict = { "mel_outputs": NeuralType(('B', 'D', 'T'), MelSpectrogramType()), "gate_outputs": NeuralType(('B', 'T'), LogitsType()), "alignments": NeuralType(('B', 'T', 'T'), SequenceToSequenceAlignmentType()), } if not self.training: output_dict["mel_lengths"] = NeuralType(('B'), LengthsType()) return output_dict
def output_types(self): return { "spect": NeuralType(('B', 'D', 'T'), MelSpectrogramType()), "spect_lens": NeuralType(('B'), SequenceToSequenceAlignmentType()), "spect_mask": NeuralType(('B', 'D', 'T'), MaskType()), "durs_predicted": NeuralType(('B', 'T'), TokenDurationType()), "log_durs_predicted": NeuralType(('B', 'T'), TokenLogDurationType()), "pitch_predicted": NeuralType(('B', 'T'), RegressionValuesType()), }
def output_types(self): return { "z": NeuralType(('B', 'D', 'T'), NormalDistributionSamplesType()), "y_m": NeuralType(('B', 'D', 'T'), NormalDistributionMeanType()), "y_logs": NeuralType(('B', 'D', 'T'), NormalDistributionLogVarianceType()), "logdet": NeuralType(('B'), LogDeterminantType()), "log_durs_predicted": NeuralType(('B', 'T'), TokenLogDurationType()), "log_durs_extracted": NeuralType(('B', 'T'), TokenLogDurationType()), "spect_lengths": NeuralType(('B'), LengthsType()), "attn": NeuralType(('B', 'T', 'T'), SequenceToSequenceAlignmentType()), }
class GlowTTSModule(NeuralModule): def __init__( self, encoder_module: NeuralModule, decoder_module: NeuralModule, n_speakers: int = 1, gin_channels: int = 0 ): """ Main GlowTTS module. Contains the encoder and decoder. Args: encoder_module (NeuralModule): Text encoder for predicting latent distribution statistics decoder_module (NeuralModule): Invertible spectrogram decoder n_speakers (int): Number of speakers gin_channels (int): Channels in speaker embeddings """ super().__init__() self.encoder = encoder_module self.decoder = decoder_module if n_speakers > 1: self.emb_g = nn.Embedding(n_speakers, gin_channels) nn.init.uniform_(self.emb_g.weight, -0.1, 0.1) @property def input_types(self): return { "text": NeuralType(('B', 'T'), TokenIndex()), "text_lengths": NeuralType(('B'), LengthsType()), "spect": NeuralType(('B', 'D', 'T'), MelSpectrogramType()), "spect_lengths": NeuralType(('B'), LengthsType()), "speaker": NeuralType(('B'), IntType(), optional=True), } @property def output_types(self): return { "z": NeuralType(('B', 'D', 'T'), NormalDistributionSamplesType()), "y_m": NeuralType(('B', 'D', 'T'), NormalDistributionMeanType()), "y_logs": NeuralType(('B', 'D', 'T'), NormalDistributionLogVarianceType()), "logdet": NeuralType(('B'), LogDeterminantType()), "log_durs_predicted": NeuralType(('B', 'T'), TokenLogDurationType()), "log_durs_extracted": NeuralType(('B', 'T'), TokenLogDurationType()), "spect_lengths": NeuralType(('B'), LengthsType()), "attn": NeuralType(('B', 'T', 'T'), SequenceToSequenceAlignmentType()), } @typecheck() def forward(self, *, text, text_lengths, spect, spect_lengths, speaker=None): if speaker is not None: speaker = F.normalize(self.emb_g(speaker)).unsqueeze(-1) # [b, h] x_m, x_logs, log_durs_predicted, x_mask = self.encoder( text=text, text_lengths=text_lengths, speaker_embeddings=speaker ) y_max_length = spect.size(2) y_max_length = (y_max_length // self.decoder.n_sqz) * self.decoder.n_sqz spect = spect[:, :, :y_max_length] spect_lengths = (spect_lengths // self.decoder.n_sqz) * self.decoder.n_sqz y_mask = torch.unsqueeze(glow_tts_submodules.sequence_mask(spect_lengths, y_max_length), 1).to(x_mask.dtype) attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) z, logdet = self.decoder(spect=spect, spect_mask=y_mask, speaker_embeddings=speaker, reverse=False) with torch.no_grad(): x_s_sq_r = torch.exp(-2 * x_logs) logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - x_logs, [1]).unsqueeze(-1) # [b, t, 1] logp2 = torch.matmul(x_s_sq_r.transpose(1, 2), -0.5 * (z ** 2)) # [b, t, d] x [b, d, t'] = [b, t, t'] logp3 = torch.matmul((x_m * x_s_sq_r).transpose(1, 2), z) # [b, t, d] x [b, d, t'] = [b, t, t'] logp4 = torch.sum(-0.5 * (x_m ** 2) * x_s_sq_r, [1]).unsqueeze(-1) # [b, t, 1] logp = logp1 + logp2 + logp3 + logp4 # [b, t, t'] attn = (glow_tts_submodules.maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach()).squeeze(1) y_m = torch.matmul(x_m, attn) y_logs = torch.matmul(x_logs, attn) log_durs_extracted = torch.log(1e-8 + torch.sum(attn, -1)) * x_mask.squeeze() return z, y_m, y_logs, logdet, log_durs_predicted, log_durs_extracted, spect_lengths, attn @typecheck( input_types={ "text": NeuralType(('B', 'T'), TokenIndex()), "text_lengths": NeuralType(('B',), LengthsType()), "speaker": NeuralType(('B'), IntType(), optional=True), "noise_scale": NeuralType(optional=True), "length_scale": NeuralType(optional=True), }, output_types={ "y": NeuralType(('B', 'D', 'T'), MelSpectrogramType()), "attn": NeuralType(('B', 'T', 'T'), SequenceToSequenceAlignmentType()), }, ) def generate_spect(self, *, text, text_lengths, speaker=None, noise_scale=0.3, length_scale=1.0): if speaker is not None: speaker = F.normalize(self.emb_g(speaker)).unsqueeze(-1) # [b, h] x_m, x_logs, log_durs_predicted, x_mask = self.encoder( text=text, text_lengths=text_lengths, speaker_embeddings=speaker ) w = torch.exp(log_durs_predicted) * x_mask.squeeze() * length_scale w_ceil = torch.ceil(w) spect_lengths = torch.clamp_min(torch.sum(w_ceil, [1]), 1).long() y_max_length = None spect_lengths = (spect_lengths // self.decoder.n_sqz) * self.decoder.n_sqz y_mask = torch.unsqueeze(glow_tts_submodules.sequence_mask(spect_lengths, y_max_length), 1).to(x_mask.dtype) attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) attn = glow_tts_submodules.generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)) y_m = torch.matmul(x_m, attn) y_logs = torch.matmul(x_logs, attn) z = (y_m + torch.exp(y_logs) * torch.randn_like(y_m) * noise_scale) * y_mask y, _ = self.decoder(spect=z, spect_mask=y_mask, speaker_embeddings=speaker, reverse=True) return y, attn def save_to(self, save_path: str): """TODO: Implement""" @classmethod def restore_from(cls, restore_path: str): """TODO: Implement"""