Ejemplo n.º 1
0
def run_test_musicnet():
    global BATCH_SIZE
    BATCH_SIZE = cnf.batch_size
    prepare_data_for_test()
    data_gen.test_counters = np.zeros(cnf.bin_max_len, dtype=np.int32)  # reset test counters
    data_supplier = data_feeder.create_data_supplier()

    for test_length in cnf.bins:

        with tf.Graph().as_default():
            tester = create_tester(test_length)

            with tf.compat.v1.Session(config=tf.compat.v1.ConfigProto()) as sess:
                sess.run(tf.compat.v1.global_variables_initializer())
                saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables())
                saver.restore(sess, cnf.model_file)

                predictions = []
                labels = []
                full_test_len = len(data_gen.test_set[cnf.task][test_length])
                # rounds it up to the next size divisible by batch size:
                n_test_inputs = (full_test_len // BATCH_SIZE) * BATCH_SIZE + BATCH_SIZE
                print("Testing on {} test inputs: ".format(n_test_inputs), end="")
                threshold = 0
                for i in range(n_test_inputs // BATCH_SIZE):
                    if i > threshold:
                        print("{},".format(i * BATCH_SIZE), end="", flush=True)
                        threshold += 1000 // BATCH_SIZE
                    batch_xs, batch_ys = data_supplier.supply_test_data(test_length, BATCH_SIZE)
                    pred_flat = (tester.get_result(sess, batch_xs, batch_ys)).flatten()

                    stride_labels = 128
                    n_frames = cnf.musicnet_window_size // stride_labels - 1
                    labels_pre = np.array(batch_ys[0])[:,
                                 stride_labels * (n_frames // 2):stride_labels * (n_frames // 2) + 128]
                    labels_flat = (labels_pre - 1).flatten()  # gets 0/1 labels on 128 notes
                    predictions += list(pred_flat)
                    labels += list(labels_flat)

                predictions = np.array(predictions)
                labels = np.array(labels)
                n_overshoot = n_test_inputs - full_test_len  # inputs more than test len
                if n_overshoot > 0:  # removes duplicates
                    predictions = predictions[:-(128 * n_overshoot)]
                    labels = labels[:-(128 * n_overshoot)]

                avg_prec_score = average_precision_score(labels, predictions)
                print("\n")
                print("Cutting {} input duplicates".format(n_overshoot))
                print("Done testing on all {} test inputs".format(len(labels) / 128))
                print("AVERAGE PRECISION SCORE on all test data = {0:.7f}".format(avg_prec_score))
Ejemplo n.º 2
0
data_gen.init()

if cnf.task in cnf.language_tasks:
    task = data_gen.find_data_task(cnf.task)
    task.prepare_data()
    data_gen.collect_bins()
    data_gen.print_bin_usage()
else:
    for length in range(1, max_length + 1):
        n_examples = cnf.data_size
        data_gen.init_data(cnf.task, length, n_examples, cnf.n_input)
    data_gen.collect_bins()
    if len(data_gen.train_set[cnf.task][cnf.forward_max]) == 0:
        data_gen.init_data(cnf.task, cnf.forward_max, cnf.test_data_size, cnf.n_input)

data_supplier = data_feeder.create_data_supplier()


# Perform training
with tf.Graph().as_default():
    learner = RSE(cnf.n_hidden, cnf.bins, cnf.n_input, countList, cnf.n_output, cnf.dropout_keep_prob,
                  create_translation_model=cnf.task in cnf.language_tasks, use_two_gpus=cnf.use_two_gpus)
    learner.create_graph()
    learner.variable_summaries = tf.compat.v1.summary.merge_all()
    tf.compat.v1.get_variable_scope().reuse_variables()
    learner.create_test_graph(cnf.forward_max)
    saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables())

    with tf.compat.v1.Session(config=cnf.tf_config) as sess:
        sess.run(tf.compat.v1.global_variables_initializer())
        sess.run(tf.compat.v1.local_variables_initializer())