def _forward_discrminator( self, text: torch.Tensor, text_lengths: torch.Tensor, feats: torch.Tensor, feats_lengths: torch.Tensor, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs, ) -> Dict[str, Any]: """Perform discriminator forward. Args: text (Tensor): Text index tensor (B, T_text). text_lengths (Tensor): Text length tensor (B,). feats (Tensor): Feature tensor (B, T_feats, aux_channels). feats_lengths (Tensor): Feature length tensor (B,). speech (Tensor): Speech waveform tensor (B, T_wav). speech_lengths (Tensor): Speech length tensor (B,). Returns: Dict[str, Any]: * loss (Tensor): Loss scalar tensor. * stats (Dict[str, float]): Statistics to be monitored. * weight (Tensor): Weight tensor to summarize losses. * optim_idx (int): Optimizer index (0 for G and 1 for D). """ # setup batch_size = text.size(0) speech = speech.unsqueeze(1) # calculate generator outputs reuse_cache = True if not self.cache_generator_outputs or self._cache is None: reuse_cache = False # calculate text2mel outputs text2mel_loss, stats, feats_gen = self.generator["text2mel"]( text=text, text_lengths=text_lengths, feats=feats, feats_lengths=feats_lengths, joint_training=True, **kwargs, ) # get random segments feats_gen_, start_idxs = get_random_segments( x=feats_gen.transpose(1, 2), x_lengths=feats_lengths, segment_size=self.segment_size, ) # calculate vocoder outputs speech_hat_ = self.generator["vocoder"](feats_gen_) if self.use_pqmf: speech_hat_ = self.pqmf.synthesis(speech_hat_) else: _, _, speech_hat_, start_idxs = self._cache # store cache if self.cache_generator_outputs and not reuse_cache: self._cache = (text2mel_loss, stats, speech_hat_, start_idxs) # parse outputs speech_ = get_segments( x=speech, start_idxs=start_idxs * self.generator["vocoder"].upsample_factor, segment_size=self.segment_size * self.generator["vocoder"].upsample_factor, ) # calculate discriminator outputs p_hat = self.discriminator(speech_hat_.detach()) p = self.discriminator(speech_) # calculate losses real_loss, fake_loss = self.discriminator_adv_loss(p_hat, p) loss = real_loss + fake_loss stats = dict( discriminator_loss=loss.item(), real_loss=real_loss.item(), fake_loss=fake_loss.item(), ) loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) # reset cache if reuse_cache or not self.training: self._cache = None return { "loss": loss, "stats": stats, "weight": weight, "optim_idx": 1, # needed for trainer }
def _forward_generator( self, text: torch.Tensor, text_lengths: torch.Tensor, feats: torch.Tensor, feats_lengths: torch.Tensor, speech: torch.Tensor, speech_lengths: torch.Tensor, sids: Optional[torch.Tensor] = None, spembs: Optional[torch.Tensor] = None, lids: Optional[torch.Tensor] = None, ) -> Dict[str, Any]: """Perform generator forward. Args: text (Tensor): Text index tensor (B, T_text). text_lengths (Tensor): Text length tensor (B,). feats (Tensor): Feature tensor (B, T_feats, aux_channels). feats_lengths (Tensor): Feature length tensor (B,). speech (Tensor): Speech waveform tensor (B, T_wav). speech_lengths (Tensor): Speech length tensor (B,). sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1). spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). Returns: Dict[str, Any]: * loss (Tensor): Loss scalar tensor. * stats (Dict[str, float]): Statistics to be monitored. * weight (Tensor): Weight tensor to summarize losses. * optim_idx (int): Optimizer index (0 for G and 1 for D). """ # setup batch_size = text.size(0) feats = feats.transpose(1, 2) speech = speech.unsqueeze(1) # calculate generator outputs reuse_cache = True if not self.cache_generator_outputs or self._cache is None: reuse_cache = False outs = self.generator( text=text, text_lengths=text_lengths, feats=feats, feats_lengths=feats_lengths, sids=sids, spembs=spembs, lids=lids, ) else: outs = self._cache # store cache if self.training and self.cache_generator_outputs and not reuse_cache: self._cache = outs # parse outputs speech_hat_, dur_nll, _, start_idxs, _, z_mask, outs_ = outs _, z_p, m_p, logs_p, _, logs_q = outs_ speech_ = get_segments( x=speech, start_idxs=start_idxs * self.generator.upsample_factor, segment_size=self.generator.segment_size * self.generator.upsample_factor, ) # calculate discriminator outputs p_hat = self.discriminator(speech_hat_) with torch.no_grad(): # do not store discriminator gradient in generator turn p = self.discriminator(speech_) # calculate losses with autocast(enabled=False): mel_loss = self.mel_loss(speech_hat_, speech_) kl_loss = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask) dur_loss = torch.sum(dur_nll.float()) adv_loss = self.generator_adv_loss(p_hat) feat_match_loss = self.feat_match_loss(p_hat, p) mel_loss = mel_loss * self.lambda_mel kl_loss = kl_loss * self.lambda_kl dur_loss = dur_loss * self.lambda_dur adv_loss = adv_loss * self.lambda_adv feat_match_loss = feat_match_loss * self.lambda_feat_match loss = mel_loss + kl_loss + dur_loss + adv_loss + feat_match_loss stats = dict( generator_loss=loss.item(), generator_mel_loss=mel_loss.item(), generator_kl_loss=kl_loss.item(), generator_dur_loss=dur_loss.item(), generator_adv_loss=adv_loss.item(), generator_feat_match_loss=feat_match_loss.item(), ) loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) # reset cache if reuse_cache or not self.training: self._cache = None return { "loss": loss, "stats": stats, "weight": weight, "optim_idx": 0, # needed for trainer }
def _forward_discrminator( self, text: torch.Tensor, text_lengths: torch.Tensor, feats: torch.Tensor, feats_lengths: torch.Tensor, speech: torch.Tensor, speech_lengths: torch.Tensor, sids: Optional[torch.Tensor] = None, spembs: Optional[torch.Tensor] = None, lids: Optional[torch.Tensor] = None, ) -> Dict[str, Any]: """Perform discriminator forward. Args: text (Tensor): Text index tensor (B, T_text). text_lengths (Tensor): Text length tensor (B,). feats (Tensor): Feature tensor (B, T_feats, aux_channels). feats_lengths (Tensor): Feature length tensor (B,). speech (Tensor): Speech waveform tensor (B, T_wav). speech_lengths (Tensor): Speech length tensor (B,). sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1). spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). Returns: Dict[str, Any]: * loss (Tensor): Loss scalar tensor. * stats (Dict[str, float]): Statistics to be monitored. * weight (Tensor): Weight tensor to summarize losses. * optim_idx (int): Optimizer index (0 for G and 1 for D). """ # setup batch_size = text.size(0) feats = feats.transpose(1, 2) speech = speech.unsqueeze(1) # calculate generator outputs reuse_cache = True if not self.cache_generator_outputs or self._cache is None: reuse_cache = False outs = self.generator( text=text, text_lengths=text_lengths, feats=feats, feats_lengths=feats_lengths, sids=sids, spembs=spembs, lids=lids, ) else: outs = self._cache # store cache if self.cache_generator_outputs and not reuse_cache: self._cache = outs # parse outputs speech_hat_, _, _, start_idxs, *_ = outs speech_ = get_segments( x=speech, start_idxs=start_idxs * self.generator.upsample_factor, segment_size=self.generator.segment_size * self.generator.upsample_factor, ) # calculate discriminator outputs p_hat = self.discriminator(speech_hat_.detach()) p = self.discriminator(speech_) # calculate losses with autocast(enabled=False): real_loss, fake_loss = self.discriminator_adv_loss(p_hat, p) loss = real_loss + fake_loss stats = dict( discriminator_loss=loss.item(), discriminator_real_loss=real_loss.item(), discriminator_fake_loss=fake_loss.item(), ) loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) # reset cache if reuse_cache or not self.training: self._cache = None return { "loss": loss, "stats": stats, "weight": weight, "optim_idx": 1, # needed for trainer }
def _forward_generator( self, text: torch.Tensor, text_lengths: torch.Tensor, feats: torch.Tensor, feats_lengths: torch.Tensor, speech: torch.Tensor, speech_lengths: torch.Tensor, sids: Optional[torch.Tensor] = None, spembs: Optional[torch.Tensor] = None, lids: Optional[torch.Tensor] = None, **kwargs, ) -> Dict[str, Any]: """Perform generator forward. Args: text (Tensor): Text index tensor (B, T_text). text_lengths (Tensor): Text length tensor (B,). feats (Tensor): Feature tensor (B, T_feats, aux_channels). feats_lengths (Tensor): Feature length tensor (B,). speech (Tensor): Speech waveform tensor (B, T_wav). speech_lengths (Tensor): Speech length tensor (B,). sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1). spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). Returns: Dict[str, Any]: * loss (Tensor): Loss scalar tensor. * stats (Dict[str, float]): Statistics to be monitored. * weight (Tensor): Weight tensor to summarize losses. * optim_idx (int): Optimizer index (0 for G and 1 for D). """ # setup batch_size = text.size(0) speech = speech.unsqueeze(1) # calculate generator outputs reuse_cache = True if not self.cache_generator_outputs or self._cache is None: reuse_cache = False outs = self.generator( text=text, text_lengths=text_lengths, feats=feats, feats_lengths=feats_lengths, sids=sids, spembs=spembs, lids=lids, **kwargs, ) else: outs = self._cache # store cache if self.training and self.cache_generator_outputs and not reuse_cache: self._cache = outs # parse outputs ( speech_hat_, bin_loss, log_p_attn, start_idxs, d_outs, ds, p_outs, ps, e_outs, es, ) = outs speech_ = get_segments( x=speech, start_idxs=start_idxs * self.generator.upsample_factor, segment_size=self.generator.segment_size * self.generator.upsample_factor, ) # calculate discriminator outputs p_hat = self.discriminator(speech_hat_) with torch.no_grad(): # do not store discriminator gradient in generator turn p = self.discriminator(speech_) # calculate losses mel_loss = self.mel_loss(speech_hat_, speech_) adv_loss = self.generator_adv_loss(p_hat) feat_match_loss = self.feat_match_loss(p_hat, p) dur_loss, pitch_loss, energy_loss = self.var_loss( d_outs, ds, p_outs, ps, e_outs, es, text_lengths ) forwardsum_loss = self.forwardsum_loss(log_p_attn, text_lengths, feats_lengths) mel_loss = mel_loss * self.lambda_mel adv_loss = adv_loss * self.lambda_adv feat_match_loss = feat_match_loss * self.lambda_feat_match g_loss = mel_loss + adv_loss + feat_match_loss var_loss = (dur_loss + pitch_loss + energy_loss) * self.lambda_var align_loss = (forwardsum_loss + bin_loss) * self.lambda_align loss = g_loss + var_loss + align_loss stats = dict( generator_loss=loss.item(), generator_g_loss=g_loss.item(), generator_var_loss=var_loss.item(), generator_align_loss=align_loss.item(), generator_g_mel_loss=mel_loss.item(), generator_g_adv_loss=adv_loss.item(), generator_g_feat_match_loss=feat_match_loss.item(), generator_var_dur_loss=dur_loss.item(), generator_var_pitch_loss=pitch_loss.item(), generator_var_energy_loss=energy_loss.item(), generator_align_forwardsum_loss=forwardsum_loss.item(), generator_align_bin_loss=bin_loss.item(), ) loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) # reset cache if reuse_cache or not self.training: self._cache = None return { "loss": loss, "stats": stats, "weight": weight, "optim_idx": 0, # needed for trainer }