def _gta_forward(self, inp, tar, stop_prob, xvectors, training): #add xvector tar_inp = tar[:, :-1] tar_real = tar[:, 1:] tar_stop_prob = stop_prob[:, 1:] mel_len = int(tf.shape(tar_inp)[1]) tar_mel = tar_inp[:, 0::self.r, :] with tf.GradientTape() as tape: #add xvector in inputs model_out = self.__call__(inputs=inp, targets=tar_mel, xvectors=xvectors, training=training) loss, loss_vals = weighted_sum_losses( (tar_real, tar_stop_prob, tar_real), (model_out['final_output'][:, :mel_len, :], model_out['stop_prob'][:, :mel_len, :], model_out['mel_linear'][:, :mel_len, :]), self.loss, self.loss_weights) model_out.update({'loss': loss}) model_out.update({ 'losses': { 'output': loss_vals[0], 'stop_prob': loss_vals[1], 'mel_linear': loss_vals[2] } }) model_out.update({'reduced_target': tar_mel}) return model_out, tape
def _val_step(self, input_sequence, target_sequence, target_durations): target_durations = tf.expand_dims(target_durations, -1) mel_len = int(tf.shape(target_sequence)[1]) model_out = self.__call__(input_sequence, target_durations, training=False) loss, loss_vals = weighted_sum_losses((target_sequence, target_durations), (model_out['mel'][:, :mel_len, :], model_out['duration']), self.loss, self.loss_weights) model_out.update({'loss': loss}) model_out.update({'losses': {'mel': loss_vals[0], 'duration': loss_vals[1]}}) return model_out
def _train_step(self, input_sequence, target_sequence, target_durations): target_durations = tf.expand_dims(target_durations, -1) mel_len = int(tf.shape(target_sequence)[1]) with tf.GradientTape() as tape: model_out = self.__call__(input_sequence, target_durations, training=True) loss, loss_vals = weighted_sum_losses((target_sequence, target_durations), (model_out['mel'][:, :mel_len, :], model_out['duration']), self.loss, self.loss_weights) model_out.update({'loss': loss}) model_out.update({'losses': {'mel': loss_vals[0], 'duration': loss_vals[1]}}) gradients = tape.gradient(model_out['loss'], self.trainable_variables) self.optimizer.apply_gradients(zip(gradients, self.trainable_variables)) return model_out