def forward( self, text: torch.Tensor, text_lengths: torch.Tensor, feats: torch.Tensor, feats_lengths: torch.Tensor, sids: Optional[torch.Tensor] = None, spembs: Optional[torch.Tensor] = None, lids: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[ torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, ], ]: """Calculate forward propagation. Args: text (Tensor): Text index tensor (B, T_text). text_lengths (Tensor): Text length tensor (B,). feats (Tensor): Feature tensor (B, aux_channels, T_feats). feats_lengths (Tensor): Feature length tensor (B,). sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1). spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). Returns: Tensor: Waveform tensor (B, 1, segment_size * upsample_factor). Tensor: Duration negative log-likelihood (NLL) tensor (B,). Tensor: Monotonic attention weight tensor (B, 1, T_feats, T_text). Tensor: Segments start index tensor (B,). Tensor: Text mask tensor (B, 1, T_text). Tensor: Feature mask tensor (B, 1, T_feats). tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: - Tensor: Posterior encoder hidden representation (B, H, T_feats). - Tensor: Flow hidden representation (B, H, T_feats). - Tensor: Expanded text encoder projected mean (B, H, T_feats). - Tensor: Expanded text encoder projected scale (B, H, T_feats). - Tensor: Posterior encoder projected mean (B, H, T_feats). - Tensor: Posterior encoder projected scale (B, H, T_feats). """ # forward text encoder x, m_p, logs_p, x_mask = self.text_encoder(text, text_lengths) # calculate global conditioning g = None if self.spks is not None: # speaker one-hot vector embedding: (B, global_channels, 1) g = self.global_emb(sids.view(-1)).unsqueeze(-1) if self.spk_embed_dim is not None: # pretreined speaker embedding, e.g., X-vector (B, global_channels, 1) g_ = self.spemb_proj(F.normalize(spembs)).unsqueeze(-1) if g is None: g = g_ else: g = g + g_ if self.langs is not None: # language one-hot vector embedding: (B, global_channels, 1) g_ = self.lang_emb(lids.view(-1)).unsqueeze(-1) if g is None: g = g_ else: g = g + g_ # forward posterior encoder z, m_q, logs_q, y_mask = self.posterior_encoder(feats, feats_lengths, g=g) # forward flow z_p = self.flow(z, y_mask, g=g) # (B, H, T_feats) # monotonic alignment search with torch.no_grad(): # negative cross-entropy s_p_sq_r = torch.exp(-2 * logs_p) # (B, H, T_text) # (B, 1, T_text) neg_x_ent_1 = torch.sum( -0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True, ) # (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text) neg_x_ent_2 = torch.matmul( -0.5 * (z_p**2).transpose(1, 2), s_p_sq_r, ) # (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text) neg_x_ent_3 = torch.matmul( z_p.transpose(1, 2), (m_p * s_p_sq_r), ) # (B, 1, T_text) neg_x_ent_4 = torch.sum( -0.5 * (m_p**2) * s_p_sq_r, [1], keepdim=True, ) # (B, T_feats, T_text) neg_x_ent = neg_x_ent_1 + neg_x_ent_2 + neg_x_ent_3 + neg_x_ent_4 # (B, 1, T_feats, T_text) attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze( y_mask, -1) # monotonic attention weight: (B, 1, T_feats, T_text) attn = (self.maximum_path( neg_x_ent, attn_mask.squeeze(1), ).unsqueeze(1).detach()) # forward duration predictor w = attn.sum(2) # (B, 1, T_text) dur_nll = self.duration_predictor(x, x_mask, w=w, g=g) dur_nll = dur_nll / torch.sum(x_mask) # expand the length to match with the feature sequence # (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats) m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) # (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats) logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2) # get random segments z_segments, z_start_idxs = get_random_segments( z, feats_lengths, self.segment_size, ) # forward decoder with random segments wav = self.decoder(z_segments, g=g) return ( wav, dur_nll, attn, z_start_idxs, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q), )
def _forward_discrminator( self, text: torch.Tensor, text_lengths: torch.Tensor, feats: torch.Tensor, feats_lengths: torch.Tensor, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs, ) -> Dict[str, Any]: """Perform discriminator forward. Args: text (Tensor): Text index tensor (B, T_text). text_lengths (Tensor): Text length tensor (B,). feats (Tensor): Feature tensor (B, T_feats, aux_channels). feats_lengths (Tensor): Feature length tensor (B,). speech (Tensor): Speech waveform tensor (B, T_wav). speech_lengths (Tensor): Speech length tensor (B,). Returns: Dict[str, Any]: * loss (Tensor): Loss scalar tensor. * stats (Dict[str, float]): Statistics to be monitored. * weight (Tensor): Weight tensor to summarize losses. * optim_idx (int): Optimizer index (0 for G and 1 for D). """ # setup batch_size = text.size(0) speech = speech.unsqueeze(1) # calculate generator outputs reuse_cache = True if not self.cache_generator_outputs or self._cache is None: reuse_cache = False # calculate text2mel outputs text2mel_loss, stats, feats_gen = self.generator["text2mel"]( text=text, text_lengths=text_lengths, feats=feats, feats_lengths=feats_lengths, joint_training=True, **kwargs, ) # get random segments feats_gen_, start_idxs = get_random_segments( x=feats_gen.transpose(1, 2), x_lengths=feats_lengths, segment_size=self.segment_size, ) # calculate vocoder outputs speech_hat_ = self.generator["vocoder"](feats_gen_) if self.use_pqmf: speech_hat_ = self.pqmf.synthesis(speech_hat_) else: _, _, speech_hat_, start_idxs = self._cache # store cache if self.cache_generator_outputs and not reuse_cache: self._cache = (text2mel_loss, stats, speech_hat_, start_idxs) # parse outputs speech_ = get_segments( x=speech, start_idxs=start_idxs * self.generator["vocoder"].upsample_factor, segment_size=self.segment_size * self.generator["vocoder"].upsample_factor, ) # calculate discriminator outputs p_hat = self.discriminator(speech_hat_.detach()) p = self.discriminator(speech_) # calculate losses real_loss, fake_loss = self.discriminator_adv_loss(p_hat, p) loss = real_loss + fake_loss stats = dict( discriminator_loss=loss.item(), real_loss=real_loss.item(), fake_loss=fake_loss.item(), ) loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) # reset cache if reuse_cache or not self.training: self._cache = None return { "loss": loss, "stats": stats, "weight": weight, "optim_idx": 1, # needed for trainer }
def forward( self, text: torch.Tensor, text_lengths: torch.Tensor, feats: torch.Tensor, feats_lengths: torch.Tensor, pitch: torch.Tensor, pitch_lengths: torch.Tensor, energy: torch.Tensor, energy_lengths: torch.Tensor, sids: Optional[torch.Tensor] = None, spembs: Optional[torch.Tensor] = None, lids: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, ]: """Calculate forward propagation. Args: text (Tensor): Text index tensor (B, T_text). text_lengths (Tensor): Text length tensor (B,). feats (Tensor): Feature tensor (B, T_feats, aux_channels). feats_lengths (Tensor): Feature length tensor (B,). pitch (Tensor): Batch of padded token-averaged pitch (B, T_text, 1). pitch_lengths (LongTensor): Batch of pitch lengths (B, T_text). energy (Tensor): Batch of padded token-averaged energy (B, T_text, 1). energy_lengths (LongTensor): Batch of energy lengths (B, T_text). sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1). spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). Returns: Tensor: Waveform tensor (B, 1, segment_size * upsample_factor). Tensor: Binarization loss (). Tensor: Log probability attention matrix (B, T_feats, T_text). Tensor: Segments start index tensor (B,). Tensor: predicted duration (B, T_text). Tensor: ground-truth duration obtained from an alignment module (B, T_text). Tensor: predicted pitch (B, T_text,1). Tensor: ground-truth averaged pitch (B, T_text, 1). Tensor: predicted energy (B, T_text, 1). Tensor: ground-truth averaged energy (B, T_text, 1). """ text = text[:, :text_lengths.max()] # for data-parallel feats = feats[:, :feats_lengths.max()] # for data-parallel pitch = pitch[:, :pitch_lengths.max()] # for data-parallel energy = energy[:, :energy_lengths.max()] # for data-parallel # forward encoder x_masks = self._source_mask(text_lengths) hs, _ = self.encoder(text, x_masks) # (B, T_text, adim) # integrate with GST if self.use_gst: style_embs = self.gst(feats) hs = hs + style_embs.unsqueeze(1) # integrate with SID and LID embeddings if self.spks is not None: sid_embs = self.sid_emb(sids.view(-1)) hs = hs + sid_embs.unsqueeze(1) if self.langs is not None: lid_embs = self.lid_emb(lids.view(-1)) hs = hs + lid_embs.unsqueeze(1) # integrate speaker embedding if self.spk_embed_dim is not None: hs = self._integrate_with_spk_embed(hs, spembs) # forward alignment module and obtain duration, averaged pitch, energy h_masks = make_pad_mask(text_lengths).to(hs.device) log_p_attn = self.alignment_module(hs, feats, h_masks) ds, bin_loss = viterbi_decode(log_p_attn, text_lengths, feats_lengths) ps = average_by_duration(ds, pitch.squeeze(-1), text_lengths, feats_lengths).unsqueeze(-1) es = average_by_duration(ds, energy.squeeze(-1), text_lengths, feats_lengths).unsqueeze(-1) # forward duration predictor and variance predictors if self.stop_gradient_from_pitch_predictor: p_outs = self.pitch_predictor(hs.detach(), h_masks.unsqueeze(-1)) else: p_outs = self.pitch_predictor(hs, h_masks.unsqueeze(-1)) if self.stop_gradient_from_energy_predictor: e_outs = self.energy_predictor(hs.detach(), h_masks.unsqueeze(-1)) else: e_outs = self.energy_predictor(hs, h_masks.unsqueeze(-1)) d_outs = self.duration_predictor(hs, h_masks) # use groundtruth in training p_embs = self.pitch_embed(ps.transpose(1, 2)).transpose(1, 2) e_embs = self.energy_embed(es.transpose(1, 2)).transpose(1, 2) hs = hs + e_embs + p_embs # upsampling h_masks = make_non_pad_mask(feats_lengths).to(hs.device) d_masks = make_non_pad_mask(text_lengths).to(ds.device) hs = self.length_regulator(hs, ds, h_masks, d_masks) # (B, T_feats, adim) # forward decoder h_masks = self._source_mask(feats_lengths) zs, _ = self.decoder(hs, h_masks) # (B, T_feats, adim) # get random segments z_segments, z_start_idxs = get_random_segments( zs.transpose(1, 2), feats_lengths, self.segment_size, ) # forward generator wav = self.generator(z_segments) return ( wav, bin_loss, log_p_attn, z_start_idxs, d_outs, ds, p_outs, ps, e_outs, es, )