def vad_test(m_eval, sess_eval, batch_size_eval, eval_file_dir, norm_dir, data_len, eval_type): eval_input_dir = eval_file_dir eval_output_dir = eval_file_dir + '/Labels' pad_size = batch_size_eval - data_len % batch_size_eval if eval_type != 2: eval_data_set = dr.DataReader(eval_input_dir, eval_output_dir, norm_dir, w=19, u=9, name="eval", pad=pad_size) else: eval_data_set = dnn_dr.DataReader(eval_input_dir, eval_output_dir, norm_dir, w=19, u=9, name="eval", pad=pad_size) final_softout, final_label = evaluation(m_eval, eval_data_set, sess_eval, batch_size_eval, eval_type) return final_softout, final_label
def full_evaluation(m_eval, sess_eval, batch_size_eval, eval_file_dir, summary_writer, summary_dic, itr): mean_cost = [] mean_accuracy = [] mean_auc = [] print("-------- Performance for each of noise types --------") noise_list = os.listdir(eval_file_dir) noise_list = sorted(noise_list) summary_ph = summary_dic["summary_ph"] for i in range(len(noise_list)): noise_name = '/' + noise_list[i] eval_input_dir = eval_file_dir + noise_name eval_output_dir = eval_file_dir + noise_name + '/Labels' eval_data_set = dr.DataReader(eval_input_dir, eval_output_dir, norm_dir, w=w, u=u, name="eval") eval_cost, eval_accuracy, eval_list, eval_auc, eval_auc_list = evaluation(m_eval, eval_data_set, sess_eval, batch_size_eval) print("--noise type : " + noise_list[i]) print("cost: %.4f, accuracy across all SNRs: %.4f" % (eval_cost, eval_accuracy*100)) print('accuracy wrt SNR:') print('SNR_-5 : %.4f, SNR_0 : %.4f, SNR_5 : %.4f, SNR_10 : %.4f' % (eval_list[0]*100, eval_list[1]*100, eval_list[2]*100, eval_list[3]*100)) print('AUC wrt SNR:') print('SNR_-5 : %.4f, SNR_0 : %.4f, SNR_5 : %.4f, SNR_10 : %.4f' % (eval_auc_list[0]*100, eval_auc_list[1]*100, eval_auc_list[2]*100, eval_auc_list[3]*100)) print('') eval_summary_list = [eval_cost] + eval_list + [eval_accuracy] for j, summary_name in enumerate(summary_list): summary_str = sess_eval.run(summary_dic[noise_list[i]+"_"+summary_name], feed_dict={summary_ph: eval_summary_list[j]}) summary_writer.add_summary(summary_str, itr) mean_cost.append(eval_cost) mean_accuracy.append(eval_accuracy) mean_auc.append(eval_auc) mean_cost = np.mean(np.asarray(mean_cost)) var_accuracy = np.var(np.asarray(mean_accuracy)) mean_accuracy = np.mean(np.asarray(mean_accuracy)) mean_auc = np.mean(np.asarray(mean_auc)) summary_writer.add_summary(sess_eval.run(summary_dic["cost_across_all_noise_types"], feed_dict={summary_ph: mean_cost}), itr) summary_writer.add_summary(sess_eval.run(summary_dic["accuracy_across_all_noise_types"], feed_dict={summary_ph: mean_accuracy}), itr) summary_writer.add_summary(sess_eval.run(summary_dic["variance_across_all_noise_types"], feed_dict={summary_ph: var_accuracy}), itr) print("-------- Performance across all of noise types --------") print("cost : %.4f" % mean_cost) print("******* averaged accuracy across all noise_types : %.4f *******" % (mean_accuracy*100)) print("******* averaged auc across all noise_types : %.7f *******" % (mean_auc*100)) print("******* variance of accuracies across all noise_types : %6.6f *******" % var_accuracy) return mean_auc, var_accuracy
def main(argv=None): # Graph Part # print("Graph initialization...") with tf.device(device): with tf.variable_scope("model", reuse=None): m_train = Model(is_training=True) with tf.variable_scope("model", reuse=True): m_valid = Model(is_training=False) print("Done") # Summary Part # print("Setting up summary op...") cost_ph = tf.placeholder(dtype=tf.float32) accuracy_ph = tf.placeholder(dtype=tf.float32) cost_summary_op = tf.summary.scalar("cost", cost_ph) accuracy_summary_op = tf.summary.scalar("accuracy", accuracy_ph) train_summary_writer = tf.summary.FileWriter(logs_dir + '/train/') valid_summary_writer = tf.summary.FileWriter(logs_dir + '/valid/', max_queue=2) print("Done") # Model Save Part # print("Setting up Saver...") saver = tf.train.Saver() ckpt = tf.train.get_checkpoint_state(logs_dir) print("Done") # Session Part # sess_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False) sess_config.gpu_options.allow_growth = True sess = tf.Session(config=sess_config) if ckpt and ckpt.model_checkpoint_path: # model restore print("Model restored...") saver.restore(sess, ckpt.model_checkpoint_path) print("Done") else: sess.run(tf.global_variables_initializer()) # if the checkpoint doesn't exist, do initialization data_set = dr.DataReader(input_dir, output_dir, norm_dir, w=w, u=u, name="train") # training data reader initialization valid_data_set = dr.DataReader(valid_input_dir, valid_output_dir, norm_dir, w=w, u=u, name="valid") # validation data reader initialization if FLAGS.mode is 'train': for itr in range(max_epoch): train_inputs, train_labels = data_set.next_batch(batch_size) # imgplot = plt.imshow(train_inputs) # plt.show() one_hot_labels = train_labels.reshape((-1, 1)) one_hot_labels = dense_to_one_hot(one_hot_labels, num_classes=2) feed_dict = {m_train.inputs: train_inputs, m_train.labels: one_hot_labels, m_train.keep_probability: dropout_rate} sess.run(m_train.train_op, feed_dict=feed_dict) if itr % 50 == 0 and itr >= 0: train_cost, train_accuracy = sess.run([m_train.cost, m_train.accuracy], feed_dict=feed_dict) print("Step: %d, train_cost: %.3f, train_accuracy=%3.3f" % (itr, train_cost, train_accuracy)) train_cost_summary_str = sess.run(cost_summary_op, feed_dict={cost_ph: train_cost}) train_accuracy_summary_str = sess.run(accuracy_summary_op, feed_dict={accuracy_ph: train_accuracy}) train_summary_writer.add_summary(train_cost_summary_str, itr) # write the train phase summary to event files train_summary_writer.add_summary(train_accuracy_summary_str, itr) if itr % 100 == 0 and itr >= 0: saver.save(sess, logs_dir + "/model.ckpt", itr) # model save valid_cost, valid_accuracy = evaluation(m_valid, valid_data_set, sess, valid_batch_size) # print('') print("valid_cost: %.3f, valid_accuracy: %.3f" % (valid_cost, valid_accuracy)) print('') valid_summary_str_cost = sess.run(cost_summary_op, feed_dict={cost_ph: valid_cost}) valid_summary_str_accuracy = sess.run(accuracy_summary_op, feed_dict={accuracy_ph: valid_accuracy}) valid_summary_writer.add_summary(valid_summary_str_cost, itr) valid_summary_writer.add_summary(valid_summary_str_accuracy, itr) elif FLAGS.mode is 'test': _, valid_accuracy = evaluation(m_valid, valid_data_set, sess, valid_batch_size) print("valid_accuracy = %.3f" % valid_accuracy)