Esempio n. 1
0
def Noise_mbatch(noise_list, mbatch_size, clean_seq_len):
    '''
	Creates a padded mini-batch of noise speech wavs.

	Inputs:
		noise_list - training list for the noise files.
		mbatch_size - size of the mini-batch.
		clean_seq_len - sequence length of each clean speech file in the mini-batch.

	Outputs:
		mbatch - matrix of paded wavs stored as a numpy array.
		seq_len - length of each wavs strored as a numpy array.
	'''

    mbatch_list = random.sample(
        noise_list, mbatch_size)  # get mini-batch list from training list.
    for i in range(len(clean_seq_len)):
        flag = True
        while flag:
            if mbatch_list[i]['seq_len'] < clean_seq_len[i]:
                mbatch_list[i] = random.choice(noise_list)
            else:
                flag = False
    maxlen = max([dic['seq_len'] for dic in mbatch_list
                  ])  # find maximum length wav in mini-batch.
    seq_len = []  # list of the wav lengths.
    mbatch = np.zeros([len(mbatch_list), maxlen],
                      np.int16)  # numpy array for wav matrix.
    for i in range(len(mbatch_list)):
        (wav, _) = read_wav(
            mbatch_list[i]['file_path'])  # read wav from given file path.
        mbatch[i, :mbatch_list[i]['seq_len']] = wav  # add wav to numpy array.
        seq_len.append(
            mbatch_list[i]['seq_len'])  # append length of wav to list.
    return mbatch, np.array(seq_len, np.int32)
Esempio n. 2
0
def Clean_mbatch(clean_list, mbatch_size, start_idx, end_idx):
    '''
	Creates a padded mini-batch of clean speech wavs.

	Inputs:
		clean_list - training list for the clean speech files.
		mbatch_size - size of the mini-batch.
		version - version name.

	Outputs:
		mbatch - matrix of paded wavs stored as a numpy array.
		seq_len - length of each wavs strored as a numpy array.
		clean_list - training list for the clean files.
	'''
    mbatch_list = clean_list[
        start_idx:end_idx]  # get mini-batch list from training list.
    maxlen = max([dic['seq_len'] for dic in mbatch_list
                  ])  # find maximum length wav in mini-batch.
    seq_len = []  # list of the wavs lengths.
    mbatch = np.zeros([len(mbatch_list), maxlen],
                      np.int16)  # numpy array for wav matrix.
    for i in range(len(mbatch_list)):
        (wav, _) = read_wav(
            mbatch_list[i]['file_path'])  # read wav from given file path.
        mbatch[i, :mbatch_list[i]['seq_len']] = wav  # add wav to numpy array.
        seq_len.append(
            mbatch_list[i]['seq_len'])  # append length of wav to list.
    return mbatch, np.array(seq_len, np.int32)
Esempio n. 3
0
def Batch(fdir, snr_l):
    '''
	REQUIRES REWRITING.

	Places all of the test waveforms from the list into a numpy array. 
	SPHERE format cannot be used. 'glob' is used to support Unix style pathname 
	pattern expansions. Waveforms are padded to the maximum waveform length. The 
	waveform lengths are recorded so that the correct lengths can be sliced 
	for feature extraction. The SNR levels of each test file are placed into a
	numpy array. Also returns a list of the file names.

	Inputs:
		fdir - directory containing the waveforms.
		fnames - filename/s of the waveforms.
		snr_l - list of the SNR levels used.

	Outputs:
		wav_np - matrix of paded waveforms stored as a numpy array.
		len_np - length of each waveform strored as a numpy array.
		snr_test_np - numpy array of all the SNR levels for the test set.
		fname_l - list of filenames.
	'''
    fname_l = []  # list of file names.
    wav_l = []  # list for waveforms.
    snr_test_l = []  # list of SNR levels for the test set.
    # if isinstance(fnames, str): fnames = [fnames] # if string, put into list.
    fnames = ['*.wav', '*.flac', '*.mp3']
    for fname in fnames:
        for fpath in glob.glob(os.path.join(fdir, fname)):
            for snr in snr_l:
                if fpath.find('_' + str(snr) + 'dB') != -1:
                    snr_test_l.append(snr)  # append SNR level.
            (wav, _) = read_wav(fpath)  # read waveform from given file path.
            if np.isnan(wav).any() or np.isinf(wav).any():
                raise ValueError('Error: NaN or Inf value. File path: %s.' %
                                 (file_path))
            wav_l.append(wav)  # append.
            fname_l.append(os.path.basename(
                os.path.splitext(fpath)[0]))  # append name.
    len_l = []  # list of the waveform lengths.
    maxlen = max(len(wav) for wav in wav_l)  # maximum length of waveforms.
    wav_np = np.zeros([len(wav_l), maxlen],
                      np.int16)  # numpy array for waveform matrix.
    for (i, wav) in zip(range(len(wav_l)), wav_l):
        wav_np[i, :len(wav)] = wav  # add waveform to numpy array.
        len_l.append(len(wav))  # append length of waveform to list.
    return wav_np, np.array(len_l, np.int32), np.array(snr_test_l,
                                                       np.int32), fname_l
Esempio n. 4
0
def infer2(sess, net, args):
	print("Inference...", )
	print (args.test_x_list)
	net.saver.restore(sess, args.model_path + '/epoch-' + str(args.epoch)) # load model from epoch.
	
	if args.out_type == 'xi_hat': args.out_path = args.out_path + '/xi_hat'
	elif args.out_type == 'y': args.out_path = args.out_path + '/' + args.gain + '/y'
	elif args.out_type == 'ibm_hat': args.out_path = args.out_path + '/ibm_hat'
	else: ValueError('Incorrect output type.')

	if not os.path.exists(args.out_path): os.makedirs(args.out_path) # make output directory.

	for j in tqdm(args.test_x_list):
		(wav, _) = read_wav(j['file_path']) # read wav from given file path.		
		input_feat = sess.run(net.infer_feat, feed_dict={net.s_ph: [wav], net.s_len_ph: [j['seq_len']]}) # sample of training set.
		xi_bar_hat = sess.run(net.infer_output, feed_dict={net.input_ph: input_feat[0], 
			net.nframes_ph: input_feat[1], net.training_ph: False}) # output of network.
		xi_hat = xi.xi_hat(xi_bar_hat, args.stats['mu_hat'], args.stats['sigma_hat'])

		file_name = j['file_path'].rsplit('/',1)[1].split('.')[0]

		if args.out_type == 'xi_hat':
			spio.savemat(args.out_path + '/' + file_name + '.mat', {'xi_hat':xi_hat})

		elif args.out_type == 'y':
			y_MAG = np.multiply(input_feat[0], gain.gfunc(xi_hat, xi_hat+1, gtype=args.gain))
			y = np.squeeze(sess.run(net.y, feed_dict={net.y_MAG_ph: y_MAG, 
				net.x_PHA_ph: input_feat[2], net.nframes_ph: input_feat[1], net.training_ph: False})) # output of network.
			if np.isnan(y).any(): ValueError('NaN values found in enhanced speech.')
			if np.isinf(y).any(): ValueError('Inf values found in enhanced speech.')
			print (args.out_path + '/' + file_name + '.wav')
			utils.save_wav(args.out_path + '/' + file_name + '.wav', args.f_s, y)

		elif args.out_type == 'ibm_hat':
			ibm_hat = np.greater(xi_hat, 1.0)
			spio.savemat(args.out_path + '/' + file_name + '.mat', {'ibm_hat':ibm_hat})

	print('Inference complete.')