def kl_loss(self, ff_dict, num_samples=100): teacher = self.teacher quant_chann = self.quant_chann use_mu_law = self.use_mu_law mel = ff_dict['mel'] x = ff_dict['x'] mean = ff_dict['mean_tot'] scale = ff_dict['scale_tot'] log_scale = ff_dict['log_scale_tot'] batch_size, length = x.get_shape().as_list() rl = self._logistic_0_1(batch_size * num_samples, length) mean = utils.tf_repeat(mean, [num_samples, 1]) scale = utils.tf_repeat(scale, [num_samples, 1]) # (x_i|x_<i), x given x_previous x_xp = rl * scale + mean # clip x and x_xp to real audio range [-1.0, 1.0) # if use_mu_law = True, # take iaf output as mu_law encoded real audio signal. x_scaled = self._clip_quant_scale(x, quant_chann, use_mu_law) x_xp_scaled = self._clip_quant_scale(x_xp, quant_chann, use_mu_law) wn_ff_dict = teacher.feed_forward({'wav': x_scaled, 'mel': mel}) te_mol = wn_ff_dict['out_params'] te_mol = utils.tf_repeat(te_mol, [num_samples, 1, 1]) # teacher always use log_scale, so use_log_scale of # loss_func.mol_log_probs is set to default value True. log_te_probs = loss_func.mol_log_probs( te_mol, x_xp_scaled, quant_chann) # H_Ps_Pt for batch * length H_Ps_Pt_bl = -tf.reduce_mean( tf.reshape(log_te_probs, [batch_size, num_samples, length]), axis=1) H_Ps = tf.reduce_mean(log_scale) + 2 H_Ps_Pt = tf.reduce_mean(H_Ps_Pt_bl) kl_loss = H_Ps_Pt - H_Ps return {'kl_loss': kl_loss, 'H_Ps': H_Ps, 'H_Ps_Pt': H_Ps_Pt}
def kl_loss(self, ff_dict, num_samples=100): teacher = self.teacher quant_chann = self.quant_chann mel = ff_dict['mel'] x = ff_dict['x'] mean = ff_dict['mean_tot'] scale = ff_dict['scale_tot'] log_scale = ff_dict['log_scale_tot'] batch_size, length = x.get_shape().as_list() rl = self._logistic_0_1(batch_size * num_samples, length) mean = utils.tf_repeat(mean, [num_samples, 1]) scale = utils.tf_repeat(scale, [num_samples, 1]) # (x_i|x_<i), x given x_previous from student x_xp = rl * scale + mean x_scaled = PWNHelper.clip_or_not_fn(self, x) x_xp_scaled = PWNHelper.clip_or_not_fn(self, x_xp) wn_ff_dict = teacher.feed_forward({'wav_scaled': x_scaled, 'mel': mel}) te_mol = wn_ff_dict['out_params'] te_mol = utils.tf_repeat(te_mol, [num_samples, 1, 1]) # teacher always use log_scale, so use_log_scale of # loss_func.mol_log_probs is set to default value True. log_te_probs = loss_func.mol_log_probs( te_mol, x_xp_scaled, quant_chann) # H_Ps_Pt for batch * length H_Ps_Pt_bl = -tf.reduce_mean( tf.reshape(log_te_probs, [batch_size, num_samples, length]), axis=1) H_Ps = tf.reduce_mean(log_scale) + 2 H_Ps_Pt = tf.reduce_mean(H_Ps_Pt_bl) kl_loss = H_Ps_Pt - H_Ps return {'kl_loss': kl_loss, 'H_Ps': H_Ps, 'H_Ps_Pt': H_Ps_Pt}