def compute_align_path(self, mu, log_sigma, y, x_mask, y_mask): # find the max alignment path attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) log_p = self.compute_log_probs(mu, log_sigma, y) # [B, T_en, T_dec] attn = maximum_path(log_p, attn_mask.squeeze(1)).unsqueeze(1) dr_mas = torch.sum(attn, -1) return dr_mas.squeeze(1), log_p
def _forward_aligner( self, x: torch.FloatTensor, y: torch.FloatTensor, x_mask: torch.IntTensor, y_mask: torch.IntTensor ) -> Tuple[torch.IntTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: """Aligner forward pass. 1. Compute a mask to apply to the attention map. 2. Run the alignment network. 3. Apply MAS to compute the hard alignment map. 4. Compute the durations from the hard alignment map. Args: x (torch.FloatTensor): Input sequence. y (torch.FloatTensor): Output sequence. x_mask (torch.IntTensor): Input sequence mask. y_mask (torch.IntTensor): Output sequence mask. Returns: Tuple[torch.IntTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: Durations from the hard alignment map, soft alignment potentials, log scale alignment potentials, hard alignment map. Shapes: - x: :math:`[B, T_en, C_en]` - y: :math:`[B, T_de, C_de]` - x_mask: :math:`[B, 1, T_en]` - y_mask: :math:`[B, 1, T_de]` - o_alignment_dur: :math:`[B, T_en]` - alignment_soft: :math:`[B, T_en, T_de]` - alignment_logprob: :math:`[B, 1, T_de, T_en]` - alignment_mas: :math:`[B, T_en, T_de]` """ attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) alignment_soft, alignment_logprob = self.aligner( y.transpose(1, 2), x.transpose(1, 2), x_mask, None) alignment_mas = maximum_path( alignment_soft.squeeze(1).transpose(1, 2).contiguous(), attn_mask.squeeze(1).contiguous()) o_alignment_dur = torch.sum(alignment_mas, -1).int() alignment_soft = alignment_soft.squeeze(1).transpose(1, 2) return o_alignment_dur, alignment_soft, alignment_logprob, alignment_mas
def inference_with_MAS(self, x, x_lengths, y=None, y_lengths=None, aux_input={ "d_vectors": None, "speaker_ids": None }): # pylint: disable=dangerous-default-value """ It's similar to the teacher forcing in Tacotron. It was proposed in: https://arxiv.org/abs/2104.05557 Shapes: - x: :math:`[B, T]` - x_lenghts: :math:`B` - y: :math:`[B, T, C]` - y_lengths: :math:`B` - g: :math:`[B, C] or B` """ y = y.transpose(1, 2) y_max_length = y.size(2) # norm speaker embeddings g = aux_input[ "d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None if self.use_speaker_embedding or self.use_d_vector_file: if not self.use_d_vector_file: g = F.normalize(g).unsqueeze(-1) else: g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1] # embedding pass o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g) # drop redisual frames wrt num_squeeze and set y_lengths. y, y_lengths, y_max_length, attn = self.preprocess( y, y_lengths, y_max_length, None) # create masks y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype) attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) # decoder pass z, logdet = self.decoder(y, y_mask, g=g, reverse=False) # find the alignment path between z and encoder output o_scale = torch.exp(-2 * o_log_scale) logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, [1]).unsqueeze(-1) # [b, t, 1] logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * (z**2)) # [b, t, d] x [b, d, t'] = [b, t, t'] logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2), z) # [b, t, d] x [b, d, t'] = [b, t, t'] logp4 = torch.sum(-0.5 * (o_mean**2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] logp = logp1 + logp2 + logp3 + logp4 # [b, t, t'] attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() y_mean, y_log_scale, o_attn_dur = self.compute_outputs( attn, o_mean, o_log_scale, x_mask) attn = attn.squeeze(1).permute(0, 2, 1) # get predited aligned distribution z = y_mean * y_mask # reverse the decoder and predict using the aligned distribution y, logdet = self.decoder(z, y_mask, g=g, reverse=True) outputs = { "model_outputs": z.transpose(1, 2), "logdet": logdet, "y_mean": y_mean.transpose(1, 2), "y_log_scale": y_log_scale.transpose(1, 2), "alignments": attn, "durations_log": o_dur_log.transpose(1, 2), "total_durations_log": o_attn_dur.transpose(1, 2), } return outputs
def forward(self, x, x_lengths, y, y_lengths=None, aux_input={ "d_vectors": None, "speaker_ids": None }): # pylint: disable=dangerous-default-value """ Shapes: - x: :math:`[B, T]` - x_lenghts::math:`B` - y: :math:`[B, T, C]` - y_lengths::math:`B` - g: :math:`[B, C] or B` """ # [B, T, C] -> [B, C, T] y = y.transpose(1, 2) y_max_length = y.size(2) # norm speaker embeddings g = aux_input[ "d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None if self.use_speaker_embedding or self.use_d_vector_file: if not self.use_d_vector_file: g = F.normalize(g).unsqueeze(-1) else: g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1] # embedding pass o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g) # drop redisual frames wrt num_squeeze and set y_lengths. y, y_lengths, y_max_length, attn = self.preprocess( y, y_lengths, y_max_length, None) # create masks y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype) # [B, 1, T_en, T_de] attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) # decoder pass z, logdet = self.decoder(y, y_mask, g=g, reverse=False) # find the alignment path with torch.no_grad(): o_scale = torch.exp(-2 * o_log_scale) logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, [1]).unsqueeze(-1) # [b, t, 1] logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * (z**2)) # [b, t, d] x [b, d, t'] = [b, t, t'] logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2), z) # [b, t, d] x [b, d, t'] = [b, t, t'] logp4 = torch.sum(-0.5 * (o_mean**2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] logp = logp1 + logp2 + logp3 + logp4 # [b, t, t'] attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() y_mean, y_log_scale, o_attn_dur = self.compute_outputs( attn, o_mean, o_log_scale, x_mask) attn = attn.squeeze(1).permute(0, 2, 1) outputs = { "z": z.transpose(1, 2), "logdet": logdet, "y_mean": y_mean.transpose(1, 2), "y_log_scale": y_log_scale.transpose(1, 2), "alignments": attn, "durations_log": o_dur_log.transpose(1, 2), "total_durations_log": o_attn_dur.transpose(1, 2), } return outputs
def forward( self, x: torch.tensor, x_lengths: torch.tensor, y: torch.tensor, y_lengths: torch.tensor, aux_input={ "d_vectors": None, "speaker_ids": None }, ) -> Dict: """Forward pass of the model. Args: x (torch.tensor): Batch of input character sequence IDs. x_lengths (torch.tensor): Batch of input character sequence lengths. y (torch.tensor): Batch of input spectrograms. y_lengths (torch.tensor): Batch of input spectrogram lengths. aux_input (dict, optional): Auxiliary inputs for multi-speaker training. Defaults to {"d_vectors": None, "speaker_ids": None}. Returns: Dict: model outputs keyed by the output name. Shapes: - x: :math:`[B, T_seq]` - x_lengths: :math:`[B]` - y: :math:`[B, C, T_spec]` - y_lengths: :math:`[B]` - d_vectors: :math:`[B, C, 1]` - speaker_ids: :math:`[B]` """ outputs = {} sid, g = self._set_cond_input(aux_input) x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths) # speaker embedding if self.num_speakers > 1 and sid is not None: g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] # posterior encoder z, m_q, logs_q, y_mask = self.posterior_encoder(y, y_lengths, g=g) # flow layers z_p = self.flow(z, y_mask, g=g) # find the alignment path attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) with torch.no_grad(): o_scale = torch.exp(-2 * logs_p) logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1]).unsqueeze(-1) # [b, t, 1] logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p**2)]) logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p]) logp4 = torch.sum(-0.5 * (m_p**2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] logp = logp2 + logp3 + logp1 + logp4 attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() # duration predictor attn_durations = attn.sum(3) if self.args.use_sdp: loss_duration = self.duration_predictor( x.detach() if self.args.detach_dp_input else x, x_mask, attn_durations, g=g.detach() if self.args.detach_dp_input and g is not None else g, ) loss_duration = loss_duration / torch.sum(x_mask) else: attn_log_durations = torch.log(attn_durations + 1e-6) * x_mask log_durations = self.duration_predictor( x.detach() if self.args.detach_dp_input else x, x_mask, g=g.detach() if self.args.detach_dp_input and g is not None else g, ) loss_duration = torch.sum((log_durations - attn_log_durations)**2, [1, 2]) / torch.sum(x_mask) outputs["loss_duration"] = loss_duration # expand prior m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p]) logs_p = torch.einsum("klmn, kjm -> kjn", [attn, logs_p]) # select a random feature segment for the waveform decoder z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size) o = self.waveform_decoder(z_slice, g=g) outputs.update({ "model_outputs": o, "alignments": attn.squeeze(1), "slice_ids": slice_ids, "z": z, "z_p": z_p, "m_p": m_p, "logs_p": logs_p, "m_q": m_q, "logs_q": logs_q, }) return outputs
def forward(self, x, x_lengths, y, y_lengths=None, aux_input={ "d_vectors": None, "speaker_ids": None }): # pylint: disable=dangerous-default-value """ Args: x (torch.Tensor): Input text sequence ids. :math:`[B, T_en]` x_lengths (torch.Tensor): Lengths of input text sequences. :math:`[B]` y (torch.Tensor): Target mel-spectrogram frames. :math:`[B, T_de, C_mel]` y_lengths (torch.Tensor): Lengths of target mel-spectrogram frames. :math:`[B]` aux_input (Dict): Auxiliary inputs. `d_vectors` is speaker embedding vectors for a multi-speaker model. :math:`[B, D_vec]`. `speaker_ids` is speaker ids for a multi-speaker model usind speaker-embedding layer. :math:`B` Returns: Dict: - z: :math: `[B, T_de, C]` - logdet: :math:`B` - y_mean: :math:`[B, T_de, C]` - y_log_scale: :math:`[B, T_de, C]` - alignments: :math:`[B, T_en, T_de]` - durations_log: :math:`[B, T_en, 1]` - total_durations_log: :math:`[B, T_en, 1]` """ # [B, T, C] -> [B, C, T] y = y.transpose(1, 2) y_max_length = y.size(2) # norm speaker embeddings g = self._speaker_embedding(aux_input) # embedding pass o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g) # drop redisual frames wrt num_squeeze and set y_lengths. y, y_lengths, y_max_length, attn = self.preprocess( y, y_lengths, y_max_length, None) # create masks y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype) # [B, 1, T_en, T_de] attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) # decoder pass z, logdet = self.decoder(y, y_mask, g=g, reverse=False) # find the alignment path with torch.no_grad(): o_scale = torch.exp(-2 * o_log_scale) logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, [1]).unsqueeze(-1) # [b, t, 1] logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * (z**2)) # [b, t, d] x [b, d, t'] = [b, t, t'] logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2), z) # [b, t, d] x [b, d, t'] = [b, t, t'] logp4 = torch.sum(-0.5 * (o_mean**2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] logp = logp1 + logp2 + logp3 + logp4 # [b, t, t'] attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() y_mean, y_log_scale, o_attn_dur = self.compute_outputs( attn, o_mean, o_log_scale, x_mask) attn = attn.squeeze(1).permute(0, 2, 1) outputs = { "z": z.transpose(1, 2), "logdet": logdet, "y_mean": y_mean.transpose(1, 2), "y_log_scale": y_log_scale.transpose(1, 2), "alignments": attn, "durations_log": o_dur_log.transpose(1, 2), "total_durations_log": o_attn_dur.transpose(1, 2), } return outputs
def forward( self, x: torch.tensor, x_lengths: torch.tensor, y: torch.tensor, y_lengths: torch.tensor, waveform: torch.tensor, aux_input={ "d_vectors": None, "speaker_ids": None, "language_ids": None }, ) -> Dict: """Forward pass of the model. Args: x (torch.tensor): Batch of input character sequence IDs. x_lengths (torch.tensor): Batch of input character sequence lengths. y (torch.tensor): Batch of input spectrograms. y_lengths (torch.tensor): Batch of input spectrogram lengths. waveform (torch.tensor): Batch of ground truth waveforms per sample. aux_input (dict, optional): Auxiliary inputs for multi-speaker and multi-lingual training. Defaults to {"d_vectors": None, "speaker_ids": None, "language_ids": None}. Returns: Dict: model outputs keyed by the output name. Shapes: - x: :math:`[B, T_seq]` - x_lengths: :math:`[B]` - y: :math:`[B, C, T_spec]` - y_lengths: :math:`[B]` - waveform: :math:`[B, T_wav, 1]` - d_vectors: :math:`[B, C, 1]` - speaker_ids: :math:`[B]` - language_ids: :math:`[B]` """ outputs = {} sid, g, lid = self._set_cond_input(aux_input) # speaker embedding if self.args.use_speaker_embedding and sid is not None: g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] # language embedding lang_emb = None if self.args.use_language_embedding and lid is not None: lang_emb = self.emb_l(lid).unsqueeze(-1) x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb) # posterior encoder z, m_q, logs_q, y_mask = self.posterior_encoder(y, y_lengths, g=g) # flow layers z_p = self.flow(z, y_mask, g=g) # find the alignment path attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) with torch.no_grad(): o_scale = torch.exp(-2 * logs_p) logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1]).unsqueeze(-1) # [b, t, 1] logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p**2)]) logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p]) logp4 = torch.sum(-0.5 * (m_p**2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] logp = logp2 + logp3 + logp1 + logp4 attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() # duration predictor attn_durations = attn.sum(3) if self.args.use_sdp: loss_duration = self.duration_predictor( x.detach() if self.args.detach_dp_input else x, x_mask, attn_durations, g=g.detach() if self.args.detach_dp_input and g is not None else g, lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb, ) loss_duration = loss_duration / torch.sum(x_mask) else: attn_log_durations = torch.log(attn_durations + 1e-6) * x_mask log_durations = self.duration_predictor( x.detach() if self.args.detach_dp_input else x, x_mask, g=g.detach() if self.args.detach_dp_input and g is not None else g, lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb, ) loss_duration = torch.sum((log_durations - attn_log_durations)**2, [1, 2]) / torch.sum(x_mask) outputs["loss_duration"] = loss_duration # expand prior m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p]) logs_p = torch.einsum("klmn, kjm -> kjn", [attn, logs_p]) # select a random feature segment for the waveform decoder z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size) o = self.waveform_decoder(z_slice, g=g) wav_seg = segment( waveform, slice_ids * self.config.audio.hop_length, self.args.spec_segment_size * self.config.audio.hop_length, ) if self.args.use_speaker_encoder_as_loss and self.speaker_manager.speaker_encoder is not None: # concate generated and GT waveforms wavs_batch = torch.cat((wav_seg, o), dim=0) # resample audio to speaker encoder sample_rate # pylint: disable=W0105 if self.audio_transform is not None: wavs_batch = self.audio_transform(wavs_batch) pred_embs = self.speaker_manager.speaker_encoder.forward( wavs_batch, l2_norm=True) # split generated and GT speaker embeddings gt_spk_emb, syn_spk_emb = torch.chunk(pred_embs, 2, dim=0) else: gt_spk_emb, syn_spk_emb = None, None outputs.update({ "model_outputs": o, "alignments": attn.squeeze(1), "z": z, "z_p": z_p, "m_p": m_p, "logs_p": logs_p, "m_q": m_q, "logs_q": logs_q, "waveform_seg": wav_seg, "gt_spk_emb": gt_spk_emb, "syn_spk_emb": syn_spk_emb, }) return outputs