def vad_test(m_eval, sess_eval, batch_size_eval, eval_file_dir, norm_dir, data_len, eval_type): eval_input_dir = eval_file_dir eval_output_dir = eval_file_dir + '/Labels' pad_size = batch_size_eval - data_len % batch_size_eval if eval_type != 2: eval_data_set = dr.DataReader(eval_input_dir, eval_output_dir, norm_dir, w=19, u=9, name="eval", pad=pad_size) else: eval_data_set = dnn_dr.DataReader(eval_input_dir, eval_output_dir, norm_dir, w=19, u=9, name="eval", pad=pad_size) final_softout, final_label = evaluation(m_eval, eval_data_set, sess_eval, batch_size_eval, eval_type) return final_softout, final_label
def do_validation(m_valid, sess, valid_batch_size, valid_file_dir, norm_dir, model_config, type='DNN'): # dataset reader setting # if type is 'DNN': valid_data_set = dnn_dr.DataReader(valid_file_dir, valid_file_dir+'/Labels', norm_dir, w=model_config['w'], u=model_config['u'], name="eval") avg_valid_accuracy = 0. avg_valid_cost = 0. itr_sum = 0. accuracy_list = [0 for i in range(valid_data_set._file_len)] cost_list = [0 for i in range(valid_data_set._file_len)] itr_file = 0 while True: valid_inputs, valid_labels = valid_data_set.next_batch(valid_batch_size) if valid_data_set.file_change_checker(): # print(itr_file) accuracy_list[itr_file] = avg_valid_accuracy / itr_sum cost_list[itr_file] = avg_valid_cost / itr_sum avg_valid_cost = 0. avg_valid_accuracy = 0. itr_sum = 0 itr_file += 1 valid_data_set.file_change_initialize() if valid_data_set.eof_checker(): valid_data_set.reader_initialize() print('Valid data reader was initialized!') # initialize eof flag & num_file & start index break one_hot_labels = valid_labels.reshape((-1, 1)) one_hot_labels = dense_to_one_hot(one_hot_labels, num_classes=2) feed_dict = {m_valid.inputs: valid_inputs, m_valid.labels: one_hot_labels, m_valid.keep_probability: 1} # valid_cost, valid_softpred, valid_raw_labels\ # = sess.run([m_valid.cost, m_valid.softpred, m_valid.raw_labels], feed_dict=feed_dict) # # fpr, tpr, thresholds = metrics.roc_curve(valid_raw_labels, valid_softpred, pos_label=1) # valid_auc = metrics.auc(fpr, tpr) valid_cost, valid_accuracy = sess.run([m_valid.cost, m_valid.accuracy], feed_dict=feed_dict) avg_valid_accuracy += valid_accuracy avg_valid_cost += valid_cost itr_sum += 1 total_avg_valid_accuracy = np.asscalar(np.mean(np.asarray(accuracy_list))) total_avg_valid_cost = np.asscalar(np.mean(np.asarray(cost_list))) return total_avg_valid_accuracy, total_avg_valid_cost
def main(argv=None): # Graph Part # print("Graph initialization...") with tf.device(device): with tf.variable_scope("model", reuse=None): m_train = Model(is_training=True) with tf.variable_scope("model", reuse=True): m_valid = Model(is_training=False) print("Done") # Summary Part # print("Setting up summary op...") summary_ph = tf.placeholder(dtype=tf.float32) with tf.variable_scope("Training_procedure"): cost_summary_op = tf.summary.scalar("cost", summary_ph) accuracy_summary_op = tf.summary.scalar("accuracy", summary_ph) if mode is 'train': train_summary_writer = tf.summary.FileWriter(logs_dir + '/train/', max_queue=2) valid_summary_writer = tf.summary.FileWriter(logs_dir + '/valid/', max_queue=2) # summary_dic = summary_generation(valid_file_dir) print("Done") # Model Save Part # print("Setting up Saver...") saver = tf.train.Saver() ckpt = tf.train.get_checkpoint_state(logs_dir) print("Done") # Session Part # sess_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False) sess_config.gpu_options.allow_growth = True sess = tf.Session(config=sess_config) if ckpt and ckpt.model_checkpoint_path: # model restore print("Model restored...") if mode is 'train': saver.restore(sess, ckpt.model_checkpoint_path) else: saver.restore(sess, initial_logs_dir + ckpt_name) print("Done") else: sess.run(tf.global_variables_initializer() ) # if the checkpoint doesn't exist, do initialization if mode is 'train': train_data_set = dr.DataReader( input_dir, output_dir, norm_dir, w=w, u=u, name="train") # training data reader initialization if mode is 'train': for itr in range(max_epoch): train_inputs, train_labels = train_data_set.next_batch(batch_size) # imgplot = plt.imshow(train_inputs) # plt.show() one_hot_labels = train_labels.reshape((-1, 1)) one_hot_labels = dense_to_one_hot(one_hot_labels, num_classes=2) feed_dict = { m_train.inputs: train_inputs, m_train.labels: one_hot_labels, m_train.keep_probability: dropout_rate } sess.run(m_train.train_op, feed_dict=feed_dict) if itr % 10 == 0 and itr >= 0: # train_cost, train_softpred, train_raw_labels \ # = sess.run([m_train.cost, m_train.softpred, m_train.raw_labels], feed_dict=feed_dict) # fpr, tpr, thresholds = metrics.roc_curve(train_raw_labels, train_softpred, pos_label=1) # train_auc = metrics.auc(fpr, tpr) train_cost, train_accuracy \ = sess.run([m_train.cost, m_train.accuracy], feed_dict=feed_dict) print("Step: %d, train_cost: %.4f, train_accuracy=%4.4f" % (itr, train_cost, train_accuracy * 100)) train_cost_summary_str = sess.run( cost_summary_op, feed_dict={summary_ph: train_cost}) train_accuracy_summary_str = sess.run( accuracy_summary_op, feed_dict={summary_ph: train_accuracy}) train_summary_writer.add_summary( train_cost_summary_str, itr) # write the train phase summary to event files train_summary_writer.add_summary(train_accuracy_summary_str, itr) # if train_data_set.eof_checker(): if itr % 50 == 0 and itr > 0: saver.save(sess, logs_dir + "/model.ckpt", itr) # model save print('validation start!') valid_accuracy, valid_cost = \ utils.do_validation(m_valid, sess, valid_batch_size, valid_file_dir, norm_dir, model_config, type='DNN') print("valid_cost: %.4f, valid_accuracy=%4.4f" % (valid_cost, valid_accuracy * 100)) valid_cost_summary_str = sess.run( cost_summary_op, feed_dict={summary_ph: valid_cost}) valid_accuracy_summary_str = sess.run( accuracy_summary_op, feed_dict={summary_ph: valid_accuracy}) valid_summary_writer.add_summary( valid_cost_summary_str, itr) # write the train phase summary to event files valid_summary_writer.add_summary(valid_accuracy_summary_str, itr) # full_evaluation(m_valid, sess, valid_batch_size, valid_file_dir, valid_summary_writer, summary_dic, itr) elif mode is 'test': # full_evaluation(m_valid, sess, valid_batch_size, test_file_dir, valid_summary_writer, summary_dic, 0) final_softout, final_label = utils.vad_test(m_valid, sess, valid_batch_size, test_file_dir, norm_dir, data_len, eval_type) if data_len is None: return final_softout, final_label else: return final_softout[0:data_len, :], final_label[0:data_len, :]
def full_evaluation(m_eval, sess_eval, batch_size_eval, eval_file_dir, summary_writer, summary_dic, itr): mean_cost = [] mean_accuracy = [] mean_auc = [] print("-------- Performance for each of noise types --------") noise_list = os.listdir(eval_file_dir) noise_list = sorted(noise_list) summary_ph = summary_dic["summary_ph"] for i in range(len(noise_list)): noise_name = '/' + noise_list[i] eval_input_dir = eval_file_dir + noise_name eval_output_dir = eval_file_dir + noise_name + '/Labels' eval_data_set = dr.DataReader(eval_input_dir, eval_output_dir, norm_dir, w=w, u=u, name="eval") eval_cost, eval_accuracy, eval_list, eval_auc, eval_auc_list = evaluation( m_eval, eval_data_set, sess_eval, batch_size_eval) print("--noise type : " + noise_list[i]) print( "cost: %.4f, accuracy across all SNRs: %.4f, auc across all SNRS: %.4f" % (eval_cost, eval_accuracy * 100, eval_auc)) print('accuracy wrt SNR:') print('SNR_-5 : %.4f, SNR_0 : %.4f, SNR_5 : %.4f, SNR_10 : %.4f' % (eval_list[0] * 100, eval_list[1] * 100, eval_list[2] * 100, eval_list[3] * 100)) print('AUC wrt SNR:') print('SNR_-5 : %.4f, SNR_0 : %.4f, SNR_5 : %.4f, SNR_10 : %.4f' % (eval_auc_list[0], eval_auc_list[1], eval_auc_list[2], eval_auc_list[3])) print('') eval_summary_list = [eval_cost] + eval_list + [eval_accuracy] for j, summary_name in enumerate(summary_list): summary_str = sess_eval.run( summary_dic[noise_list[i] + "_" + summary_name], feed_dict={summary_ph: eval_summary_list[j]}) summary_writer.add_summary(summary_str, itr) mean_cost.append(eval_cost) mean_accuracy.append(eval_accuracy) mean_auc.append(eval_auc) mean_cost = np.mean(np.asarray(mean_cost)) var_accuracy = np.var(np.asarray(mean_accuracy)) mean_accuracy = np.mean(np.asarray(mean_accuracy)) mean_auc = np.mean(np.asarray(mean_auc)) summary_writer.add_summary( sess_eval.run(summary_dic["cost_across_all_noise_types"], feed_dict={summary_ph: mean_cost}), itr) summary_writer.add_summary( sess_eval.run(summary_dic["accuracy_across_all_noise_types"], feed_dict={summary_ph: mean_accuracy}), itr) summary_writer.add_summary( sess_eval.run(summary_dic["variance_across_all_noise_types"], feed_dict={summary_ph: var_accuracy}), itr) print("-------- Performance across all of noise types --------") print("cost : %.4f" % mean_cost) print("******* averaged accuracy across all noise_types : %.4f *******" % (mean_accuracy * 100)) print("******* averaged auc across all noise_types : %.4f *******" % mean_auc) print( "******* variance of accuracies across all noise_types : %4.4f *******" % (var_accuracy * 100))