def segment_test(): x = T.range(0, 11) x = x.repeat(8, 1).unsqueeze(1) segment_ids = T.randint(0, 7, (8, )) segments = segment(x, segment_ids, segment_size=4) for idx, start_indx in enumerate(segment_ids): assert x[idx, :, start_indx:start_indx + 4].sum() == segments[idx, :, :].sum()
def segment_test(): x = T.range(0, 11) x = x.repeat(8, 1).unsqueeze(1) segment_ids = T.randint(0, 7, (8, )) segments = segment(x, segment_ids, segment_size=4) for idx, start_indx in enumerate(segment_ids): assert x[idx, :, start_indx:start_indx + 4].sum() == segments[idx, :, :].sum() try: segments = segment(x, segment_ids, segment_size=10) raise Exception("Should have failed") except: pass segments = segment(x, segment_ids, segment_size=10, pad_short=True) for idx, start_indx in enumerate(segment_ids): assert x[idx, :, start_indx:start_indx + 10].sum() == segments[idx, :, :].sum()
def train_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]: """Perform a single training step. Run the model forward pass and compute losses. Args: batch (Dict): Input tensors. criterion (nn.Module): Loss layer designed for the model. optimizer_idx (int): Index of optimizer to use. 0 for the generator and 1 for the discriminator networks. Returns: Tuple[Dict, Dict]: Model ouputs and computed losses. """ # pylint: disable=attribute-defined-outside-init if optimizer_idx not in [0, 1]: raise ValueError(" [!] Unexpected `optimizer_idx`.") if optimizer_idx == 0: text_input = batch["text_input"] text_lengths = batch["text_lengths"] mel_lengths = batch["mel_lengths"] linear_input = batch["linear_input"] d_vectors = batch["d_vectors"] speaker_ids = batch["speaker_ids"] waveform = batch["waveform"] # generator pass outputs = self.forward( text_input, text_lengths, linear_input.transpose(1, 2), mel_lengths, aux_input={ "d_vectors": d_vectors, "speaker_ids": speaker_ids }, ) # cache tensors for the discriminator self.y_disc_cache = None self.wav_seg_disc_cache = None self.y_disc_cache = outputs["model_outputs"] wav_seg = segment( waveform.transpose(1, 2), outputs["slice_ids"] * self.config.audio.hop_length, self.args.spec_segment_size * self.config.audio.hop_length, ) self.wav_seg_disc_cache = wav_seg outputs["waveform_seg"] = wav_seg # compute discriminator scores and features ( outputs["scores_disc_fake"], outputs["feats_disc_fake"], _, outputs["feats_disc_real"], ) = self.disc(outputs["model_outputs"], wav_seg) # compute losses with autocast(enabled=False): # use float32 for the criterion loss_dict = criterion[optimizer_idx]( waveform_hat=outputs["model_outputs"].float(), waveform=wav_seg.float(), z_p=outputs["z_p"].float(), logs_q=outputs["logs_q"].float(), m_p=outputs["m_p"].float(), logs_p=outputs["logs_p"].float(), z_len=mel_lengths, scores_disc_fake=outputs["scores_disc_fake"], feats_disc_fake=outputs["feats_disc_fake"], feats_disc_real=outputs["feats_disc_real"], loss_duration=outputs["loss_duration"], ) elif optimizer_idx == 1: # discriminator pass outputs = {} # compute scores and features outputs["scores_disc_fake"], _, outputs[ "scores_disc_real"], _ = self.disc(self.y_disc_cache.detach(), self.wav_seg_disc_cache) # compute loss with autocast(enabled=False): # use float32 for the criterion loss_dict = criterion[optimizer_idx]( outputs["scores_disc_real"], outputs["scores_disc_fake"], ) return outputs, loss_dict
def forward( self, x: torch.tensor, x_lengths: torch.tensor, y: torch.tensor, y_lengths: torch.tensor, waveform: torch.tensor, aux_input={ "d_vectors": None, "speaker_ids": None, "language_ids": None }, ) -> Dict: """Forward pass of the model. Args: x (torch.tensor): Batch of input character sequence IDs. x_lengths (torch.tensor): Batch of input character sequence lengths. y (torch.tensor): Batch of input spectrograms. y_lengths (torch.tensor): Batch of input spectrogram lengths. waveform (torch.tensor): Batch of ground truth waveforms per sample. aux_input (dict, optional): Auxiliary inputs for multi-speaker and multi-lingual training. Defaults to {"d_vectors": None, "speaker_ids": None, "language_ids": None}. Returns: Dict: model outputs keyed by the output name. Shapes: - x: :math:`[B, T_seq]` - x_lengths: :math:`[B]` - y: :math:`[B, C, T_spec]` - y_lengths: :math:`[B]` - waveform: :math:`[B, T_wav, 1]` - d_vectors: :math:`[B, C, 1]` - speaker_ids: :math:`[B]` - language_ids: :math:`[B]` """ outputs = {} sid, g, lid = self._set_cond_input(aux_input) # speaker embedding if self.args.use_speaker_embedding and sid is not None: g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] # language embedding lang_emb = None if self.args.use_language_embedding and lid is not None: lang_emb = self.emb_l(lid).unsqueeze(-1) x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb) # posterior encoder z, m_q, logs_q, y_mask = self.posterior_encoder(y, y_lengths, g=g) # flow layers z_p = self.flow(z, y_mask, g=g) # find the alignment path attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) with torch.no_grad(): o_scale = torch.exp(-2 * logs_p) logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1]).unsqueeze(-1) # [b, t, 1] logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p**2)]) logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p]) logp4 = torch.sum(-0.5 * (m_p**2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] logp = logp2 + logp3 + logp1 + logp4 attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() # duration predictor attn_durations = attn.sum(3) if self.args.use_sdp: loss_duration = self.duration_predictor( x.detach() if self.args.detach_dp_input else x, x_mask, attn_durations, g=g.detach() if self.args.detach_dp_input and g is not None else g, lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb, ) loss_duration = loss_duration / torch.sum(x_mask) else: attn_log_durations = torch.log(attn_durations + 1e-6) * x_mask log_durations = self.duration_predictor( x.detach() if self.args.detach_dp_input else x, x_mask, g=g.detach() if self.args.detach_dp_input and g is not None else g, lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb, ) loss_duration = torch.sum((log_durations - attn_log_durations)**2, [1, 2]) / torch.sum(x_mask) outputs["loss_duration"] = loss_duration # expand prior m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p]) logs_p = torch.einsum("klmn, kjm -> kjn", [attn, logs_p]) # select a random feature segment for the waveform decoder z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size) o = self.waveform_decoder(z_slice, g=g) wav_seg = segment( waveform, slice_ids * self.config.audio.hop_length, self.args.spec_segment_size * self.config.audio.hop_length, ) if self.args.use_speaker_encoder_as_loss and self.speaker_manager.speaker_encoder is not None: # concate generated and GT waveforms wavs_batch = torch.cat((wav_seg, o), dim=0) # resample audio to speaker encoder sample_rate # pylint: disable=W0105 if self.audio_transform is not None: wavs_batch = self.audio_transform(wavs_batch) pred_embs = self.speaker_manager.speaker_encoder.forward( wavs_batch, l2_norm=True) # split generated and GT speaker embeddings gt_spk_emb, syn_spk_emb = torch.chunk(pred_embs, 2, dim=0) else: gt_spk_emb, syn_spk_emb = None, None outputs.update({ "model_outputs": o, "alignments": attn.squeeze(1), "z": z, "z_p": z_p, "m_p": m_p, "logs_p": logs_p, "m_q": m_q, "logs_q": logs_q, "waveform_seg": wav_seg, "gt_spk_emb": gt_spk_emb, "syn_spk_emb": syn_spk_emb, }) return outputs