def example(self, s, d, s_len, d_len, snr): """ Compute example for Deep Xi, i.e. observation (noisy-speech STMS) and target (gain). Argument/s: s - clean speech (dtype=tf.int32). d - noise (dtype=tf.int32). s_len - clean-speech length without padding (samples). d_len - noise length without padding (samples). snr - SNR level. Returns: x_STMS - noisy-speech short-time magnitude spectrum. gain - gain. n_frames - number of time-domain frames. """ s, d, x, n_frames = self.mix(s, d, s_len, d_len, snr) s_STMS, _ = self.polar_analysis(s) d_STMS, _ = self.polar_analysis(d) x_STMS, _ = self.polar_analysis(x) xi = self.xi(s_STMS, d_STMS) # instantaneous a priori SNR. gamma = self.gamma(x_STMS, d_STMS) # instantaneous a posteriori SNR. G = gfunc(xi=xi, gamma=gamma, gtype=self.gain) # IRM = tf.math.sqrt(tf.math.truediv(xi, tf.math.add(xi, self.one))) return x_STMS, G, n_frames
def enhanced_speech(self, x_STMS, x_STPS, xi_bar_hat, gtype): """ Compute enhanced speech. Argument/s: x_STMS - noisy-speech short-time magnitude spectrum. x_STPS - noisy-speech short-time phase spectrum. xi_bar_hat - mapped a priori SNR estimate. gtype - gain function type. Returns: enhanced speech. """ xi_hat = self.xi_map.inverse(xi_bar_hat) gamma_hat = tf.math.add(xi_hat, self.one) y_STMS = tf.math.multiply(x_STMS, gfunc(xi_hat, gamma_hat, gtype)) return self.polar_synthesis(y_STMS, x_STPS)
def enhanced_speech(self, x_STMS, x_STPS_xi_hat_mat, gamma_bar_hat, gtype): """ Compute enhanced speech. Argument/s: x_STMS - noisy-speech short-time magnitude spectrum. x_STPS_xi_hat_mat - tuple of noisy-speech short-time phase spectrum and a priori SNR loaded from .mat file. gamma_bar_hat - mapped a priori SNR estimate. gtype - gain function type. Returns: enhanced speech. """ gamma_hat = self.gamma_map.inverse(gamma_bar_hat) x_STPS, xi_hat_mat = x_STPS_xi_hat_mat xi_hat = xi_hat_mat['xi_hat'] y_STMS = tf.math.multiply(x_STMS, gfunc(xi_hat, gamma_hat, gtype)) return self.polar_synthesis(y_STMS, x_STPS)
def enhanced_speech(self, x_STMS, x_STPS, xi_gamma_bar_hat, gtype): """ Compute enhanced speech. Argument/s: x_STMS - noisy-speech short-time magnitude spectrum. x_STPS - noisy-speech short-time phase spectrum. xi_gamma_bar_hat - mapped a priori and a posteriorir SNR estimate. gtype - gain function type. Returns: enhanced speech. """ xi_bar_hat, gamma_bar_hat = tf.split(xi_gamma_bar_hat, num_or_size_splits=2, axis=-1) xi_hat = self.xi_map.inverse(xi_bar_hat) gamma_hat = self.gamma_map.inverse(gamma_bar_hat) y_STMS = tf.math.multiply(x_STMS, gfunc(xi_hat, gamma_hat, gtype)) return self.polar_synthesis(y_STMS, x_STPS)
def enhanced_speech(self, x_STMS_STPS, dummy, xi_s_stps_bar_hat, gtype): """ Compute enhanced speech. Argument/s: x_STMS_STPS - noisy-speech short-time magnitude and phase spectrum. dummy - dummy variable. xi_s_stps_bar_hat - mapped a priori SNR and clean-speech STPS estimate. gtype - gain function type. Returns: enhanced speech. """ x_STMS, _ = tf.split(x_STMS_STPS, num_or_size_splits=2, axis=-1) xi_bar_hat, s_stps_bar_hat = tf.split(xi_s_stps_bar_hat, num_or_size_splits=2, axis=-1) xi_hat = self.xi_map.inverse(xi_bar_hat) gamma_hat = tf.math.add(xi_hat, self.one) y_STPS = self.s_stps_map.inverse(s_stps_bar_hat) y_STMS = tf.math.multiply(x_STMS, gfunc(xi_hat, gamma_hat, gtype)) return self.polar_synthesis(y_STMS, y_STPS)
def enhanced_speech(self, x_STDCT, dummy, xi_cd_bar_hat, gtype): """ Compute enhanced speech. Argument/s: x_STDCT - noisy-speech short-time magnitude spectrum. dummy - dummy variable (not used). ____ - _____________________. gtype - gain function type. Returns: enhanced speech. """ xi_bar_hat, cd_bar_hat = tf.split(xi_cd_bar_hat, num_or_size_splits=2, axis=-1) xi_hat = self.xi_map.inverse(xi_bar_hat) gamma_hat = tf.math.add(xi_hat, self.one) cd_hat = self.cd_map.inverse(cd_bar_hat) cdm_hat = tf.math.greater(cd_hat, 0.0) y_STDCT = tf.math.multiply(x_STDCT, gfunc(xi_hat, gamma_hat, gtype, cdm_hat)) return self.stdct_synthesis(y_STDCT)
def test(self, test_x, test_x_len, test_x_base_names, test_s, test_s_len, test_s_base_names, test_epoch, model_path='model', gain='mmse-lsa', stats_path=None): """ Deep Xi testing. Objective measures are used to evaluate the performance of Deep Xi. Argument/s: test_x - noisy-speech test batch. test_x_len - noisy-speech test batch lengths. test_x_base_names - noisy-speech base names. test_s - clean-speech test batch. test_s_len - clean-speech test batch lengths. test_s_base_names - clean-speech base names. test_epoch - epoch to test. model_path - path to model directory. gain - gain function (see deepxi/args.py). stats_path - path to the saved statistics. """ if not isinstance(test_epoch, list): test_epoch = [test_epoch] if not isinstance(gain, list): gain = [gain] for e in test_epoch: for g in gain: if e < 1: raise ValueError("test_epoch must be greater than 0.") self.sample_stats(stats_path) self.model.load_weights(model_path + '/epoch-' + str(e - 1) + '/variables/variables') print("Processing observations...") x_STMS_batch, x_STPS_batch, n_frames = self.observation_batch( test_x, test_x_len) print("Performing inference...") xi_bar_hat_batch = self.model.predict(x_STMS_batch, batch_size=1, verbose=1) print("Performing synthesis and objective scoring...") results = {} batch_size = len(test_x_len) for i in tqdm(range(batch_size)): base_name = test_x_base_names[i] x_STMS = x_STMS_batch[i, :n_frames[i], :] x_STPS = x_STPS_batch[i, :n_frames[i], :] xi_bar_hat = xi_bar_hat_batch[i, :n_frames[i], :] xi_hat = self.xi_hat(xi_bar_hat) y_STMS = np.multiply(x_STMS, gfunc(xi_hat, xi_hat + 1, gtype=g)) y = self.polar_synthesis(y_STMS, x_STPS).numpy() for (j, basename) in enumerate(test_s_base_names): if basename in test_x_base_names[i]: ref_idx = j s = self.normalise(test_s[ref_idx, 0:test_s_len[ref_idx]]).numpy() y = y[0:len(s)] noise_source = test_x_base_names[i].split("_")[-2] snr_level = int(test_x_base_names[i].split("_")[-1][:-2]) results = self.add_score( results, (noise_source, snr_level, 'STOI'), 100 * stoi(s, y, self.f_s, extended=False)) results = self.add_score( results, (noise_source, snr_level, 'eSTOI'), 100 * stoi(s, y, self.f_s, extended=True)) results = self.add_score(results, (noise_source, snr_level, 'PESQ'), pesq(self.f_s, s, y, 'nb')) results = self.add_score( results, (noise_source, snr_level, 'MOS-LQO'), pesq(self.f_s, s, y, 'wb')) noise_sources, snr_levels, metrics = set(), set(), set() for key, value in results.items(): noise_sources.add(key[0]) snr_levels.add(key[1]) metrics.add(key[2]) if not os.path.exists("log/results"): os.makedirs("log/results") with open( "log/results/" + self.ver + "_e" + str(e) + '_' + g + ".csv", "w") as f: f.write("noise,snr_db") for k in sorted(metrics): f.write(',' + k) f.write('\n') for i in sorted(noise_sources): for j in sorted(snr_levels): f.write("{},{}".format(i, j)) for k in sorted(metrics): if (i, j, k) in results.keys(): f.write(",{:.2f}".format( np.mean(results[(i, j, k)]))) f.write('\n') avg_results = {} for i in sorted(noise_sources): for j in sorted(snr_levels): if (j >= self.min_snr) and (j <= self.max_snr): for k in sorted(metrics): if (i, j, k) in results.keys(): avg_results = self.add_score( avg_results, k, results[(i, j, k)]) if not os.path.exists("log/results/average.csv"): with open("log/results/average.csv", "w") as f: f.write("ver") for i in sorted(metrics): f.write("," + i) f.write('\n') with open("log/results/average.csv", "a") as f: f.write(self.ver + "_e" + str(e) + '_' + g) for i in sorted(metrics): if i in avg_results.keys(): f.write(",{:.2f}".format(np.mean(avg_results[i]))) f.write('\n')
def infer( ## NEED TO ADD DeepMMSE self, test_x, test_x_len, test_x_base_names, test_epoch, model_path='model', out_type='y', gain='mmse-lsa', out_path='out', stats_path=None, n_filters=40, ): """ Deep Xi inference. The specified 'out_type' is saved. Argument/s: test_x - noisy-speech test batch. test_x_len - noisy-speech test batch lengths. test_x_base_names - noisy-speech base names. test_epoch - epoch to test. model_path - path to model directory. out_type - output type (see deepxi/args.py). gain - gain function (see deepxi/args.py). out_path - path to save output files. stats_path - path to the saved statistics. """ if out_type == 'xi_hat': out_path = out_path + '/xi_hat' elif out_type == 'y': out_path = out_path + '/y/' + gain elif out_type == 'deepmmse': out_path = out_path + '/deepmmse' elif out_type == 'ibm_hat': out_path = out_path + '/ibm_hat' elif out_type == 'subband_ibm_hat': out_path = out_path + '/subband_ibm_hat' else: raise ValueError('Invalid output type.') if not os.path.exists(out_path): os.makedirs(out_path) if test_epoch < 1: raise ValueError("test_epoch must be greater than 0.") # The mel-scale filter bank is to compute an ideal binary mask (IBM) # estimate for log-spectral subband energies (LSSE). if out_type == 'subband_ibm_hat': mel_filter_bank = self.mel_filter_bank(n_filters) self.sample_stats(stats_path) self.model.load_weights(model_path + '/epoch-' + str(test_epoch - 1) + '/variables/variables') print("Processing observations...") x_STMS_batch, x_STPS_batch, n_frames = self.observation_batch( test_x, test_x_len) print("Performing inference...") xi_bar_hat_batch = self.model.predict(x_STMS_batch, batch_size=1, verbose=1) print("Performing synthesis...") batch_size = len(test_x_len) for i in tqdm(range(batch_size)): base_name = test_x_base_names[i] x_STMS = x_STMS_batch[i, :n_frames[i], :] x_STPS = x_STPS_batch[i, :n_frames[i], :] xi_bar_hat = xi_bar_hat_batch[i, :n_frames[i], :] xi_hat = self.xi_hat(xi_bar_hat) if out_type == 'xi_hat': save_mat(args.out_path + '/' + base_name + '.mat', xi_hat, 'xi_hat') elif out_type == 'y': y_STMS = np.multiply(x_STMS, gfunc(xi_hat, xi_hat + 1, gtype=gain)) y = self.polar_synthesis(y_STMS, x_STPS).numpy() save_wav(out_path + '/' + base_name + '.wav', y, self.f_s) elif out_type == 'deepmmse': d_PSD_hat = np.multiply( np.square(x_STMS), gfunc(xi_hat, xi_hat + 1, gtype='deepmmse')) save_mat(out_path + '/' + base_name + '.mat', d_PSD_hat, 'd_psd_hat') elif out_type == 'ibm_hat': ibm_hat = np.greater(xi_hat, 1.0).astype(bool) save_mat(out_path + '/' + base_name + '.mat', ibm_hat, 'ibm_hat') elif out_type == 'subband_ibm_hat': xi_hat_subband = np.matmul(xi_hat, mel_filter_bank.transpose()) subband_ibm_hat = np.greater(xi_hat_subband, 1.0).astype(bool) save_mat(out_path + '/' + base_name + '.mat', subband_ibm_hat, 'subband_ibm_hat') else: raise ValueError('Invalid output type.')
def infer( self, test_x, test_x_len, test_x_base_names, test_epoch, model_path='model', out_type='y', gain='mmse-lsa', out_path='out', n_filters=40, saved_data_path=None, ): """ Deep Xi inference. The specified 'out_type' is saved. Argument/s: test_x - noisy-speech test batch. test_x_len - noisy-speech test batch lengths. test_x_base_names - noisy-speech base names. test_epoch - epoch to test. model_path - path to model directory. out_type - output type (see deepxi/args.py). gain - gain function (see deepxi/args.py). out_path - path to save output files. saved_data_path - path to saved data necessary for enhancement. """ out_path_base = out_path if not isinstance(test_epoch, list): test_epoch = [test_epoch] if not isinstance(gain, list): gain = [gain] # The mel-scale filter bank is to compute an ideal binary mask (IBM) # estimate for log-spectral subband energies (LSSE). if out_type == 'subband_ibm_hat': mel_filter_bank = self.mel_filter_bank(n_filters) for e in test_epoch: if e < 1: raise ValueError("test_epoch must be greater than 0.") for g in gain: out_path = out_path_base + '/' + self.ver + '/' + 'e' + str( e) # output path. if out_type == 'xi_hat': out_path = out_path + '/xi_hat' elif out_type == 'gamma_hat': out_path = out_path + '/gamma_hat' elif out_type == 's_STPS_hat': out_path = out_path + '/s_STPS_hat' elif out_type == 'y': if self.inp_tgt_type == 'MagIRM': out_path = out_path + '/y' else: out_path = out_path + '/y/' + g elif out_type == 'deepmmse': out_path = out_path + '/deepmmse' elif out_type == 'ibm_hat': out_path = out_path + '/ibm_hat' elif out_type == 'subband_ibm_hat': out_path = out_path + '/subband_ibm_hat' elif out_type == 'cd_hat': out_path = out_path + '/cd_hat' else: raise ValueError('Invalid output type.') if not os.path.exists(out_path): os.makedirs(out_path) self.model.load_weights(model_path + '/epoch-' + str(e - 1) + '/variables/variables') print("Processing observations...") inp_batch, supplementary_batch, n_frames = self.observation_batch( test_x, test_x_len) print("Performing inference...") tgt_hat_batch = self.model.predict(inp_batch, batch_size=1, verbose=1) print("Saving outputs...") batch_size = len(test_x_len) for i in tqdm(range(batch_size)): base_name = test_x_base_names[i] inp = inp_batch[i, :n_frames[i], :] tgt_hat = tgt_hat_batch[i, :n_frames[i], :] # if tf.is_tensor(supplementary_batch): supplementary = supplementary_batch[i, :n_frames[i], :] if saved_data_path is not None: saved_data = read_mat(saved_data_path + '/' + base_name + '.mat') supplementary = (supplementary, saved_data) if out_type == 'xi_hat': xi_hat = self.inp_tgt.xi_hat(tgt_hat) save_mat(out_path + '/' + base_name + '.mat', xi_hat, 'xi_hat') elif out_type == 'gamma_hat': gamma_hat = self.inp_tgt.gamma_hat(tgt_hat) save_mat(out_path + '/' + base_name + '.mat', gamma_hat, 'gamma_hat') elif out_type == 's_STPS_hat': s_STPS_hat = self.inp_tgt.s_stps_hat(tgt_hat) save_mat(out_path + '/' + base_name + '.mat', s_STPS_hat, 's_STPS_hat') elif out_type == 'y': y = self.inp_tgt.enhanced_speech( inp, supplementary, tgt_hat, g).numpy() save_wav(out_path + '/' + base_name + '.wav', y, self.inp_tgt.f_s) elif out_type == 'deepmmse': xi_hat = self.inp_tgt.xi_hat(tgt_hat) d_PSD_hat = np.multiply( np.square(inp), gfunc(xi_hat, xi_hat + 1.0, gtype='deepmmse')) save_mat(out_path + '/' + base_name + '.mat', d_PSD_hat, 'd_psd_hat') elif out_type == 'ibm_hat': xi_hat = self.inp_tgt.xi_hat(tgt_hat) ibm_hat = np.greater(xi_hat, 1.0).astype(bool) save_mat(out_path + '/' + base_name + '.mat', ibm_hat, 'ibm_hat') elif out_type == 'subband_ibm_hat': xi_hat = self.inp_tgt.xi_hat(tgt_hat) xi_hat_subband = np.matmul(xi_hat, mel_filter_bank.transpose()) subband_ibm_hat = np.greater(xi_hat_subband, 1.0).astype(bool) save_mat(out_path + '/' + base_name + '.mat', subband_ibm_hat, 'subband_ibm_hat') elif out_type == 'cd_hat': cd_hat = self.inp_tgt.cd_hat(tgt_hat) save_mat(out_path + '/' + base_name + '.mat', cd_hat, 'cd_hat') else: raise ValueError('Invalid output type.')