def main(_): json_dir = './config.json' with open(json_dir) as config_json: config = json.load(config_json) tf.logging.set_verbosity(tf.logging.INFO) # Start a new TensorFlow session. sess = tf.InteractiveSession() phase_specs = tf.placeholder( tf.float32, shape=[None, config['context_window_width'], 129, 4], name='phase_specs') model_settings = model.create_model_settings( dim_direction_label=config['dim_direction_label'], sample_rate=config["sample_rate"], win_len=config['win_len'], win_shift=config['win_shift'], nDFT=config['nDFT'], context_window_width=config['context_window_width']) with tf.variable_scope('CNN'): predict_logits = model.doa_cnn(phase_specs=phase_specs, model_settings=model_settings, is_training=True) CNN_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='CNN') print('-' * 80) print('CNN vars') nparams = 0 for v in CNN_vars: v_shape = v.get_shape().as_list() v_n = reduce(lambda x, y: x * y, v_shape) nparams += v_n print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name)) print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) / (1024 * 1024))) print('-' * 80) tf.global_variables_initializer().run() init_local_variable = tf.local_variables_initializer() init_local_variable.run() if config['start_checkpoint']: model.load_variables_from_checkpoint(sess, config['start_checkpoint'], var_list=CNN_vars) rir_data_dir = config['rir_data_dir'] rir_file_list = glob.glob(os.path.join(rir_data_dir, "*.wav")) reverb = config['reverb'] reverb.sort() room_index = config['room_idx'] room_index.sort() # find testing files testing_file_list = glob.glob( os.path.join(config['testing_data_dir'], "*.wav")) if not len(testing_file_list): Exception("No wav files found at " + testing_file_list) if not len(rir_file_list): Exception("No wav files found at " + rir_data_dir) for room_idx, room in enumerate(room_index): for reverb_idx, reverb_percent in enumerate(reverb): reverb_wav = input_data.gen_moving_direct_wav( wav_dir=config['testing_data_dir'], rir_dir=config['rir_data_dir'], doa_interval=config['direction_range'], deg_per_sec=config['deg_per_sec'], reverb_percent=reverb_percent, room_index=room) voiced_idx, voiced_percent = input_data.get_dual_channel_voiced_idx( reverb_wav, win_len=config['win_len'], win_shift=config['win_shift'], nDFT=config['nDFT'], context_window_width=config['context_window_width'], rms_thre=3e-1) duration = reverb_wav.shape[1] / 16e3 testing_specs = input_data.get_reverb_specs( reverb_wav=reverb_wav, win_len=config['win_len'], win_shift=config['win_shift'], nDFT=config['nDFT'], context_window_width=config['context_window_width']) num_frames = testing_specs.shape[0] print(num_frames) label, label_argmax = input_data.get_moving_wav_labels( num_frames, config['win_shift'], config['deg_per_sec'], config['direction_range']) logits = sess.run(predict_logits, feed_dict={phase_specs: testing_specs}) testing_predict = eval.get_deg_from_logits( logits, doa_interval=config['direction_range'], num_doa_class=config['dim_direction_label']) wavfile.write(filename='./moving.wav', data=np.transpose(reverb_wav), rate=16000) time_idx = np.arange(0, len(reverb_wav[0, :]), math.floor(len(reverb_wav[0, :]) / 5)) time_text = time_idx * duration / len(reverb_wav[0, :]) time_text = [str(round(float(label), 2)) for label in time_text] idx = range(len(label_argmax)) label_idx = np.arange(0, len(label_argmax), math.floor(len(label_argmax) / 5)) label_text = label_idx * duration / len(label_argmax) label_text = [str(round(float(label), 2)) for label in label_text] plt.figure(figsize=(20, 10)) plt.subplot(311) plt.xlabel('time (s)') plt.xticks(time_idx, time_text) plt.ylabel('X') plt.ylim(-1, 1) plt.plot(reverb_wav[0, :]) ax = plt.gca() ax.xaxis.set_label_coords(1.05, -0.025) plt.subplot(312) plt.xlabel('time (s)') plt.xticks(time_idx, time_text) plt.ylim(-1, 1) plt.ylabel('Y') plt.plot(reverb_wav[1, :]) ax = plt.gca() ax.xaxis.set_label_coords(1.05, -0.025) # only plot result for voiced part testing_predict = testing_predict.astype(float) silent_idx = np.logical_not(voiced_idx) testing_predict[silent_idx] = np.nan label_argmax = label_argmax.astype(float) label_argmax[silent_idx] = np.nan plt.subplot(313) plt.ylim(0, 140) plt.ylabel('DOA / degree') plt.xlabel('time (s)') plt.xticks(label_idx, label_text) plt.plot(idx, label_argmax, 'bs', label='ground truth', markersize=2.15) plt.plot(idx, testing_predict, 'r.', label='predict', markersize=2) plt.legend(loc='upper left') plt.grid(True) ax = plt.gca() ax.xaxis.set_label_coords(1.05, -0.025) fig_save_path = os.path.join( './figures', 'v4_voiced', os.path.basename(config['testing_data_dir'])) if not os.path.exists(fig_save_path): os.makedirs(fig_save_path) file_name = 'moving_plot_reverb' + str( reverb_percent) + '_room' + str(room) + '.png' save_path = os.path.join(fig_save_path, file_name) plt.savefig(save_path)
def main(_): # import config json_dir = './config.json' with open(json_dir) as config_json: config = json.load(config_json) # define noisy specs input_specs = tf.placeholder( tf.float32, shape=[None, config['context_window_width'], 257, 2], name='specs') # define clean specs target_specs = tf.placeholder(tf.float32, shape=[None, 257, 2], name='ground_truth') # create SE-FCN with tf.variable_scope('SEFCN'): model_out = model.se_fcn(input_specs, config['nDFT'], config['context_window_width']) model_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='SEFCN') print('-' * 80) print('SE-FCN vars') nparams = 0 for v in model_vars: v_shape = v.get_shape().as_list() v_n = reduce(lambda x, y: x * y, v_shape) nparams += v_n print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name)) print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) / (1024 * 1024))) print('-' * 80) # define loss and the optimizer mse = tf.losses.mean_squared_error(target_specs, model_out) sess = tf.InteractiveSession() model_path = os.path.join(config['model_dir'], config['param_file']) # load model parameters from checkpoint model.load_variables_from_checkpoint(sess, model_path) # run the test & save the test results tf.logging.set_verbosity(tf.logging.ERROR) testing_file_list = glob.glob( os.path.join(config['test_tedlium_wav_dir'], "*.wav")) print('testing set size: ', len(testing_file_list)) test_snr = config['test_snr'] for file_idx, testing_file in enumerate(testing_file_list): _, clean_wav = input_data.read_wav(testing_file, config['sampling_rate']) stm_path = os.path.join( config['test_stm_path'], os.path.basename(testing_file).split(".wav")[0] + '.stm') utter_pos = input_data.get_utter_pos(stm_path, config['sampling_rate']) for noise_idx in range(len(config['test_noise'])): noise_wav_path = os.path.join( config['test_noise_path'], config['test_noise'][noise_idx] + '.wav') _, noise_wav = input_data.read_wav(noise_wav_path, config['sampling_rate']) for snr_idx in range(len(test_snr)): for utter_index in range(config['how_many_testing_utter']): utter_wav, noisy_wav, _, _ = input_data.get_noisy_wav_tedlium( clean_wav=clean_wav, noise_wav=noise_wav, utter_pos=utter_pos, pos_index=utter_index, snr=test_snr[snr_idx], utter_percentage=config['speech_percentage']) segment = int( math.ceil( len(noisy_wav) / (config['wav_length_per_batch'] * config['sampling_rate']))) for segment_idx in range(segment): noisy_specs, clean_specs = input_data.get_seg_specs( mix_wav=noisy_wav, utter_wav=utter_wav, wav_length_per_seg=config['wav_length_per_batch'], seg_idx=segment_idx, win_len=config['win_len'], win_shift=config['win_shift'], context_window_width=config[ 'context_window_width'], fs=config['sampling_rate'], nDFT=config['nDFT']) seg_specs, seg_mse = sess.run([model_out, mse], feed_dict={ input_specs: noisy_specs, target_specs: clean_specs }) print("processing file: " + testing_file, " " * 5, "seg:", "{}/{}".format(segment_idx + 1, segment), " " * 5, "proc num batch:", input_specs.shape[0], " " * 5, "seg mse:", format(seg_mse, '.5f')) seg_specs = np.vstack(seg_specs) seg_specs_real = seg_specs[:, :, 0] seg_specs_imag = seg_specs[:, :, 1] if segment_idx == 0: rec_test_out_real = seg_specs_real rec_test_out_imag = seg_specs_imag else: rec_test_out_real = np.concatenate( (rec_test_out_real, seg_specs_real), axis=0) rec_test_out_imag = np.concatenate( (rec_test_out_imag, seg_specs_imag), axis=0) rec_wav = output_data.rec_wav( mag_spec=rec_test_out_real, spec_imag=rec_test_out_imag, win_len=config['win_len'], win_shift=config['win_shift'], nDFT=config['nDFT']) save_path = os.path.join( config['save_testing_results_dir'], 'test', str(config['test_noise'][noise_idx]), str(test_snr[snr_idx])) if not os.path.exists(save_path): os.makedirs(save_path) comp_save_path = os.path.join( save_path, os.path.basename(testing_file).split(".wav")[0] + '_U' + str(utter_index) + '.wav') output_data.save_wav_file(comp_save_path, rec_wav, config['sampling_rate']) save_path = os.path.join( config['save_testing_results_dir'], 'mix', str(config['test_noise'][noise_idx]), str(test_snr[snr_idx])) if not os.path.exists(save_path): os.makedirs(save_path) comp_save_path = os.path.join( save_path, os.path.basename(testing_file).split(".wav")[0] + '_U' + str(utter_index) + '.wav') output_data.save_wav_file(comp_save_path, noisy_wav, config['sampling_rate']) save_path = os.path.join( config['save_testing_results_dir'], 'clean', str(config['test_noise'][noise_idx]), str(test_snr[snr_idx])) if not os.path.exists(save_path): os.makedirs(save_path) comp_save_path = os.path.join( save_path, os.path.basename(testing_file).split(".wav")[0] + '_U' + str(utter_index) + '.wav') output_data.save_wav_file(comp_save_path, utter_wav, config['sampling_rate']) np.set_printoptions(precision=3, suppress=True)
def main(_): # We want to see all the logging messages for this tutorial. tf.logging.set_verbosity(tf.logging.INFO) # Start a new TensorFlow session. sess = tf.InteractiveSession() labels = FLAGS.labels.split(',') label_count = len(labels) # Place data loading and preprocessing on the cpu with tf.device('/cpu:0'): raw_data = Data(FLAGS.data_dir, labels, FLAGS.validation_percentage, FLAGS.testing_percentage) tr_data = ImageDataGenerator(raw_data.get_data('training'), raw_data.get_label_to_index(), FLAGS.batch_size) val_data = ImageDataGenerator(raw_data.get_data('validation'), raw_data.get_label_to_index(), FLAGS.batch_size) te_data = ImageDataGenerator(raw_data.get_data('testing'), raw_data.get_label_to_index(), FLAGS.batch_size) # create an reinitializable iterator given the dataset structure iterator = tf.data.Iterator.from_structure( tr_data.dataset.output_types, tr_data.dataset.output_shapes) next_batch = iterator.get_next() # Ops for initializing the two different iterators training_init_op = iterator.make_initializer(tr_data.dataset) validation_init_op = iterator.make_initializer(val_data.dataset) testing_init_op = iterator.make_initializer(te_data.dataset) # Figure out the learning rates for each training phase. Since it's often # effective to have high learning rates at the start of training, followed by # lower levels towards the end, the number of steps and learning rates can be # specified as comma-separated lists to define the rate at each stage. For # example --how_many_training_epochs=10000,3000 --learning_rate=0.001,0.0001 # will run 13,000 training loops in total, with a rate of 0.001 for the first # 10,000, and 0.0001 for the final 3,000. training_epochs_list = list( map(int, FLAGS.how_many_training_epochs.split(','))) learning_rates_list = list(map(float, FLAGS.learning_rate.split(','))) if len(training_epochs_list) != len(learning_rates_list): raise Exception( '--how_many_training_epochs and --learning_rate must be equal length ' 'lists, but are %d and %d long instead' % (len(training_epochs_list), len(learning_rates_list))) input_xs = tf.placeholder(tf.float32, [None, FLAGS.image_hw, FLAGS.image_hw, 3], name='input_xs') logits, dropout_prob = model.create_model(input_xs, label_count, FLAGS.model_architecture, is_training=True) # Define loss and optimizer ground_truth_input = tf.placeholder(tf.int64, [None], name='groundtruth_input') # Optionally we can add runtime checks to spot when NaNs or other symptoms of # numerical errors start occurring during training. control_dependencies = [] if FLAGS.check_nans: checks = tf.add_check_numerics_ops() control_dependencies = [checks] # Create the back propagation and training evaluation machinery in the graph. with tf.name_scope('cross_entropy'): cross_entropy_mean = tf.losses.sparse_softmax_cross_entropy( labels=ground_truth_input, logits=logits) tf.summary.scalar('cross_entropy', cross_entropy_mean) with tf.name_scope('train'), tf.control_dependencies(control_dependencies): learning_rate_input = tf.placeholder(tf.float32, [], name='learning_rate_input') momentum = tf.placeholder(tf.float32, [], name='momentum') # train_step = tf.train.GradientDescentOptimizer(learning_rate_input).minimize(cross_entropy_mean) # train_step = tf.train.MomentumOptimizer(learning_rate_input, momentum, use_nesterov=True).minimize(cross_entropy_mean) # train_step = tf.train.AdamOptimizer(learning_rate_input).minimize(cross_entropy_mean) # train_step = tf.train.AdadeltaOptimizer(learning_rate_input).minimize(cross_entropy_mean) train_step = tf.train.RMSPropOptimizer( learning_rate_input, momentum).minimize(cross_entropy_mean) predicted_indices = tf.argmax(logits, 1) correct_prediction = tf.equal(predicted_indices, ground_truth_input) confusion_matrix = tf.confusion_matrix(ground_truth_input, predicted_indices, num_classes=label_count) evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) tf.summary.scalar('accuracy', evaluation_step) global_step = tf.train.get_or_create_global_step() increment_global_step = tf.assign(global_step, global_step + 1) saver = tf.train.Saver(tf.global_variables()) # Merge all the summaries and write them out to /tmp/retrain_logs (by default) merged_summaries = tf.summary.merge_all() train_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/train', sess.graph) validation_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/validation') tf.global_variables_initializer().run() start_epoch = 1 start_checkpoint_epoch = 0 if FLAGS.start_checkpoint: model.load_variables_from_checkpoint(sess, FLAGS.start_checkpoint) tmp = FLAGS.start_checkpoint tmp = tmp.split('-') tmp.reverse() start_checkpoint_epoch = int(tmp[0]) start_epoch = start_checkpoint_epoch + 1 # calculate training epochs max training_epochs_max = np.sum(training_epochs_list) if start_checkpoint_epoch != training_epochs_max: tf.logging.info('Training from epoch: %d ', start_epoch) # Saving as Protocol Buffer (pb) # tf.train.write_graph(sess.graph_def, FLAGS.train_dir, # FLAGS.model_architecture + '.pbtxt') tf.train.write_graph(sess.graph_def, FLAGS.train_dir, FLAGS.model_architecture + '.pb', as_text=False) # Save list of words. with gfile.GFile( os.path.join(FLAGS.train_dir, FLAGS.model_architecture + '_labels.txt'), 'w') as f: f.write('\n'.join(raw_data.labels_list)) # Get the number of training/validation steps per epoch tr_batches_per_epoch = int(tr_data.data_size / FLAGS.batch_size) if tr_data.data_size % FLAGS.batch_size > 0: tr_batches_per_epoch += 1 val_batches_per_epoch = int(val_data.data_size / FLAGS.batch_size) if val_data.data_size % FLAGS.batch_size > 0: val_batches_per_epoch += 1 te_batches_per_epoch = int(te_data.data_size / FLAGS.batch_size) if te_data.data_size % FLAGS.batch_size > 0: te_batches_per_epoch += 1 ############################ # Training loop. ############################ for training_epoch in xrange(start_epoch, training_epochs_max + 1): # Figure out what the current learning rate is. training_epochs_sum = 0 for i in range(len(training_epochs_list)): training_epochs_sum += training_epochs_list[i] if training_epoch <= training_epochs_sum: learning_rate_value = learning_rates_list[i] break # Initialize iterator with the training dataset sess.run(training_init_op) for step in range(tr_batches_per_epoch): # Pull the image samples we'll use for training. train_batch_xs, train_batch_ys = sess.run(next_batch) # Run the graph with this batch of training data. train_summary, train_accuracy, cross_entropy_value, _, _ = sess.run( [ merged_summaries, evaluation_step, cross_entropy_mean, train_step, increment_global_step ], feed_dict={ input_xs: train_batch_xs, ground_truth_input: train_batch_ys, learning_rate_input: learning_rate_value, momentum: 0.95, dropout_prob: 0.5 }) train_writer.add_summary(train_summary, step) tf.logging.info( 'Epoch #%d, Step #%d: rate %f, accuracy %.1f%%, cross entropy %f' % (training_epoch, step, learning_rate_value, train_accuracy * 100, cross_entropy_value)) # Validate the model on the entire validation set print("{} Start validation".format(datetime.datetime.now())) # Reinitialize iterator with the validation dataset sess.run(validation_init_op) total_val_accuracy = 0 validation_count = 0 total_conf_matrix = None for i in range(val_batches_per_epoch): validation_batch_xs, validation_batch_ys = sess.run(next_batch) # Run a validation step and capture training summaries for TensorBoard # with the `merged` op. validation_summary, validation_accuracy, conf_matrix = sess.run( [merged_summaries, evaluation_step, confusion_matrix], feed_dict={ input_xs: validation_batch_xs, ground_truth_input: validation_batch_ys, dropout_prob: 1.0 }) validation_writer.add_summary(validation_summary, training_epoch) total_val_accuracy += validation_accuracy validation_count += 1 if total_conf_matrix is None: total_conf_matrix = conf_matrix else: total_conf_matrix += conf_matrix total_val_accuracy /= validation_count tf.logging.info('Confusion Matrix:\n %s' % (total_conf_matrix)) tf.logging.info('Step %d: Validation accuracy = %.1f%% (N=%d)' % (training_epoch, total_val_accuracy * 100, raw_data.get_size('validation'))) # Save the model checkpoint periodically. if (training_epoch % FLAGS.save_step_interval == 0 or training_epoch == training_epochs_max): checkpoint_path = os.path.join(FLAGS.train_dir, FLAGS.model_architecture + '.ckpt') tf.logging.info('Saving to "%s-%d"', checkpoint_path, training_epoch) saver.save(sess, checkpoint_path, global_step=training_epoch) ############################ # For Evaluate ############################ start = datetime.datetime.now() print("{} Start testing".format(start)) # Reinitialize iterator with the Evaluate dataset sess.run(testing_init_op) total_test_accuracy = 0 test_count = 0 total_conf_matrix = None for i in range(te_batches_per_epoch): test_batch_xs, test_batch_ys = sess.run(next_batch) test_accuracy, conf_matrix = sess.run( [evaluation_step, confusion_matrix], feed_dict={ input_xs: test_batch_xs, ground_truth_input: test_batch_ys, dropout_prob: 1.0 }) total_test_accuracy += test_accuracy test_count += 1 if total_conf_matrix is None: total_conf_matrix = conf_matrix else: total_conf_matrix += conf_matrix total_test_accuracy /= test_count tf.logging.info('Confusion Matrix:\n %s' % (total_conf_matrix)) tf.logging.info('Final test accuracy = %.1f%% (N=%d)' % (total_test_accuracy * 100, raw_data.get_size('testing'))) end = datetime.datetime.now() print('End testing: ', end) print('total testing time: ', end - start) ############################ # start prediction ############################ print("{} Start prediction".format(datetime.datetime.now())) id2name = {i: name for i, name in enumerate(labels)} submission = dict() # Place data loading and preprocessing on the cpu raw_data2 = prediction_data.Data(FLAGS.prediction_data_dir) pre_data = prediction_data.ImageDataGenerator(raw_data2.get_data(), FLAGS.prediction_batch_size) # create an reinitializable iterator given the dataset structure iterator = tf.data.Iterator.from_structure(pre_data.dataset.output_types, pre_data.dataset.output_shapes) next_batch = iterator.get_next() # Ops for initializing the two different iterators prediction_init_op = iterator.make_initializer(pre_data.dataset) # Get the number of training/validation steps per epoch pre_batches_per_epoch = int(pre_data.data_size / FLAGS.prediction_batch_size) if pre_data.data_size % FLAGS.prediction_batch_size > 0: pre_batches_per_epoch += 1 count = 0 # Initialize iterator with the prediction dataset sess.run(prediction_init_op) for i in range(pre_batches_per_epoch): fingerprints, fnames = sess.run(next_batch) prediction = sess.run([predicted_indices], feed_dict={ input_xs: fingerprints, dropout_prob: 1.0 }) size = len(fnames) for n in xrange(0, size): submission[fnames[n].decode('UTF-8')] = id2name[prediction[0][n]] count += size print(count, ' completed') # make submission.csv if not os.path.exists(FLAGS.result_dir): os.makedirs(FLAGS.result_dir) fout = open(os.path.join( FLAGS.result_dir, 'submission_' + FLAGS.model_architecture + '_' + FLAGS.how_many_training_epochs + '.csv'), 'w', encoding='utf-8', newline='') writer = csv.writer(fout) writer.writerow(['file', 'species']) for key in sorted(submission.keys()): writer.writerow([key, submission[key]]) fout.close()
def main(_): # We want to see all the logging messages for this tutorial. tf.logging.set_verbosity(tf.logging.INFO) # Start a new TensorFlow session. sess = tf.InteractiveSession() labels = FLAGS.labels.split(',') label_count = len(labels) training_epochs_list = list(map(int, FLAGS.how_many_training_epochs.split(','))) learning_rates_list = list(map(float, FLAGS.learning_rate.split(','))) if len(training_epochs_list) != len(learning_rates_list): raise Exception( '--how_many_training_epochs and --learning_rate must be equal length ' 'lists, but are %d and %d long instead' % (len(training_epochs_list), len(learning_rates_list))) input_xs = tf.placeholder( tf.float32, [None, IMAGE_HEIGHT, IMAGE_WIDTH, 3], name='input_xs') logits, dropout_prob = models.create_model( input_xs, label_count, FLAGS.model_architecture, is_training=True) # Define loss and optimizer ground_truth_input = tf.placeholder(tf.int64, [None], name='groundtruth_input') # Optionally we can add runtime checks to spot when NaNs or other symptoms of # numerical errors start occurring during training. control_dependencies = [] if FLAGS.check_nans: checks = tf.add_check_numerics_ops() control_dependencies = [checks] # Create the back propagation and training evaluation machinery in the graph. with tf.name_scope('cross_entropy'): cross_entropy_mean = tf.losses.sparse_softmax_cross_entropy( labels=ground_truth_input, logits=logits) tf.summary.scalar('cross_entropy', cross_entropy_mean) with tf.name_scope('train'), tf.control_dependencies(control_dependencies): learning_rate_input = tf.placeholder(tf.float32, [], name='learning_rate_input') momentum = tf.placeholder(tf.float32, [], name='momentum') # train_step = tf.train.GradientDescentOptimizer(learning_rate_input).minimize(cross_entropy_mean) train_step = tf.train.MomentumOptimizer(learning_rate_input, momentum, use_nesterov=True).minimize(cross_entropy_mean) # train_step = tf.train.AdamOptimizer(learning_rate_input).minimize(cross_entropy_mean) # train_step = tf.train.AdadeltaOptimizer(learning_rate_input).minimize(cross_entropy_mean) # train_step = tf.train.RMSPropOptimizer(learning_rate_input, momentum).minimize(cross_entropy_mean) predicted_indices = tf.argmax(logits, 1) correct_prediction = tf.equal(predicted_indices, ground_truth_input) confusion_matrix = tf.confusion_matrix( ground_truth_input, predicted_indices, num_classes=label_count) evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) tf.summary.scalar('accuracy', evaluation_step) global_step = tf.train.get_or_create_global_step() increment_global_step = tf.assign(global_step, global_step + 1) saver = tf.train.Saver(tf.global_variables()) merged_summaries = tf.summary.merge_all() tf.global_variables_initializer().run() ############################ # start prediction ############################ print("{} Start prediction".format(datetime.datetime.now())) id2name = {i: name for i, name in enumerate(labels)} submission = dict() # Place data loading and preprocessing on the cpu raw_data2 = prediction_data.Data(FLAGS.prediction_data_dir) pre_data = prediction_data.ImageDataGenerator(raw_data2.get_data(), FLAGS.prediction_batch_size) # create an reinitializable iterator given the dataset structure iterator = tf.data.Iterator.from_structure(pre_data.dataset.output_types, pre_data.dataset.output_shapes) next_batch = iterator.get_next() # Ops for initializing the two different iterators prediction_init_op = iterator.make_initializer(pre_data.dataset) # Get the number of training/validation steps per epoch pre_batches_per_epoch = int(np.floor(pre_data.data_size / FLAGS.prediction_batch_size)) + 1 print("Test Size : {}".format(raw_data2.get_size())) count = 0; sess.run(prediction_init_op) ckpt_list = FLAGS.ckpt_list.split(',') ckpt_size = len(ckpt_list) for i in range(pre_batches_per_epoch): pred_labels = [] pred_xs, fnames = sess.run(next_batch) for j in range(ckpt_size): models.load_variables_from_checkpoint(sess, ckpt_list[j]) prediction, predicted_label = sess.run([predicted_indices, logits], feed_dict={ input_xs: pred_xs, dropout_prob: 1.0 }) pred_prob = tf.nn.softmax(predicted_label) pred_labels.append(sess.run(pred_prob)) pred_label_array = np.array(pred_labels) ensemble_pred_labels = np.mean(pred_label_array, axis = 0) ensemble_class_pred = np.argmax(ensemble_pred_labels, axis = 1) size = len(fnames) for n in xrange(0, size): submission[fnames[n].decode('UTF-8')] = id2name[ensemble_class_pred[n]] count += size print(count, ' completed') # make submission.csv if not os.path.exists(FLAGS.result_dir): os.makedirs(FLAGS.result_dir) fout = open(os.path.join(FLAGS.result_dir, 'submission_' + FLAGS.model_architecture + '_ensemble_1_3.csv'), 'w', encoding='utf-8', newline='') writer = csv.writer(fout) writer.writerow(['file', 'species']) for key in sorted(submission.keys()): writer.writerow([key, submission[key]]) fout.close()
def main(_): # import config json_dir = './config.json' with open(json_dir) as config_json: config = json.load(config_json) # define noisy specs input_specs = tf.placeholder( tf.float32, shape=[None, config['context_window_width'], 257, 2], name='specs') # create SE-FCN with tf.variable_scope('SEFCN'): model_out = model.se_fcn(input_specs, config['nDFT'], config['context_window_width']) model_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='SEFCN') print('-' * 80) print('SE-FCN vars') nparams = 0 for v in model_vars: v_shape = v.get_shape().as_list() v_n = reduce(lambda x, y: x * y, v_shape) nparams += v_n print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name)) print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) / (1024 * 1024))) print('-' * 80) sess = tf.InteractiveSession() model_path = os.path.join(config['model_dir'], config['param_file']) # load model parameters from checkpoint model.load_variables_from_checkpoint(sess, model_path) for root, dirs, files in os.walk(config['recording_dir']): for basename in files: if not fnmatch.fnmatch(basename, '*.wav'): continue subdir = os.path.basename(os.path.normpath(root)) filename = os.path.join(root, basename) fs, mix_wav = wavfile.read(filename) mix_wav = mix_wav / (2**15 - 1) max_amp = np.max(np.abs(mix_wav)) segment = int( math.ceil( len(mix_wav) / (config['seg_recording_length'] * fs))) for segment_idx in range(segment): testing_specs = input_data.get_seg_testing_specs( mix_wav=mix_wav, fs=fs, wav_length_per_seg=config['seg_recording_length'], seg_idx=segment_idx, win_len=config['win_len'], win_shift=config['win_shift'], nDFT=config['nDFT'], context_window=config['context_window_width']) seg_specs = sess.run([model_out], feed_dict={input_specs: testing_specs}) print("processing file: " + filename, " " * 5, "seg:", "{}/{}".format(segment_idx + 1, segment), " " * 5, "proc num batch:", testing_specs.shape[0]) seg_specs = np.vstack(seg_specs) seg_specs_real = seg_specs[:, :, 0] seg_specs_imag = seg_specs[:, :, 1] if segment_idx == 0: rec_test_out_real = seg_specs_real rec_test_out_imag = seg_specs_imag else: rec_test_out_real = np.concatenate( (rec_test_out_real, seg_specs_real), axis=0) rec_test_out_imag = np.concatenate( (rec_test_out_imag, seg_specs_imag), axis=0) rec_wav = output_data.rec_wav(mag_spec=rec_test_out_real, spec_imag=rec_test_out_imag, win_len=config['win_len'], win_shift=config['win_shift'], nDFT=config['nDFT']) rec_wav = rec_wav * max_amp save_path = os.path.join(config['save_processed_recordings_dir'], subdir) if not os.path.exists(save_path): os.makedirs(save_path) comp_save_path = os.path.join(save_path, basename) output_data.save_wav_file(comp_save_path, rec_wav, fs) np.set_printoptions(precision=3, suppress=True)
def main(_): json_dir = './config.json' with open(json_dir) as config_json: config = json.load(config_json) tf.logging.set_verbosity(tf.logging.INFO) # Start a new TensorFlow session. sess = tf.InteractiveSession() phase_specs = tf.placeholder( tf.float32, shape=[None, config['context_window_width'], 129, 4], name='phase_specs') ground_truth_doa_label = tf.placeholder( tf.float32, shape=[None, config['dim_direction_label']], name='ground_truth_input') model_settings = model.create_model_settings( dim_direction_label=config['dim_direction_label'], sample_rate=config["sample_rate"], win_len=config['win_len'], win_shift=config['win_shift'], nDFT=config['nDFT'], context_window_width=config['context_window_width']) with tf.variable_scope('CNN'): predict_logits = model.doa_cnn(phase_specs=phase_specs, model_settings=model_settings, is_training=True) CNN_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='CNN') print('-' * 80) print('CNN vars') nparams = 0 for v in CNN_vars: v_shape = v.get_shape().as_list() v_n = reduce(lambda x, y: x * y, v_shape) nparams += v_n print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name)) print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) / (1024 * 1024))) print('-' * 80) cross_entropy = tf.losses.softmax_cross_entropy( onehot_labels=ground_truth_doa_label, logits=predict_logits) mean_cross_entropy = tf.reduce_mean(cross_entropy) acc, acc_op = tf.metrics.accuracy(labels=tf.argmax(ground_truth_doa_label, 1), predictions=tf.argmax(predict_logits, 1)) pc_acc, pc_acc_op = tf.metrics.mean_per_class_accuracy( labels=tf.argmax(ground_truth_doa_label, 1), predictions=tf.argmax(predict_logits, 1), num_classes=config['dim_direction_label']) tf.summary.scalar('cross_entropy', mean_cross_entropy) tf.summary.scalar('class_accuracy', acc_op) tf.summary.histogram('per_class_accuracy', pc_acc_op) extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) global_step = tf.train.get_or_create_global_step() with tf.name_scope('train'), tf.control_dependencies(extra_update_ops): adam = tf.train.AdamOptimizer(config['Adam_learn_rate']) # rms = tf.train.RMSPropOptimizer(config['Adam_learn_rate']) train_step = adam.minimize(cross_entropy, global_step=global_step, var_list=CNN_vars) # train_step = rms.minimize(cross_entropy, global_step=global_step, var_list=CNN_vars) saver = tf.train.Saver(tf.global_variables()) # Merge all the summaries and write them out to /tmp/retrain_logs (by default) merged_summaries = tf.summary.merge_all() train_writer = tf.summary.FileWriter(config['summaries_dir'], sess.graph) # tf.global_variables_initializer().run() start_step = 1 tf.global_variables_initializer().run() init_local_variable = tf.local_variables_initializer() init_local_variable.run() if config['start_checkpoint']: model.load_variables_from_checkpoint(sess, config['start_checkpoint']) start_step = global_step.eval(session=sess) tf.logging.info('Training from step: %d ', start_step) # Save graph.pbtxt. tf.train.write_graph(sess.graph_def, config['train_dir'], 'model.pbtxt') # find training files training_data_dir = config['training_data_dir'] training_file_list = glob.glob(os.path.join(training_data_dir, "*.wav")) training_speech_dir = config['training_speech_dir'] training_speech_list = glob.glob(os.path.join(training_speech_dir, "**", "*.wav"), recursive=True) rir_data_dir = config['rir_data_dir'] rir_file_list = glob.glob(os.path.join(rir_data_dir, "*.wav")) reverb = config['reverb'] reverb.sort() room_index = config['room_idx'] room_index.sort() # find testing files testing_file_list = glob.glob( os.path.join(config['testing_data_dir'], "*.wav")) if not len(training_file_list): Exception("No wav files found at " + training_data_dir) if not len(rir_file_list): Exception("No wav files found at " + rir_data_dir) tf.logging.info("Number of training wav files: %d", len(training_file_list)) # Training loop. how_many_training_steps = config['how_many_training_steps'] for training_step in range(start_step, int(how_many_training_steps + 1)): training_file_idx = random.randint(0, len(training_file_list) - 1) # rir_idx = random.randint(0, len(rir_file_list)-1) # rir_idx = training_step % (1+config['direction_range'][1]) rir_idx = training_step % len(rir_file_list) training_filename = training_file_list[training_file_idx] rir_filename = rir_file_list[rir_idx] reverb_percent = int( rir_filename.split('reverb_')[1].split('Percent_')[0]) if reverb_percent == 75 or reverb_percent == 65: if random.randint(0, 1): speech_file_idx = random.randint(0, len(training_speech_list) - 1) training_filename = training_speech_list[speech_file_idx] reverb_wav, training_phase_specs = input_data.get_input_specs( training_filename, rir_filename, config['win_len'], config['win_shift'], config['nDFT'], config['context_window_width'], config['max_wav_length']) num_frames = training_phase_specs.shape[0] training_doa_label = input_data.get_direction_label( rir_filename, config['dim_direction_label'], config['direction_range'], num_frames) training_summary, training_cross_entropy, _, _ = sess.run( [ merged_summaries, mean_cross_entropy, train_step, init_local_variable ], feed_dict={ phase_specs: training_phase_specs, ground_truth_doa_label: training_doa_label }) print("Step: ", training_step, " " * 10, "cross entropy: ", format(training_cross_entropy, '.5f'), " " * 10, "rir: ", format(reverb_percent, '2.0f'), " " * 10, "training file: ", os.path.basename(training_filename)) train_writer.add_summary(training_summary, training_step) # Save the model checkpoint periodically. if training_step % config[ 'save_step_interval'] == 0 or training_step == how_many_training_steps: checkpoint_path = os.path.join(config['train_dir'], 'model.ckpt') tf.logging.info('Saving to "%s-%d"', checkpoint_path, training_step) saver.save(sess, checkpoint_path, global_step=training_step) set_size = len(testing_file_list) tf.logging.info('testing set size=%d', set_size) doa_per_reverb = int( max(config['direction_range']) - min(config['direction_range']) + 1) test_acc = np.zeros( [len(testing_file_list), doa_per_reverb, len(reverb), len(room_index)]) test_adj_acc = np.zeros( [len(testing_file_list), doa_per_reverb, len(reverb), len(room_index)]) test_frame_acc = np.zeros( [len(testing_file_list), doa_per_reverb, len(reverb), len(room_index)]) test_adj_frame_acc = np.zeros( [len(testing_file_list), doa_per_reverb, len(reverb), len(room_index)]) for testing_file_idx, testing_file in enumerate(testing_file_list): print("testing file:", os.path.basename(testing_file)) for rir_file_idx, rir_file in enumerate(rir_file_list): rir_filename = os.path.basename(rir_file) degree = int(rir_filename.split('angle_')[1].split('deg_')[0]) reverb_percent = int( rir_filename.split('reverb_')[1].split('Percent_')[0]) room_num = int(rir_filename.split('_ROOM')[1].split('.wav')[0]) if reverb_percent not in reverb or room_num not in room_index: continue reverb_idx = reverb.index(reverb_percent) room_idx = room_index.index(room_num) reverb_wav, testing_phase_specs = input_data.get_input_specs( testing_file, rir_file, config['win_len'], config['win_shift'], config['nDFT'], config['context_window_width'], config['max_wav_length']) num_frames = testing_phase_specs.shape[0] testing_doa_label = input_data.get_direction_label( rir_file, config['dim_direction_label'], config['direction_range'], num_frames) logits, class_acc, _ = sess.run( [predict_logits, acc_op, init_local_variable], feed_dict={ phase_specs: testing_phase_specs, ground_truth_doa_label: testing_doa_label }) adjacent_class = 2 how_many_previous_frame = 15 testing_predict = eval.get_label_from_logits(logits) adj_acc = eval.eval_adjacent_accuracy(testing_predict, testing_doa_label, adjacent_class) frame_acc = eval.eval_frame_accuracy(testing_predict, testing_doa_label, how_many_previous_frame) adj_frame_acc = eval.eval_joint_deg_frame(testing_predict, testing_doa_label, adjacent_class, how_many_previous_frame) test_acc[testing_file_idx, degree, reverb_idx, room_idx] = class_acc test_adj_acc[testing_file_idx, degree, reverb_idx, room_idx] = adj_acc test_frame_acc[testing_file_idx, degree, reverb_idx, room_idx] = frame_acc test_adj_frame_acc[testing_file_idx, degree, reverb_idx, room_idx] = adj_frame_acc print("degree:", format(degree, '5.1f'), " " * 6, "reverb:", format(reverb_percent, '5.0f'), " " * 6, "room:", format(room_num, '5.0f'), " " * 6, "acc:", format(class_acc, '5.5f'), " " * 6, "adj acc:", format(adj_acc, '5.5f'), " " * 6, "frame acc:", format(frame_acc, '5.5f'), " " * 6, "adj frame acc:", format(adj_frame_acc, '5.5f')) print("overall acc:", np.mean(test_acc)) print("overall adj_acc:", np.mean(test_adj_acc)) print("overall frame_acc:", np.mean(test_frame_acc)) print("overall adj frame acc:", np.mean(test_adj_frame_acc)) print("-" * 30) print("Degree accuracy") print(format("deg", '10.10s'), format("acc", '10.10s'), format("deg acc", '15.10s'), format("frame acc", '15.10s'), format("deg frame acc", "15.10s")) for i in range(doa_per_reverb): print(format(i, '.1f'), " " * 5, format(np.mean(test_acc[:, i, :]), '.4f'), " " * 6, format(np.mean(test_adj_acc[:, i, :]), '.4f'), " " * 6, format(np.mean(test_frame_acc[:, i, :]), '.4f'), " " * 6, format(np.mean(test_adj_frame_acc[:, i, :]), '.4f')) deg_idx = range(doa_per_reverb) print("-" * 30) for room in range(len(room_index)): for i in range(len(reverb)): print("reverb: ", reverb[i]) print("room: ", room_index[room]) print("acc:", np.mean(test_acc[:, :, i, room])) print("adj_acc:", np.mean(test_adj_acc[:, :, i, room])) print("frame_acc:", np.mean(test_frame_acc[:, :, i, room])) print("adj frame acc:", np.mean(test_adj_frame_acc[:, :, i, room])) for j in range(doa_per_reverb): print( format(j, '.1f'), " " * 5, format(np.mean(test_acc[:, j, i, room]), '.4f'), " " * 6, format(np.mean(test_adj_acc[:, j, i, room]), '.4f'), " " * 6, format(np.mean(test_frame_acc[:, j, i, room]), '.4f'), " " * 6, format(np.mean(test_adj_frame_acc[:, j, i, room]), '.4f')) print("-" * 30) for room in range(len(room_index)): for i in range(len(reverb)): plt.figure(i) plt.plot(deg_idx, np.mean(test_adj_acc[:, :, i, room], axis=0), '.') plt.yscale('linear') plt.xlabel('degree') plt.ylabel('accuracy') plt.title('room ' + str(room_index[room]) + ', reverb ' + str(reverb[i]) + ' percent') plt.grid(True) filename = 'room_' + str(room_index[room]) + '_reverb_' + str( reverb[i]) + '_acc.png' fig_save_path = os.path.join( './figures', 'v4', os.path.basename(config['testing_data_dir'])) if not os.path.exists(fig_save_path): os.makedirs(fig_save_path) filename = os.path.join(fig_save_path, filename) plt.savefig(filename) plt.show()
def main(_): # We want to see all the logging messages for this tutorial. tf.logging.set_verbosity(tf.logging.INFO) # Start a new TensorFlow session. sess = tf.InteractiveSession() # Begin by making sure we have the training data we need. If you already have # training data of your own, use `--data_url= ` on the command line to avoid # downloading. model_settings = model.prepare_model_settings( len(input_data.prepare_words_list(FLAGS.wanted_words.split(','))), FLAGS.sample_rate, FLAGS.clip_duration_ms, FLAGS.window_size_ms, FLAGS.window_stride_ms, FLAGS.dct_coefficient_count) audio_processor = input_data.AudioProcessor( FLAGS.data_url, FLAGS.data_dir, FLAGS.silence_percentage, FLAGS.unknown_percentage, FLAGS.wanted_words.split(','), FLAGS.validation_percentage, FLAGS.testing_percentage, model_settings) fingerprint_size = model_settings['fingerprint_size'] label_count = model_settings['label_count'] time_shift_samples = int((FLAGS.time_shift_ms * FLAGS.sample_rate) / 1000) # Figure out the learning rates for each training phase. Since it's often # effective to have high learning rates at the start of training, followed by # lower levels towards the end, the number of steps and learning rates can be # specified as comma-separated lists to define the rate at each stage. For # example --how_many_training_steps=10000,3000 --learning_rate=0.001,0.0001 # will run 13,000 training loops in total, with a rate of 0.001 for the first # 10,000, and 0.0001 for the final 3,000. training_steps_list = list(map(int, FLAGS.how_many_training_steps.split(','))) learning_rates_list = list(map(float, FLAGS.learning_rate.split(','))) if len(training_steps_list) != len(learning_rates_list): raise Exception( '--how_many_training_steps and --learning_rate must be equal length ' 'lists, but are %d and %d long instead' % (len(training_steps_list), len(learning_rates_list))) fingerprint_input = tf.placeholder( tf.float32, [None, fingerprint_size], name='fingerprint_input') logits, dropout_prob = model.create_conv_model(fingerprint_input, model_settings, is_training=True) # Define loss and optimizer ground_truth_input = tf.placeholder( tf.float32, [None, label_count], name='groundtruth_input') # Optionally we can add runtime checks to spot when NaNs or other symptoms of # numerical errors start occurring during training. control_dependencies = [] if FLAGS.check_nans: checks = tf.add_check_numerics_ops() control_dependencies = [checks] # Create the back propagation and training evaluation machinery in the graph. with tf.name_scope('cross_entropy'): cross_entropy_mean = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits( labels=ground_truth_input, logits=logits)) tf.summary.scalar('cross_entropy', cross_entropy_mean) with tf.name_scope('train'), tf.control_dependencies(control_dependencies): learning_rate_input = tf.placeholder( tf.float32, [], name='learning_rate_input') train_step = tf.train.GradientDescentOptimizer( learning_rate_input).minimize(cross_entropy_mean) predicted_indices = tf.argmax(logits, 1) expected_indices = tf.argmax(ground_truth_input, 1) correct_prediction = tf.equal(predicted_indices, expected_indices) confusion_matrix = tf.confusion_matrix(expected_indices, predicted_indices, num_classes=label_count) evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) tf.summary.scalar('accuracy', evaluation_step) global_step = tf.contrib.framework.get_or_create_global_step() increment_global_step = tf.assign(global_step, global_step + 1) saver = tf.train.Saver(tf.global_variables()) # Merge all the summaries and write them out to /tmp/retrain_logs (by default) merged_summaries = tf.summary.merge_all() train_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/train', sess.graph) validation_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/validation') tf.global_variables_initializer().run() start_step = 1 if FLAGS.start_checkpoint: model.load_variables_from_checkpoint(sess, FLAGS.start_checkpoint) start_step = global_step.eval(session=sess) tf.logging.info('Training from step: %d ', start_step) # Save graph.pbtxt. tf.train.write_graph(sess.graph_def, FLAGS.train_dir, FLAGS.model_architecture + '.pbtxt') # Save list of words. with gfile.GFile( os.path.join(FLAGS.train_dir, FLAGS.model_architecture + '_labels.txt'), 'w') as f: f.write('\n'.join(audio_processor.words_list)) # Training loop. training_steps_max = np.sum(training_steps_list) for training_step in xrange(start_step, training_steps_max + 1): # Figure out what the current learning rate is. training_steps_sum = 0 for i in range(len(training_steps_list)): training_steps_sum += training_steps_list[i] if training_step <= training_steps_sum: learning_rate_value = learning_rates_list[i] break # Pull the audio samples we'll use for training. train_fingerprints, train_ground_truth = audio_processor.get_data( FLAGS.batch_size, 0, model_settings, FLAGS.background_frequency, FLAGS.background_volume, time_shift_samples, 'training', sess) # Run the graph with this batch of training data. train_summary, train_accuracy, cross_entropy_value, _, _ = sess.run( [ merged_summaries, evaluation_step, cross_entropy_mean, train_step, increment_global_step ], feed_dict={ fingerprint_input: train_fingerprints, ground_truth_input: train_ground_truth, learning_rate_input: learning_rate_value, dropout_prob: 0.5 }) train_writer.add_summary(train_summary, training_step) tf.logging.info('Step #%d: rate %f, accuracy %.1f%%, cross entropy %f' % (training_step, learning_rate_value, train_accuracy * 100, cross_entropy_value)) is_last_step = (training_step == training_steps_max) if (training_step % FLAGS.eval_step_interval) == 0 or is_last_step: set_size = audio_processor.set_size('validation') total_accuracy = 0 total_conf_matrix = None for i in xrange(0, set_size, FLAGS.batch_size): validation_fingerprints, validation_ground_truth = ( audio_processor.get_data(FLAGS.batch_size, i, model_settings, 0.0, 0.0, 0, 'validation', sess)) # Run a validation step and capture training summaries for TensorBoard # with the `merged` op. validation_summary, validation_accuracy, conf_matrix = sess.run( [merged_summaries, evaluation_step, confusion_matrix], feed_dict={ fingerprint_input: validation_fingerprints, ground_truth_input: validation_ground_truth, dropout_prob: 1.0 }) validation_writer.add_summary(validation_summary, training_step) batch_size = min(FLAGS.batch_size, set_size - i) total_accuracy += (validation_accuracy * batch_size) / set_size if total_conf_matrix is None: total_conf_matrix = conf_matrix else: total_conf_matrix += conf_matrix tf.logging.info('Confusion Matrix:\n %s' % (total_conf_matrix)) tf.logging.info('Step %d: Validation accuracy = %.1f%% (N=%d)' % (training_step, total_accuracy * 100, set_size)) # Save the model checkpoint periodically. if (training_step % FLAGS.save_step_interval == 0 or training_step == training_steps_max): checkpoint_path = os.path.join(FLAGS.train_dir, FLAGS.model_architecture + '.ckpt') tf.logging.info('Saving to "%s-%d"', checkpoint_path, training_step) saver.save(sess, checkpoint_path, global_step=training_step) set_size = audio_processor.set_size('testing') tf.logging.info('set_size=%d', set_size) total_accuracy = 0 total_conf_matrix = None for i in xrange(0, set_size, FLAGS.batch_size): test_fingerprints, test_ground_truth = audio_processor.get_data( FLAGS.batch_size, i, model_settings, 0.0, 0.0, 0, 'testing', sess) test_accuracy, conf_matrix = sess.run( [evaluation_step, confusion_matrix], feed_dict={ fingerprint_input: test_fingerprints, ground_truth_input: test_ground_truth, dropout_prob: 1.0 }) batch_size = min(FLAGS.batch_size, set_size - i) total_accuracy += (test_accuracy * batch_size) / set_size if total_conf_matrix is None: total_conf_matrix = conf_matrix else: total_conf_matrix += conf_matrix tf.logging.info('Confusion Matrix:\n %s' % (total_conf_matrix)) tf.logging.info('Final test accuracy = %.1f%% (N=%d)' % (total_accuracy * 100, set_size))
def main(_): # random seed RANDOM_SEED = 3233 # import config json_dir = './config.json' with open(json_dir) as config_json: config = json.load(config_json) # define noisy specs input_specs = tf.placeholder(tf.float32, shape=[None, config['context_window_width'], 257, 2], name='specs') # define clean specs train_target = tf.placeholder(tf.float32, shape=[None, 257, 2], name='ground_truth') # create SE-FCN with tf.variable_scope('SEFCN'): model_out = model.se_fcn(input_specs, config['nDFT'], config['context_window_width']) model_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='SEFCN') print('-' * 80) print('SE-FCN vars') nparams = 0 for v in model_vars: v_shape = v.get_shape().as_list() v_n = reduce(lambda x, y: x * y, v_shape) nparams += v_n print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name)) print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) / (1024 * 1024))) print('-' * 80) # define loss and the optimizer mse = tf.losses.mean_squared_error(train_target, model_out) extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) global_step = tf.train.get_or_create_global_step() with tf.name_scope('train'), tf.control_dependencies(extra_update_ops): adam = tf.train.AdamOptimizer(config['Adam_learn_rate']) train_op = adam.minimize(mse, global_step=global_step, var_list=model_vars) # make summaries tf.summary.scalar('mse', mse) # train the model sess = tf.InteractiveSession() saver = tf.train.Saver(tf.global_variables()) # Merge all the summaries and write them out to /tmp/retrain_logs (by default) merged_summaries = tf.summary.merge_all() train_writer = tf.summary.FileWriter(config['summaries_dir'], sess.graph) tf.train.write_graph(sess.graph_def, config['train_dir'], 'model.pbtxt') tf.global_variables_initializer().run() init_local_variable = tf.local_variables_initializer() init_local_variable.run() start_step = 0 if config['start_checkpoint']: model.load_variables_from_checkpoint(sess, config['start_checkpoint']) start_step = global_step.eval(session=sess) print('Training from step:', start_step) tf.logging.set_verbosity(tf.logging.ERROR) snr = range(config['snr_range'][0], config['snr_range'][1] + 1) sensor_snr = range(config['sensor_snr_range'][0], config['sensor_snr_range'][1] + 1) speech_file_list = glob.glob(os.path.join(config['training_data_dir'], "**", "*.wav"), recursive=True) noise_file_list = glob.glob(os.path.join(config['noise_dir'], "**", "*.wav"), recursive=True) if not len(speech_file_list): Exception("No wav files found at " + config['training_data_dir']) if not len(noise_file_list): Exception("No wav files found at " + config['noise_dir']) # get amp normalized sensor noise data if len(config['sensor_noise_path']): fs_sensor, sensor_wav = input_data.read_wav(config['sensor_noise_path'], config['sampling_rate']) else: sensor_wav = None print("Number of training speech wav files: ", len(speech_file_list)) print("Number of training noise wav files: ", len(noise_file_list)) how_many_training_steps = config['how_many_training_steps'] random.seed(RANDOM_SEED) rand_noise_file_idx_list = [random.randint(0, len(noise_file_list)-1) for i in range(int(how_many_training_steps + 1) * config['batch_size'])] random.seed(RANDOM_SEED) rand_speech_file_idx_list = [random.randint(0, len(speech_file_list)-1) for i in range(int(how_many_training_steps + 1) * config['batch_size'])] random.seed(RANDOM_SEED) snr_idx_list = [random.randint(0, len(snr)-1) for i in range(int(how_many_training_steps + 1) * config['batch_size'])] random.seed(RANDOM_SEED) sensor_snr_idx_list = [random.randint(0, len(snr)-1) for i in range(int(how_many_training_steps + 1) * config['batch_size'])] for training_step in range(start_step+1, int(how_many_training_steps + 1)): # get training data _, batch_noisy_specs, speech_specs, _ = input_data.get_training_specs(speech_file_list, noise_file_list, snr, rand_speech_file_idx_list, rand_noise_file_idx_list, snr_idx_list, sensor_noise_wav=sensor_wav, sensor_snr=sensor_snr, sensor_snr_idx_list=sensor_snr_idx_list, training_step=training_step, batch_size=config['batch_size'], context_window_width=config['context_window_width']) # train the model _, training_summary, train_mse = sess.run([train_op, merged_summaries, mse], feed_dict={input_specs: batch_noisy_specs, train_target: speech_specs}) print("training step:", training_step, " "*10, "mse:", format(train_mse, '.5f')) train_writer.add_summary(training_summary, training_step) # Save the model checkpoint periodically. if training_step % config['save_checkpoint_steps'] == 0 or training_step == how_many_training_steps: checkpoint_path = os.path.join(config['train_dir'], 'sefcn.ckpt') tf.logging.info('Saving to "%s-%d"', checkpoint_path, training_step) saver.save(sess, checkpoint_path, global_step=training_step) # run the test & save the test results # find testing files testing_file_list = glob.glob(os.path.join(config['testing_data_dir'], "*.wav")) print('testing set size: ', len(testing_file_list)) test_snr = config['test_snr'] overall_testing_mse = np.zeros([len(snr), len(config['test_noise'])]) for file_idx, testing_file in enumerate(testing_file_list): _, clean_wav = input_data.read_wav(testing_file, config['sampling_rate']) for noise_idx in range(len(config['test_noise'])): noise_wav_path = os.path.join(config['test_noise_path'], config['test_noise'][noise_idx] + '.wav') _, noise_wav = input_data.read_wav(noise_wav_path, config['sampling_rate']) for snr_idx in range(len(test_snr)): utter_wav, noisy_wav = input_data.get_noisy_wav(clean_wav=clean_wav, noise_wav=noise_wav, snr=test_snr[snr_idx]) _, batched_noisy_specs, speech_specs, _ = input_data.get_testing_specs(utter_wav, noisy_wav, context_window_width=config['context_window_width']) estimate_specs, test_mse = sess.run([model_out, mse], feed_dict={input_specs: batched_noisy_specs, train_target: speech_specs}) rec_wav = output_data.rec_wav(mag_spec=estimate_specs[:, :, 0], spec_imag=estimate_specs[:, :, 1], win_len=config['win_len'], win_shift=config['win_shift'], nDFT=config['nDFT']) save_path = os.path.join(config['save_testing_results_dir'], 'test', str(config['test_noise'][noise_idx]), str(snr[snr_idx])) if not os.path.exists(save_path): os.makedirs(save_path) comp_save_path = os.path.join(save_path, os.path.basename(testing_file)) output_data.save_wav_file(comp_save_path, rec_wav, config['sampling_rate']) save_path = os.path.join(config['save_testing_results_dir'], 'mix', str(config['test_noise'][noise_idx]), str(snr[snr_idx])) if not os.path.exists(save_path): os.makedirs(save_path) comp_save_path = os.path.join(save_path, os.path.basename(testing_file)) output_data.save_wav_file(comp_save_path, noisy_wav, config['sampling_rate']) save_path = os.path.join(config['save_testing_results_dir'], 'clean', str(config['test_noise'][noise_idx]), str(snr[snr_idx])) if not os.path.exists(save_path): os.makedirs(save_path) comp_save_path = os.path.join(save_path, os.path.basename(testing_file)) output_data.save_wav_file(comp_save_path, utter_wav, config['sampling_rate']) print("Testing file #", file_idx, os.path.basename(testing_file), "SNR :", format(snr[snr_idx], '5.1f'), " "*10, "noise:", format(config['test_noise'][noise_idx], '10.10s'), " "*10, "mse:", format(test_mse, '.5f')) overall_testing_mse[snr_idx][noise_idx] = test_mse / (len(testing_file_list)) np.set_printoptions(precision=3, suppress=True)