예제 #1
0
파일: utils.py 프로젝트: yunhojang/VAD
def vad_test(m_eval, sess_eval, batch_size_eval, eval_file_dir, norm_dir,
             data_len, eval_type):

    eval_input_dir = eval_file_dir
    eval_output_dir = eval_file_dir + '/Labels'

    pad_size = batch_size_eval - data_len % batch_size_eval
    if eval_type != 2:
        eval_data_set = dr.DataReader(eval_input_dir,
                                      eval_output_dir,
                                      norm_dir,
                                      w=19,
                                      u=9,
                                      name="eval",
                                      pad=pad_size)
    else:
        eval_data_set = dnn_dr.DataReader(eval_input_dir,
                                          eval_output_dir,
                                          norm_dir,
                                          w=19,
                                          u=9,
                                          name="eval",
                                          pad=pad_size)

    final_softout, final_label = evaluation(m_eval, eval_data_set, sess_eval,
                                            batch_size_eval, eval_type)

    return final_softout, final_label
예제 #2
0
파일: utils.py 프로젝트: orctom/VAD
def do_validation(m_valid, sess, valid_batch_size, valid_file_dir, norm_dir, model_config, type='DNN'):

    # dataset reader setting #

    if type is 'DNN':
        valid_data_set = dnn_dr.DataReader(valid_file_dir, valid_file_dir+'/Labels', norm_dir, w=model_config['w'],
                                           u=model_config['u'], name="eval")

    avg_valid_accuracy = 0.
    avg_valid_cost = 0.
    itr_sum = 0.

    accuracy_list = [0 for i in range(valid_data_set._file_len)]
    cost_list = [0 for i in range(valid_data_set._file_len)]
    itr_file = 0
    while True:

        valid_inputs, valid_labels = valid_data_set.next_batch(valid_batch_size)

        if valid_data_set.file_change_checker():
            # print(itr_file)
            accuracy_list[itr_file] = avg_valid_accuracy / itr_sum
            cost_list[itr_file] = avg_valid_cost / itr_sum
            avg_valid_cost = 0.
            avg_valid_accuracy = 0.
            itr_sum = 0
            itr_file += 1
            valid_data_set.file_change_initialize()

        if valid_data_set.eof_checker():
            valid_data_set.reader_initialize()
            print('Valid data reader was initialized!')  # initialize eof flag & num_file & start index
            break

        one_hot_labels = valid_labels.reshape((-1, 1))
        one_hot_labels = dense_to_one_hot(one_hot_labels, num_classes=2)

        feed_dict = {m_valid.inputs: valid_inputs, m_valid.labels: one_hot_labels,
                     m_valid.keep_probability: 1}

        # valid_cost, valid_softpred, valid_raw_labels\
        #     = sess.run([m_valid.cost, m_valid.softpred, m_valid.raw_labels], feed_dict=feed_dict)
        #
        # fpr, tpr, thresholds = metrics.roc_curve(valid_raw_labels, valid_softpred, pos_label=1)
        # valid_auc = metrics.auc(fpr, tpr)

        valid_cost, valid_accuracy = sess.run([m_valid.cost, m_valid.accuracy], feed_dict=feed_dict)

        avg_valid_accuracy += valid_accuracy
        avg_valid_cost += valid_cost
        itr_sum += 1

    total_avg_valid_accuracy = np.asscalar(np.mean(np.asarray(accuracy_list)))
    total_avg_valid_cost = np.asscalar(np.mean(np.asarray(cost_list)))

    return total_avg_valid_accuracy, total_avg_valid_cost
예제 #3
0
파일: VAD_DNN.py 프로젝트: orctom/VAD
def main(argv=None):
    #                               Graph Part                               #
    print("Graph initialization...")
    with tf.device(device):
        with tf.variable_scope("model", reuse=None):
            m_train = Model(is_training=True)
        with tf.variable_scope("model", reuse=True):
            m_valid = Model(is_training=False)

    print("Done")

    #                               Summary Part                             #

    print("Setting up summary op...")

    summary_ph = tf.placeholder(dtype=tf.float32)
    with tf.variable_scope("Training_procedure"):

        cost_summary_op = tf.summary.scalar("cost", summary_ph)
        accuracy_summary_op = tf.summary.scalar("accuracy", summary_ph)

    if mode is 'train':
        train_summary_writer = tf.summary.FileWriter(logs_dir + '/train/',
                                                     max_queue=2)
        valid_summary_writer = tf.summary.FileWriter(logs_dir + '/valid/',
                                                     max_queue=2)

    # summary_dic = summary_generation(valid_file_dir)

    print("Done")

    #                               Model Save Part                           #

    print("Setting up Saver...")
    saver = tf.train.Saver()
    ckpt = tf.train.get_checkpoint_state(logs_dir)
    print("Done")

    #                               Session Part                              #

    sess_config = tf.ConfigProto(allow_soft_placement=True,
                                 log_device_placement=False)
    sess_config.gpu_options.allow_growth = True
    sess = tf.Session(config=sess_config)

    if ckpt and ckpt.model_checkpoint_path:  # model restore
        print("Model restored...")

        if mode is 'train':
            saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            saver.restore(sess, initial_logs_dir + ckpt_name)

        print("Done")
    else:
        sess.run(tf.global_variables_initializer()
                 )  # if the checkpoint doesn't exist, do initialization
    if mode is 'train':
        train_data_set = dr.DataReader(
            input_dir, output_dir, norm_dir, w=w, u=u,
            name="train")  # training data reader initialization

    if mode is 'train':

        for itr in range(max_epoch):

            train_inputs, train_labels = train_data_set.next_batch(batch_size)
            # imgplot = plt.imshow(train_inputs)
            # plt.show()
            one_hot_labels = train_labels.reshape((-1, 1))
            one_hot_labels = dense_to_one_hot(one_hot_labels, num_classes=2)

            feed_dict = {
                m_train.inputs: train_inputs,
                m_train.labels: one_hot_labels,
                m_train.keep_probability: dropout_rate
            }

            sess.run(m_train.train_op, feed_dict=feed_dict)

            if itr % 10 == 0 and itr >= 0:

                # train_cost, train_softpred, train_raw_labels \
                #     = sess.run([m_train.cost, m_train.softpred, m_train.raw_labels], feed_dict=feed_dict)
                # fpr, tpr, thresholds = metrics.roc_curve(train_raw_labels, train_softpred, pos_label=1)
                # train_auc = metrics.auc(fpr, tpr)

                train_cost, train_accuracy \
                    = sess.run([m_train.cost, m_train.accuracy], feed_dict=feed_dict)

                print("Step: %d, train_cost: %.4f, train_accuracy=%4.4f" %
                      (itr, train_cost, train_accuracy * 100))

                train_cost_summary_str = sess.run(
                    cost_summary_op, feed_dict={summary_ph: train_cost})
                train_accuracy_summary_str = sess.run(
                    accuracy_summary_op,
                    feed_dict={summary_ph: train_accuracy})
                train_summary_writer.add_summary(
                    train_cost_summary_str,
                    itr)  # write the train phase summary to event files
                train_summary_writer.add_summary(train_accuracy_summary_str,
                                                 itr)

            # if train_data_set.eof_checker():
            if itr % 50 == 0 and itr > 0:

                saver.save(sess, logs_dir + "/model.ckpt", itr)  # model save
                print('validation start!')
                valid_accuracy, valid_cost = \
                    utils.do_validation(m_valid, sess, valid_batch_size, valid_file_dir, norm_dir,
                                        model_config, type='DNN')

                print("valid_cost: %.4f, valid_accuracy=%4.4f" %
                      (valid_cost, valid_accuracy * 100))
                valid_cost_summary_str = sess.run(
                    cost_summary_op, feed_dict={summary_ph: valid_cost})
                valid_accuracy_summary_str = sess.run(
                    accuracy_summary_op,
                    feed_dict={summary_ph: valid_accuracy})
                valid_summary_writer.add_summary(
                    valid_cost_summary_str,
                    itr)  # write the train phase summary to event files
                valid_summary_writer.add_summary(valid_accuracy_summary_str,
                                                 itr)

                # full_evaluation(m_valid, sess, valid_batch_size, valid_file_dir, valid_summary_writer, summary_dic, itr)

    elif mode is 'test':
        # full_evaluation(m_valid, sess, valid_batch_size, test_file_dir, valid_summary_writer, summary_dic, 0)

        final_softout, final_label = utils.vad_test(m_valid, sess,
                                                    valid_batch_size,
                                                    test_file_dir, norm_dir,
                                                    data_len, eval_type)

        if data_len is None:
            return final_softout, final_label
        else:
            return final_softout[0:data_len, :], final_label[0:data_len, :]
예제 #4
0
파일: VAD_DNN.py 프로젝트: orctom/VAD
def full_evaluation(m_eval, sess_eval, batch_size_eval, eval_file_dir,
                    summary_writer, summary_dic, itr):

    mean_cost = []
    mean_accuracy = []
    mean_auc = []

    print("-------- Performance for each of noise types --------")

    noise_list = os.listdir(eval_file_dir)
    noise_list = sorted(noise_list)

    summary_ph = summary_dic["summary_ph"]

    for i in range(len(noise_list)):

        noise_name = '/' + noise_list[i]
        eval_input_dir = eval_file_dir + noise_name
        eval_output_dir = eval_file_dir + noise_name + '/Labels'
        eval_data_set = dr.DataReader(eval_input_dir,
                                      eval_output_dir,
                                      norm_dir,
                                      w=w,
                                      u=u,
                                      name="eval")

        eval_cost, eval_accuracy, eval_list, eval_auc, eval_auc_list = evaluation(
            m_eval, eval_data_set, sess_eval, batch_size_eval)

        print("--noise type : " + noise_list[i])
        print(
            "cost: %.4f, accuracy across all SNRs: %.4f, auc across all SNRS: %.4f"
            % (eval_cost, eval_accuracy * 100, eval_auc))

        print('accuracy wrt SNR:')

        print('SNR_-5 : %.4f, SNR_0 : %.4f, SNR_5 : %.4f, SNR_10 : %.4f' %
              (eval_list[0] * 100, eval_list[1] * 100, eval_list[2] * 100,
               eval_list[3] * 100))
        print('AUC wrt SNR:')
        print('SNR_-5 : %.4f, SNR_0 : %.4f, SNR_5 : %.4f, SNR_10 : %.4f' %
              (eval_auc_list[0], eval_auc_list[1], eval_auc_list[2],
               eval_auc_list[3]))
        print('')

        eval_summary_list = [eval_cost] + eval_list + [eval_accuracy]

        for j, summary_name in enumerate(summary_list):
            summary_str = sess_eval.run(
                summary_dic[noise_list[i] + "_" + summary_name],
                feed_dict={summary_ph: eval_summary_list[j]})
            summary_writer.add_summary(summary_str, itr)

        mean_cost.append(eval_cost)
        mean_accuracy.append(eval_accuracy)
        mean_auc.append(eval_auc)

    mean_cost = np.mean(np.asarray(mean_cost))
    var_accuracy = np.var(np.asarray(mean_accuracy))
    mean_accuracy = np.mean(np.asarray(mean_accuracy))
    mean_auc = np.mean(np.asarray(mean_auc))

    summary_writer.add_summary(
        sess_eval.run(summary_dic["cost_across_all_noise_types"],
                      feed_dict={summary_ph: mean_cost}), itr)
    summary_writer.add_summary(
        sess_eval.run(summary_dic["accuracy_across_all_noise_types"],
                      feed_dict={summary_ph: mean_accuracy}), itr)
    summary_writer.add_summary(
        sess_eval.run(summary_dic["variance_across_all_noise_types"],
                      feed_dict={summary_ph: var_accuracy}), itr)

    print("-------- Performance across all of noise types --------")
    print("cost : %.4f" % mean_cost)
    print("******* averaged accuracy across all noise_types : %.4f *******" %
          (mean_accuracy * 100))
    print("******* averaged auc across all noise_types : %.4f *******" %
          mean_auc)
    print(
        "******* variance of accuracies across all noise_types : %4.4f *******"
        % (var_accuracy * 100))