def power_loss(self, wav_dict): feat_fn = self.stft_feat_fn pred_wav = wav_dict['x'] orig_wav = wav_dict['wav'] pred_len = pred_wav.get_shape().as_list()[1] orig_len = orig_wav.get_shape().as_list()[1] # crop longer wave if pred_len > orig_len: pred_wav = self._trim(pred_wav, pred_len - orig_len) elif pred_len < orig_len: orig_wav = self._trim(orig_wav, orig_len - pred_len) orig_stft = mel_extractor._tf_stft(orig_wav) pred_stft = mel_extractor._tf_stft(pred_wav) orig_feat = PWNHelper.norm_or_not_fn(self, feat_fn(orig_stft)) pred_feat = PWNHelper.norm_or_not_fn(self, feat_fn(pred_stft)) diff = PWNHelper.diff_fn(orig_feat, pred_feat) avg_loss = PWNHelper.avg_loss_fn(diff) return {'power_loss': avg_loss}
def spec_feat_mean_std(train_path, feat_fn=lambda x: tf.pow(tf.abs(x), 2.0)): local_graph = tf.Graph() with local_graph.as_default(): input_vals = get_init_batch(train_path, batch_size=4096, seq_len=7680, first_n=10000)['wav'] ph = tf.placeholder(dtype=np.float32, shape=[4096, 7680]) feat = feat_fn(mel_extractor._tf_stft(ph)) tf.logging.info('Calculating mean and std for stft feat.') config = tf.ConfigProto(device_count={'GPU': 0}) sess = tf.Session(config=config, graph=local_graph) feat_val = sess.run(feat, feed_dict={ph: input_vals}) mean_val = np.mean(feat_val, axis=(0, 1)) std_val = np.std(feat_val, axis=(0, 1)) tf.logging.info('Done calculating mean and std for stft feat.') return mean_val, std_val