def compute_per_example_losses(self, batch, outputs): mel_before, mel_after, duration_outputs, f0_outputs, energy_outputs = outputs log_duration = tf.math.log( tf.cast(tf.math.add(batch["duration_gts"], 1), tf.float32)) duration_loss = calculate_2d_loss(log_duration, duration_outputs, self.mse) f0_loss = calculate_2d_loss(batch["f0_gts"], f0_outputs, self.mse) energy_loss = calculate_2d_loss(batch["energy_gts"], energy_outputs, self.mse) mel_loss_before = calculate_3d_loss(batch["mel_gts"], mel_before, self.mae) mel_loss_after = calculate_3d_loss(batch["mel_gts"], mel_after, self.mae) per_example_losses = (duration_loss + f0_loss + energy_loss + mel_loss_before + mel_loss_after) dict_metrics_losses = { "duration_loss": duration_loss, "f0_loss": f0_loss, "energy_loss": energy_loss, "mel_loss_before": mel_loss_before, "mel_loss_after": mel_loss_after, } return per_example_losses, dict_metrics_losses
def compute_per_example_generator_losses(self, batch, outputs): """Compute per example generator losses and return dict_metrics_losses Note that all element of the loss MUST has a shape [batch_size] and the keys of dict_metrics_losses MUST be in self.list_metrics_name. Args: batch: dictionary batch input return from dataloader outputs: outputs of the model Returns: per_example_losses: per example losses for each GPU, shape [B] dict_metrics_losses: dictionary loss. """ dict_metrics_losses = {} per_example_losses = 0.0 audios = batch["audios"] y_hat = outputs # calculate multi-resolution stft loss sc_loss, mag_loss = calculate_2d_loss(audios, tf.squeeze(y_hat, -1), self.stft_loss) # trick to prevent loss expoded here sc_loss = tf.where(sc_loss >= 15.0, 0.0, sc_loss) mag_loss = tf.where(mag_loss >= 15.0, 0.0, mag_loss) # compute generator loss gen_loss = 0.5 * (sc_loss + mag_loss) if self.steps >= self.config["discriminator_train_start_steps"]: p_hat = self._discriminator(y_hat) p = self._discriminator(tf.expand_dims(audios, 2)) adv_loss = 0.0 for i in range(len(p_hat)): adv_loss += calculate_3d_loss(tf.ones_like(p_hat[i][-1]), p_hat[i][-1], loss_fn=self.mse_loss) adv_loss /= i + 1 # define feature-matching loss fm_loss = 0.0 for i in range(len(p_hat)): for j in range(len(p_hat[i]) - 1): fm_loss += calculate_3d_loss(p[i][j], p_hat[i][j], loss_fn=self.mae_loss) fm_loss /= (i + 1) * (j + 1) adv_loss += self.config["lambda_feat_match"] * fm_loss gen_loss += self.config["lambda_adv"] * adv_loss dict_metrics_losses.update({"adversarial_loss": adv_loss}) dict_metrics_losses.update({"fm_loss": fm_loss}) dict_metrics_losses.update({"gen_loss": gen_loss}) dict_metrics_losses.update({"spectral_convergence_loss": sc_loss}) dict_metrics_losses.update({"log_magnitude_loss": mag_loss}) per_example_losses = gen_loss return per_example_losses, dict_metrics_losses
def compute_per_example_losses(self, batch, outputs): ( decoder_output, post_mel_outputs, stop_token_predictions, alignment_historys ) = outputs mel_loss_before = calculate_3d_loss(batch["mel_gts"], decoder_output, loss_fn=self.mae) mel_loss_after = calculate_3d_loss(batch["mel_gts"], post_mel_outputs, loss_fn=self.mae) # calculate stop_loss max_mel_length = tf.reduce_max(batch["mel_lengths"]) stop_gts = tf.expand_dims(tf.range(tf.reduce_max(max_mel_length), dtype=tf.int32), 0) # [1, max_len] stop_gts = tf.tile(stop_gts, [tf.shape(batch["mel_lengths"])[0], 1]) # [B, max_len] stop_gts = tf.cast(tf.math.greater_equal(stop_gts, tf.expand_dims(batch["mel_lengths"], 1)),tf.float32) stop_token_loss = calculate_2d_loss(stop_gts, stop_token_predictions, loss_fn=self.binary_crossentropy) attention_masks = tf.cast(tf.math.not_equal(batch["g_attentions"], -1.0), tf.float32) loss_att = tf.reduce_sum(tf.abs(alignment_historys * batch["g_attentions"]) * attention_masks,axis=[1, 2]) loss_att /= tf.reduce_sum(attention_masks, axis=[1, 2]) per_example_losses = (stop_token_loss + mel_loss_before + mel_loss_after + loss_att) dict_metrics_losses = { "stop_token_loss": stop_token_loss, "mel_loss_before": mel_loss_before, "mel_loss_after": mel_loss_after, "guided_attention_loss": loss_att } return per_example_losses, dict_metrics_losses
def compute_per_example_losses(self, batch, outputs): """Compute per example losses and return dict_metrics_losses Note that all element of the loss MUST has a shape [batch_size] and the keys of dict_metrics_losses MUST be in self.list_metrics_name. Args: batch: dictionary batch input return from dataloader outputs: outputs of the model Returns: per_example_losses: per example losses for each GPU, shape [B] dict_metrics_losses: dictionary loss. """ mel_before, mel_after, duration_outputs, f0_outputs, energy_outputs = outputs log_duration = tf.math.log( tf.cast(tf.math.add(batch["duration_gts"], 1), tf.float32) ) duration_loss = calculate_2d_loss(log_duration, duration_outputs, self.mse) f0_loss = calculate_2d_loss(batch["f0_gts"], f0_outputs, self.mse) energy_loss = calculate_2d_loss(batch["energy_gts"], energy_outputs, self.mse) mel_loss_before = calculate_3d_loss(batch["mel_gts"], mel_before, self.mae) mel_loss_after = calculate_3d_loss(batch["mel_gts"], mel_after, self.mae) per_example_losses = ( duration_loss + f0_loss + energy_loss + mel_loss_before + mel_loss_after ) dict_metrics_losses = { "duration_loss": duration_loss, "f0_loss": f0_loss, "energy_loss": energy_loss, "mel_loss_before": mel_loss_before, "mel_loss_after": mel_loss_after, } # reset self.reset_states_eval() gc.collect() return per_example_losses, dict_metrics_losses
def compute_per_example_generator_losses(self, batch, outputs): """Compute per example generator losses and return dict_metrics_losses Note that all element of the loss MUST has a shape [batch_size] and the keys of dict_metrics_losses MUST be in self.list_metrics_name. Args: batch: dictionary batch input return from dataloader outputs: outputs of the model Returns: per_example_losses: per example losses for each GPU, shape [B] dict_metrics_losses: dictionary loss. """ audios = batch["audios"] y_hat = outputs p_hat = self._discriminator(y_hat) p = self._discriminator(tf.expand_dims(audios, 2)) adv_loss = 0.0 for i in range(len(p_hat)): adv_loss += calculate_3d_loss(tf.ones_like(p_hat[i][-1]), p_hat[i][-1], loss_fn=self.mse_loss) adv_loss /= i + 1 # define feature-matching loss fm_loss = 0.0 for i in range(len(p_hat)): for j in range(len(p_hat[i]) - 1): fm_loss += calculate_3d_loss(p[i][j], p_hat[i][j], loss_fn=self.mae_loss) fm_loss /= (i + 1) * (j + 1) adv_loss += self.config["lambda_feat_match"] * fm_loss per_example_losses = adv_loss dict_metrics_losses = { "adversarial_loss": adv_loss, "fm_loss": fm_loss, "gen_loss": adv_loss, "mels_spectrogram_loss": calculate_2d_loss(audios, tf.squeeze(y_hat, -1), loss_fn=self.mels_loss), } return per_example_losses, dict_metrics_losses
def compute_per_example_generator_losses(self, batch, outputs): """Compute per example generator losses and return dict_metrics_losses Note that all element of the loss MUST has a shape [batch_size] and the keys of dict_metrics_losses MUST be in self.list_metrics_name. Args: batch: dictionary batch input return from dataloader outputs: outputs of the model Returns: per_example_losses: per example losses for each GPU, shape [B] dict_metrics_losses: dictionary loss. """ dict_metrics_losses = {} per_example_losses = 0.0 audios = batch["audios"] y_hat = outputs # calculate multi-resolution stft loss sc_loss, mag_loss = calculate_2d_loss( audios, tf.squeeze(y_hat, -1), self.stft_loss ) gen_loss = 0.5 * (sc_loss + mag_loss) if self.steps >= self.config["discriminator_train_start_steps"]: p_hat = self._discriminator(y_hat) p = self._discriminator(tf.expand_dims(audios, 2)) adv_loss = 0.0 adv_loss += calculate_3d_loss( tf.ones_like(p_hat), p_hat, loss_fn=self.mse_loss ) gen_loss += self.config["lambda_adv"] * adv_loss # update dict_metrics_losses dict_metrics_losses.update({"adversarial_loss": adv_loss}) dict_metrics_losses.update({"gen_loss": gen_loss}) dict_metrics_losses.update({"spectral_convergence_loss": sc_loss}) dict_metrics_losses.update({"log_magnitude_loss": mag_loss}) per_example_losses = gen_loss return per_example_losses, dict_metrics_losses
def compute_per_example_losses(self, batch, outputs): """Compute per example losses and return dict_metrics_losses Note that all element of the loss MUST has a shape [batch_size] and the keys of dict_metrics_losses MUST be in self.list_metrics_name. Args: batch: dictionary batch input return from dataloader outputs: outputs of the model Returns: per_example_losses: per example losses for each GPU, shape [B] dict_metrics_losses: dictionary loss. """ ( decoder_output, post_mel_outputs, stop_token_predictions, alignment_historys, ) = outputs mel_loss_before = calculate_3d_loss( batch["mel_gts"], decoder_output, loss_fn=self.mae ) mel_loss_after = calculate_3d_loss( batch["mel_gts"], post_mel_outputs, loss_fn=self.mae ) # calculate stop_loss max_mel_length = ( tf.reduce_max(batch["mel_lengths"]) if self.config["use_fixed_shapes"] is False else [self.config["max_mel_length"]] ) stop_gts = tf.expand_dims( tf.range(tf.reduce_max(max_mel_length), dtype=tf.int32), 0 ) # [1, max_len] stop_gts = tf.tile( stop_gts, [tf.shape(batch["mel_lengths"])[0], 1] ) # [B, max_len] stop_gts = tf.cast( tf.math.greater_equal(stop_gts, tf.expand_dims(batch["mel_lengths"], 1)), tf.float32, ) stop_token_loss = calculate_2d_loss( stop_gts, stop_token_predictions, loss_fn=self.binary_crossentropy ) # calculate guided attention loss. attention_masks = tf.cast( tf.math.not_equal(batch["g_attentions"], -1.0), tf.float32 ) loss_att = tf.reduce_sum( tf.abs(alignment_historys * batch["g_attentions"]) * attention_masks, axis=[1, 2], ) loss_att /= tf.reduce_sum(attention_masks, axis=[1, 2]) per_example_losses = ( stop_token_loss + mel_loss_before + mel_loss_after + loss_att ) dict_metrics_losses = { "stop_token_loss": stop_token_loss, "mel_loss_before": mel_loss_before, "mel_loss_after": mel_loss_after, "guided_attention_loss": loss_att, } return per_example_losses, dict_metrics_losses
def compute_per_example_generator_losses(self, batch, outputs): """Compute per example generator losses and return dict_metrics_losses Note that all element of the loss MUST has a shape [batch_size] and the keys of dict_metrics_losses MUST be in self.list_metrics_name. Args: batch: dictionary batch input return from dataloader outputs: outputs of the model Returns: per_example_losses: per example losses for each GPU, shape [B] dict_metrics_losses: dictionary loss. """ dict_metrics_losses = {} per_example_losses = 0.0 audios = batch["audios"] y_mb_hat = outputs y_hat = self.pqmf.synthesis(y_mb_hat) y_mb = self.pqmf.analysis(tf.expand_dims(audios, -1)) y_mb = tf.transpose(y_mb, (0, 2, 1)) # [B, subbands, T//subbands] y_mb = tf.reshape(y_mb, (-1, tf.shape(y_mb)[-1])) # [B * subbands, T'] y_mb_hat = tf.transpose(y_mb_hat, (0, 2, 1)) # [B, subbands, T//subbands] y_mb_hat = tf.reshape( y_mb_hat, (-1, tf.shape(y_mb_hat)[-1]) ) # [B * subbands, T'] # calculate sub/full band spectral_convergence and log mag loss. sub_sc_loss, sub_mag_loss = calculate_2d_loss( y_mb, y_mb_hat, self.sub_band_stft_loss ) sub_sc_loss = tf.reduce_mean( tf.reshape(sub_sc_loss, [-1, self.pqmf.subbands]), -1 ) sub_mag_loss = tf.reduce_mean( tf.reshape(sub_mag_loss, [-1, self.pqmf.subbands]), -1 ) full_sc_loss, full_mag_loss = calculate_2d_loss( audios, tf.squeeze(y_hat, -1), self.full_band_stft_loss ) # define generator loss gen_loss = 0.5 * (sub_sc_loss + sub_mag_loss) + 0.5 * ( full_sc_loss + full_mag_loss ) if self.steps >= self.config["discriminator_train_start_steps"]: p_hat = self._discriminator(y_hat) p = self._discriminator(tf.expand_dims(audios, 2)) adv_loss = 0.0 adv_loss += calculate_3d_loss( tf.ones_like(p_hat), p_hat, loss_fn=self.mse_loss ) gen_loss += self.config["lambda_adv"] * adv_loss # update dict_metrics_losses dict_metrics_losses.update({"adversarial_loss": adv_loss}) dict_metrics_losses.update({"gen_loss": gen_loss}) dict_metrics_losses.update({"subband_spectral_convergence_loss": sub_sc_loss}) dict_metrics_losses.update({"subband_log_magnitude_loss": sub_mag_loss}) dict_metrics_losses.update({"fullband_spectral_convergence_loss": full_sc_loss}) dict_metrics_losses.update({"fullband_log_magnitude_loss": full_mag_loss}) per_example_losses = gen_loss return per_example_losses, dict_metrics_losses