def sample_stats(self, stats_path='data', sample_size=1000, train_s_list=None, train_d_list=None): """ Computes statistics for each frequency component of the instantaneous a priori SNR in dB over a sample of the training set. The statistics are then used to map the instantaneous a priori SNR in dB between 0 and 1 using its cumulative distribution function. This forms the mapped a priori SNR (the training target). Argument/s: stats_path - path to the saved statistics. sample_size - number of training examples to compute the statistics from. train_s_list - train clean speech list. train_d_list - train noise list. """ if os.path.exists(stats_path + '/stats.npz'): print('Loading sample statistics...') with np.load(stats_path + '/stats.npz') as stats: self.mu = stats['mu_hat'] self.sigma = stats['sigma_hat'] elif train_s_list == None: raise ValueError( 'No stats.npz file exists. data/stats.p is available here: https://github.com/anicolson/DeepXi/blob/master/data/stats.npz.' ) else: print('Finding sample statistics...') s_sample_list = random.sample(self.train_s_list, sample_size) d_sample_list = random.sample(self.train_d_list, sample_size) s_sample, d_sample, s_sample_len, d_sample_len, snr_sample = self.wav_batch( s_sample_list, d_sample_list) snr_sample = np.array( random.choices(self.snr_levels, k=sample_size)) # snr_sample = np.random.randint(self.min_snr, self.max_snr + 1, sample_size) samples = [] for i in tqdm(range(s_sample.shape[0])): s_STMS, d_STMS, _, _ = self.mix(s_sample[i:i + 1], d_sample[i:i + 1], s_sample_len[i:i + 1], d_sample_len[i:i + 1], snr_sample[i:i + 1]) xi_db = self.xi_db(s_STMS, d_STMS) # instantaneous a priori SNR (dB). samples.append(np.squeeze(xi_db.numpy())) samples = np.vstack(samples) if len(samples.shape) != 2: raise ValueError('Incorrect shape for sample.') stats = { 'mu_hat': np.mean(samples, axis=0), 'sigma_hat': np.std(samples, axis=0) } self.mu, self.sigma = stats['mu_hat'], stats['sigma_hat'] if not os.path.exists(stats_path): os.makedirs(stats_path) np.savez(stats_path + '/stats.npz', mu_hat=stats['mu_hat'], sigma_hat=stats['sigma_hat']) save_mat(stats_path + '/stats.mat', stats, 'stats') print('Sample statistics saved.')
def sample( self, sample_size, sample_dir='data', ): """ Gathers a sample of the training set. The sample can be used to compute statistics for mapping functions. Argument/s: sample_size - number of training examples included in the sample. sample_dir - path to the saved sample. """ sample_path = sample_dir + '/sample' if os.path.exists(sample_path + '.npz'): print('Loading sample...') with np.load(sample_path + '.npz') as sample: s_sample = sample['s_sample'] d_sample = sample['d_sample'] x_sample = sample['x_sample'] wav_len = sample['wav_len'] elif self.train_s_list == None: raise ValueError('No sample.npz file exists.') else: if sample_size == None: raise ValueError("sample_size is not set.") print('Gathering a sample of the training set...') s_sample_list = random.sample(self.train_s_list, sample_size) d_sample_list = random.sample(self.train_d_list, sample_size) s_sample_int, d_sample_int, s_sample_len, d_sample_len, snr_sample = self.wav_batch( s_sample_list, d_sample_list) s_sample = np.zeros_like(s_sample_int, np.float32) d_sample = np.zeros_like(s_sample_int, np.float32) x_sample = np.zeros_like(s_sample_int, np.float32) for i in tqdm(range(s_sample.shape[0])): s, d, x, _ = self.inp_tgt.mix(s_sample_int[i:i + 1], d_sample_int[i:i + 1], s_sample_len[i:i + 1], d_sample_len[i:i + 1], snr_sample[i:i + 1]) s_sample[i, 0:s_sample_len[i]] = s d_sample[i, 0:s_sample_len[i]] = d x_sample[i, 0:s_sample_len[i]] = x wav_len = s_sample_len if not os.path.exists(sample_dir): os.makedirs(sample_dir) np.savez(sample_path + '.npz', s_sample=s_sample, d_sample=d_sample, x_sample=x_sample, wav_len=wav_len) sample = { 's_sample': s_sample, 'd_sample': d_sample, 'x_sample': x_sample, 'wav_len': wav_len } save_mat(sample_path + '.mat', sample, 'stats') print('Sample of the training set saved.') return s_sample, d_sample, x_sample, wav_len
def sample_old( self, sample_size, sample_dir='data', ): """ Gathers a sample of the training set. The sample can be used to compute statistics for mapping functions. Argument/s: sample_size - number of training examples included in the sample. sample_dir - path to the saved sample. """ sample_path = sample_dir + '/sample' if os.path.exists(sample_path + '.npz'): print('Loading sample...') with np.load(sample_path + '.npz') as sample: samples_s_STMS = sample['samples_s_STMS'] samples_d_STMS = sample['samples_d_STMS'] samples_x_STMS = sample['samples_x_STMS'] elif self.train_s_list == None: raise ValueError('No sample.npz file exists. data/sample.npz is available here: https://github.com/anicolson/DeepXi/blob/master/data/sample.npz.') else: if sample_size == None: raise ValueError("sample_size is not set.") print('Gathering a sample of the training set...') s_sample_list = random.sample(self.train_s_list, sample_size) d_sample_list = random.sample(self.train_d_list, sample_size) s_sample, d_sample, s_sample_len, d_sample_len, snr_sample = self.wav_batch(s_sample_list, d_sample_list) snr_sample = np.array(random.choices(self.snr_levels, k=sample_size)) samples_s_STMS = [] samples_d_STMS = [] samples_x_STMS = [] for i in tqdm(range(s_sample.shape[0])): s, d, x, _ = self.inp_tgt.mix(s_sample[i:i+1], d_sample[i:i+1], s_sample_len[i:i+1], d_sample_len[i:i+1], snr_sample[i:i+1]) s_STMS, _ = self.inp_tgt.polar_analysis(s) d_STMS, _ = self.inp_tgt.polar_analysis(d) x_STMS, _ = self.inp_tgt.polar_analysis(x) samples_s_STMS.append(np.squeeze(s_STMS.numpy())) samples_d_STMS.append(np.squeeze(d_STMS.numpy())) samples_x_STMS.append(np.squeeze(x_STMS.numpy())) samples_s_STMS = np.vstack(samples_s_STMS) samples_d_STMS = np.vstack(samples_d_STMS) samples_x_STMS = np.vstack(samples_x_STMS) if len(samples_s_STMS.shape) != 2: raise ValueError('Incorrect shape for s_STMS sample.') if len(samples_d_STMS.shape) != 2: raise ValueError('Incorrect shape for d_STMS sample.') if len(samples_x_STMS.shape) != 2: raise ValueError('Incorrect shape for x_STMS sample.') if not os.path.exists(sample_dir): os.makedirs(sample_dir) np.savez(sample_path + '.npz', samples_s_STMS=samples_s_STMS, samples_d_STMS=samples_d_STMS, samples_x_STMS=samples_x_STMS) sample = {'samples_s_STMS': samples_s_STMS, 'samples_d_STMS': samples_d_STMS, 'samples_x_STMS': samples_x_STMS} save_mat(sample_path + '.mat', sample, 'stats') print('Sample of the training set saved.') return samples_s_STMS, samples_d_STMS, samples_x_STMS
def train( self, train_s_list, train_d_list, model_path='model', val_s=None, val_d=None, val_s_len=None, val_d_len=None, val_snr=None, val_flag=True, val_save_path=None, mbatch_size=8, max_epochs=200, resume_epoch=0, stats_path=None, sample_size=None, eval_example=False, save_model=True, log_iter=False, ): """ Deep Xi training. Argument/s: train_s_list - clean-speech training list. train_d_list - noise training list. model_path - model save path. val_s - clean-speech validation batch. val_d - noise validation batch. val_s_len - clean-speech validation sequence length batch. val_d_len - noise validation sequence length batch. val_snr - SNR validation batch. val_flag - perform validation. val_save_path - validation batch save path. mbatch_size - mini-batch size. max_epochs - maximum number of epochs. resume_epoch - epoch to resume training from. stats_path - path to save sample statistics. sample_size - sample size. eval_example - evaluate a mini-batch of training examples. save_model - save architecture, weights, and training configuration. log_iter - log training loss for each training iteration. """ self.train_s_list = train_s_list self.train_d_list = train_d_list self.mbatch_size = mbatch_size self.n_examples = len(self.train_s_list) self.n_iter = math.ceil(self.n_examples / mbatch_size) self.sample_stats(stats_path, sample_size, train_s_list, train_d_list) train_dataset = self.dataset(max_epochs - resume_epoch) if val_flag: val_set = self.val_batch(val_save_path, val_s, val_d, val_s_len, val_d_len, val_snr) val_steps = len(val_set[0]) else: val_set, val_steps = None, None if eval_example: print("Saving a mini-batch of training examples in .mat files...") x_STMS_batch, xi_bar_batch, seq_mask_batch = list( train_dataset.take(1).as_numpy_iterator())[0] save_mat('./x_STMS_batch.mat', x_STMS_batch, 'x_STMS_batch') save_mat('./xi_bar_batch.mat', xi_bar_batch, 'xi_bar_batch') save_mat('./seq_mask_batch.mat', seq_mask_batch, 'seq_mask_batch') print("Testing if add_noise() works correctly...") s, d, s_len, d_len, snr_tgt = self.wav_batch( train_s_list[0:mbatch_size], train_d_list[0:mbatch_size]) (_, s, d) = self.add_noise_batch(self.normalise(s), self.normalise(d), s_len, d_len, snr_tgt) for (i, _) in enumerate(s): snr_act = self.snr_db(s[i][0:s_len[i]], d[i][0:d_len[i]]) print('SNR target|actual: {:.2f}|{:.2f} (dB).'.format( snr_tgt[i], snr_act)) if not os.path.exists(model_path): os.makedirs(model_path) if not os.path.exists("log"): os.makedirs("log") if not os.path.exists("log/iter"): os.makedirs("log/iter") callbacks = [] callbacks.append( CSVLogger("log/" + self.ver + ".csv", separator=',', append=True)) if save_model: callbacks.append(SaveWeights(model_path)) # if log_iter: callbacks.append(CSVLoggerIter("log/iter/" + self.ver + ".csv", separator=',', append=True)) if resume_epoch > 0: self.model.load_weights(model_path + "/epoch-" + str(resume_epoch - 1) + "/variables/variables") self.model.compile(sample_weight_mode="temporal", loss="binary_crossentropy", optimizer=Adam(lr=0.001, clipvalue=1.0)) self.model.fit(x=train_dataset, initial_epoch=resume_epoch, epochs=max_epochs, steps_per_epoch=self.n_iter, callbacks=callbacks, validation_data=val_set, validation_steps=val_steps)
def infer( ## NEED TO ADD DeepMMSE self, test_x, test_x_len, test_x_base_names, test_epoch, model_path='model', out_type='y', gain='mmse-lsa', out_path='out', stats_path=None, n_filters=40, ): """ Deep Xi inference. The specified 'out_type' is saved. Argument/s: test_x - noisy-speech test batch. test_x_len - noisy-speech test batch lengths. test_x_base_names - noisy-speech base names. test_epoch - epoch to test. model_path - path to model directory. out_type - output type (see deepxi/args.py). gain - gain function (see deepxi/args.py). out_path - path to save output files. stats_path - path to the saved statistics. """ if out_type == 'xi_hat': out_path = out_path + '/xi_hat' elif out_type == 'y': out_path = out_path + '/y/' + gain elif out_type == 'deepmmse': out_path = out_path + '/deepmmse' elif out_type == 'ibm_hat': out_path = out_path + '/ibm_hat' elif out_type == 'subband_ibm_hat': out_path = out_path + '/subband_ibm_hat' else: raise ValueError('Invalid output type.') if not os.path.exists(out_path): os.makedirs(out_path) if test_epoch < 1: raise ValueError("test_epoch must be greater than 0.") # The mel-scale filter bank is to compute an ideal binary mask (IBM) # estimate for log-spectral subband energies (LSSE). if out_type == 'subband_ibm_hat': mel_filter_bank = self.mel_filter_bank(n_filters) self.sample_stats(stats_path) self.model.load_weights(model_path + '/epoch-' + str(test_epoch - 1) + '/variables/variables') print("Processing observations...") x_STMS_batch, x_STPS_batch, n_frames = self.observation_batch( test_x, test_x_len) print("Performing inference...") xi_bar_hat_batch = self.model.predict(x_STMS_batch, batch_size=1, verbose=1) print("Performing synthesis...") batch_size = len(test_x_len) for i in tqdm(range(batch_size)): base_name = test_x_base_names[i] x_STMS = x_STMS_batch[i, :n_frames[i], :] x_STPS = x_STPS_batch[i, :n_frames[i], :] xi_bar_hat = xi_bar_hat_batch[i, :n_frames[i], :] xi_hat = self.xi_hat(xi_bar_hat) if out_type == 'xi_hat': save_mat(args.out_path + '/' + base_name + '.mat', xi_hat, 'xi_hat') elif out_type == 'y': y_STMS = np.multiply(x_STMS, gfunc(xi_hat, xi_hat + 1, gtype=gain)) y = self.polar_synthesis(y_STMS, x_STPS).numpy() save_wav(out_path + '/' + base_name + '.wav', y, self.f_s) elif out_type == 'deepmmse': d_PSD_hat = np.multiply( np.square(x_STMS), gfunc(xi_hat, xi_hat + 1, gtype='deepmmse')) save_mat(out_path + '/' + base_name + '.mat', d_PSD_hat, 'd_psd_hat') elif out_type == 'ibm_hat': ibm_hat = np.greater(xi_hat, 1.0).astype(bool) save_mat(out_path + '/' + base_name + '.mat', ibm_hat, 'ibm_hat') elif out_type == 'subband_ibm_hat': xi_hat_subband = np.matmul(xi_hat, mel_filter_bank.transpose()) subband_ibm_hat = np.greater(xi_hat_subband, 1.0).astype(bool) save_mat(out_path + '/' + base_name + '.mat', subband_ibm_hat, 'subband_ibm_hat') else: raise ValueError('Invalid output type.')
def infer( self, test_x, test_x_len, test_x_base_names, test_epoch, model_path='model', out_type='y', gain='mmse-lsa', out_path='out', n_filters=40, saved_data_path=None, ): """ Deep Xi inference. The specified 'out_type' is saved. Argument/s: test_x - noisy-speech test batch. test_x_len - noisy-speech test batch lengths. test_x_base_names - noisy-speech base names. test_epoch - epoch to test. model_path - path to model directory. out_type - output type (see deepxi/args.py). gain - gain function (see deepxi/args.py). out_path - path to save output files. saved_data_path - path to saved data necessary for enhancement. """ out_path_base = out_path if not isinstance(test_epoch, list): test_epoch = [test_epoch] if not isinstance(gain, list): gain = [gain] # The mel-scale filter bank is to compute an ideal binary mask (IBM) # estimate for log-spectral subband energies (LSSE). if out_type == 'subband_ibm_hat': mel_filter_bank = self.mel_filter_bank(n_filters) for e in test_epoch: if e < 1: raise ValueError("test_epoch must be greater than 0.") for g in gain: out_path = out_path_base + '/' + self.ver + '/' + 'e' + str( e) # output path. if out_type == 'xi_hat': out_path = out_path + '/xi_hat' elif out_type == 'gamma_hat': out_path = out_path + '/gamma_hat' elif out_type == 's_STPS_hat': out_path = out_path + '/s_STPS_hat' elif out_type == 'y': if self.inp_tgt_type == 'MagIRM': out_path = out_path + '/y' else: out_path = out_path + '/y/' + g elif out_type == 'deepmmse': out_path = out_path + '/deepmmse' elif out_type == 'ibm_hat': out_path = out_path + '/ibm_hat' elif out_type == 'subband_ibm_hat': out_path = out_path + '/subband_ibm_hat' elif out_type == 'cd_hat': out_path = out_path + '/cd_hat' else: raise ValueError('Invalid output type.') if not os.path.exists(out_path): os.makedirs(out_path) self.model.load_weights(model_path + '/epoch-' + str(e - 1) + '/variables/variables') print("Processing observations...") inp_batch, supplementary_batch, n_frames = self.observation_batch( test_x, test_x_len) print("Performing inference...") tgt_hat_batch = self.model.predict(inp_batch, batch_size=1, verbose=1) print("Saving outputs...") batch_size = len(test_x_len) for i in tqdm(range(batch_size)): base_name = test_x_base_names[i] inp = inp_batch[i, :n_frames[i], :] tgt_hat = tgt_hat_batch[i, :n_frames[i], :] # if tf.is_tensor(supplementary_batch): supplementary = supplementary_batch[i, :n_frames[i], :] if saved_data_path is not None: saved_data = read_mat(saved_data_path + '/' + base_name + '.mat') supplementary = (supplementary, saved_data) if out_type == 'xi_hat': xi_hat = self.inp_tgt.xi_hat(tgt_hat) save_mat(out_path + '/' + base_name + '.mat', xi_hat, 'xi_hat') elif out_type == 'gamma_hat': gamma_hat = self.inp_tgt.gamma_hat(tgt_hat) save_mat(out_path + '/' + base_name + '.mat', gamma_hat, 'gamma_hat') elif out_type == 's_STPS_hat': s_STPS_hat = self.inp_tgt.s_stps_hat(tgt_hat) save_mat(out_path + '/' + base_name + '.mat', s_STPS_hat, 's_STPS_hat') elif out_type == 'y': y = self.inp_tgt.enhanced_speech( inp, supplementary, tgt_hat, g).numpy() save_wav(out_path + '/' + base_name + '.wav', y, self.inp_tgt.f_s) elif out_type == 'deepmmse': xi_hat = self.inp_tgt.xi_hat(tgt_hat) d_PSD_hat = np.multiply( np.square(inp), gfunc(xi_hat, xi_hat + 1.0, gtype='deepmmse')) save_mat(out_path + '/' + base_name + '.mat', d_PSD_hat, 'd_psd_hat') elif out_type == 'ibm_hat': xi_hat = self.inp_tgt.xi_hat(tgt_hat) ibm_hat = np.greater(xi_hat, 1.0).astype(bool) save_mat(out_path + '/' + base_name + '.mat', ibm_hat, 'ibm_hat') elif out_type == 'subband_ibm_hat': xi_hat = self.inp_tgt.xi_hat(tgt_hat) xi_hat_subband = np.matmul(xi_hat, mel_filter_bank.transpose()) subband_ibm_hat = np.greater(xi_hat_subband, 1.0).astype(bool) save_mat(out_path + '/' + base_name + '.mat', subband_ibm_hat, 'subband_ibm_hat') elif out_type == 'cd_hat': cd_hat = self.inp_tgt.cd_hat(tgt_hat) save_mat(out_path + '/' + base_name + '.mat', cd_hat, 'cd_hat') else: raise ValueError('Invalid output type.')
def train( self, train_s_list, train_d_list, mbatch_size, max_epochs, loss_fnc, model_path='model', val_s=None, val_d=None, val_s_len=None, val_d_len=None, val_snr=None, val_flag=True, val_save_path=None, resume_epoch=0, eval_example=False, save_model=True, log_iter=False, ): """ Deep Xi training. Argument/s: train_s_list - clean-speech training list. train_d_list - noise training list. model_path - model save path. val_s - clean-speech validation batch. val_d - noise validation batch. val_s_len - clean-speech validation sequence length batch. val_d_len - noise validation sequence length batch. val_snr - SNR validation batch. val_flag - perform validation. val_save_path - validation batch save path. mbatch_size - mini-batch size. max_epochs - maximum number of epochs. resume_epoch - epoch to resume training from. eval_example - evaluate a mini-batch of training examples. save_model - save architecture, weights, and training configuration. log_iter - log training loss for each training iteration. loss_fnc - loss function. """ self.train_s_list = train_s_list self.train_d_list = train_d_list self.mbatch_size = mbatch_size self.n_examples = len(self.train_s_list) self.n_iter = math.ceil(self.n_examples / mbatch_size) train_dataset = self.dataset(max_epochs - resume_epoch) if val_flag: val_set = self.val_batch(val_save_path, val_s, val_d, val_s_len, val_d_len, val_snr) val_steps = len(val_set[0]) else: val_set, val_steps = None, None if not os.path.exists(model_path): os.makedirs(model_path) if not os.path.exists("log/loss"): os.makedirs("log/loss") callbacks = [] callbacks.append( CSVLogger("log/loss/" + self.ver + ".csv", separator=',', append=True)) if save_model: callbacks.append(SaveWeights(model_path)) # if log_iter: callbacks.append(CSVLoggerIter("log/iter/" + self.ver + ".csv", separator=',', append=True)) if resume_epoch > 0: self.model.load_weights(model_path + "/epoch-" + str(resume_epoch - 1) + "/variables/variables") if eval_example: print("Saving a mini-batch of training examples in .mat files...") inp_batch, tgt_batch, seq_mask_batch = list( train_dataset.take(1).as_numpy_iterator())[0] save_mat('./inp_batch.mat', inp_batch, 'inp_batch') save_mat('./tgt_batch.mat', tgt_batch, 'tgt_batch') save_mat('./seq_mask_batch.mat', seq_mask_batch, 'seq_mask_batch') print("Testing if add_noise() works correctly...") s, d, s_len, d_len, snr_tgt = self.wav_batch( self.train_s_list[0:mbatch_size], self.train_d_list[0:mbatch_size]) (_, s, d) = self.inp_tgt.add_noise_batch(self.inp_tgt.normalise(s), self.inp_tgt.normalise(d), s_len, d_len, snr_tgt) for (i, _) in enumerate(s): snr_act = self.inp_tgt.snr_db(s[i][0:s_len[i]], d[i][0:d_len[i]]) print('SNR target|actual: {:.2f}|{:.2f} (dB).'.format( snr_tgt[i], snr_act)) if self.network_type == "MHANet": lr_schedular = TransformerSchedular(self.network.d_model, self.network.warmup_steps) opt = Adam(learning_rate=lr_schedular, clipvalue=1.0, beta_1=0.9, beta_2=0.98, epsilon=1e-9) else: opt = Adam(learning_rate=0.001, clipvalue=1.0) if loss_fnc == "BinaryCrossentropy": loss = BinaryCrossentropy() elif loss_fnc == "MeanSquaredError": loss = MeanSquaredError() else: raise ValueError("Invalid loss function") self.model.compile(sample_weight_mode="temporal", loss=loss, optimizer=opt) print("SNR levels used for training:") print(self.snr_levels) self.model.fit(x=train_dataset, initial_epoch=resume_epoch, epochs=max_epochs, steps_per_epoch=self.n_iter, callbacks=callbacks, validation_data=val_set, validation_steps=val_steps)