Esempio n. 1
0
            def validation_step(x_val, y_val, y_val_tuple, writer=None):
                """Evaluates model on a validation set"""
                batches_validation = dh.batch_iter(
                    list(zip(x_val, y_val, y_val_tuple)), FLAGS.batch_size, 1)

                # Predict classes by threshold or topk ('ts': threshold; 'tk': topk)
                eval_counter, eval_loss, eval_auc = 0, 0.0, 0.0
                eval_rec_ts, eval_pre_ts, eval_F_ts = 0.0, 0.0, 0.0
                eval_rec_tk = [0.0] * FLAGS.top_num
                eval_pre_tk = [0.0] * FLAGS.top_num
                eval_F_tk = [0.0] * FLAGS.top_num
                val_scores = []

                for batch_validation in batches_validation:
                    x_batch_val, y_batch_val, y_batch_val_tuple = zip(*batch_validation)

                    y_batch_val_first = [i[0] for i in y_batch_val_tuple]
                    y_batch_val_second = [j[1] for j in y_batch_val_tuple]
                    y_batch_val_third = [k[2] for k in y_batch_val_tuple]

                    feed_dict = {
                        ann.input_x: x_batch_val,
                        ann.input_y_first: y_batch_val_first,
                        ann.input_y_second: y_batch_val_second,
                        ann.input_y_third: y_batch_val_third,
                        ann.input_y: y_batch_val,
                        ann.dropout_keep_prob: 1.0,
                        ann.is_training: False
                    }
                    step, summaries, scores, cur_loss = sess.run(
                        [ann.global_step, validation_summary_op, ann.scores, ann.loss], feed_dict)

                    for predicted_scores in scores:
                        val_scores.append(predicted_scores)

                    # Predict by threshold
                    predicted_labels_threshold, predicted_values_threshold = \
                        dh.get_label_using_scores_by_threshold(scores=scores, threshold=FLAGS.threshold)

                    cur_rec_ts, cur_pre_ts, cur_F_ts = 0.0, 0.0, 0.0

                    for index, predicted_label_threshold in enumerate(predicted_labels_threshold):
                        rec_inc_ts, pre_inc_ts = dh.cal_metric(predicted_label_threshold, y_batch_val[index])
                        cur_rec_ts, cur_pre_ts = cur_rec_ts + rec_inc_ts, cur_pre_ts + pre_inc_ts

                    cur_rec_ts = cur_rec_ts / len(y_batch_val)
                    cur_pre_ts = cur_pre_ts / len(y_batch_val)

                    cur_F_ts = dh.cal_F(cur_rec_ts, cur_pre_ts)

                    eval_rec_ts, eval_pre_ts = eval_rec_ts + cur_rec_ts, eval_pre_ts + cur_pre_ts

                    # Predict by topK
                    topK_predicted_labels = []
                    for top_num in range(FLAGS.top_num):
                        predicted_labels_topk, predicted_values_topk = \
                            dh.get_label_using_scores_by_topk(scores=scores, top_num=top_num+1)
                        topK_predicted_labels.append(predicted_labels_topk)

                    cur_rec_tk = [0.0] * FLAGS.top_num
                    cur_pre_tk = [0.0] * FLAGS.top_num
                    cur_F_tk = [0.0] * FLAGS.top_num

                    for top_num, predicted_labels_topK in enumerate(topK_predicted_labels):
                        for index, predicted_label_topK in enumerate(predicted_labels_topK):
                            rec_inc_tk, pre_inc_tk = dh.cal_metric(predicted_label_topK, y_batch_val[index])
                            cur_rec_tk[top_num], cur_pre_tk[top_num] = \
                                cur_rec_tk[top_num] + rec_inc_tk, cur_pre_tk[top_num] + pre_inc_tk

                        cur_rec_tk[top_num] = cur_rec_tk[top_num] / len(y_batch_val)
                        cur_pre_tk[top_num] = cur_pre_tk[top_num] / len(y_batch_val)

                        cur_F_tk[top_num] = dh.cal_F(cur_rec_tk[top_num], cur_pre_tk[top_num])

                        eval_rec_tk[top_num], eval_pre_tk[top_num] = \
                            eval_rec_tk[top_num] + cur_rec_tk[top_num], eval_pre_tk[top_num] + cur_pre_tk[top_num]

                    eval_loss = eval_loss + cur_loss
                    eval_counter = eval_counter + 1

                    if writer:
                        writer.add_summary(summaries, step)

                # Calculate the average AUC
                val_scores = np.array(val_scores)
                y_val = np.array(y_val)
                missing_labels_num = 0
                for index in range(FLAGS.total_classes):
                    y_true = y_val[:, index]
                    y_score = val_scores[:, index]
                    try:
                        eval_auc = eval_auc + roc_auc_score(y_true=y_true, y_score=y_score)
                    except:
                        missing_labels_num += 1

                eval_auc = eval_auc / (FLAGS.total_classes - missing_labels_num)
                eval_loss = float(eval_loss / eval_counter)
                eval_rec_ts = float(eval_rec_ts / eval_counter)
                eval_pre_ts = float(eval_pre_ts / eval_counter)
                eval_F_ts = dh.cal_F(eval_rec_ts, eval_pre_ts)

                for top_num in range(FLAGS.top_num):
                    eval_rec_tk[top_num] = float(eval_rec_tk[top_num] / eval_counter)
                    eval_pre_tk[top_num] = float(eval_pre_tk[top_num] / eval_counter)
                    eval_F_tk[top_num] = dh.cal_F(eval_rec_tk[top_num], eval_pre_tk[top_num])

                return eval_loss, eval_auc, eval_rec_ts, eval_pre_ts, eval_F_ts, eval_rec_tk, eval_pre_tk, eval_F_tk
            def validation_step(x_validation, y_validation, writer=None):
                """Evaluates model on a validation set"""
                batches_validation = batch_iter(
                    list(zip(x_validation, y_validation)), FLAGS.batch_size, 1)

                # Predict classes by threshold or topk ('ts': threshold; 'tk': topk)
                eval_counter, eval_loss, eval_rec_ts, eval_acc_ts, eval_F_ts = 0, 0.0, 0.0, 0.0, 0.0
                eval_rec_tk = [0.0] * FLAGS.top_num
                eval_acc_tk = [0.0] * FLAGS.top_num
                eval_F_tk = [0.0] * FLAGS.top_num

                for batch_validation in batches_validation:
                    x_batch_validation, y_batch_validation = zip(*batch_validation)
                    feed_dict = {
                        rcnn.input_x: x_batch_validation,
                        rcnn.input_y: y_batch_validation,
                        rcnn.dropout_keep_prob: 1.0,
                        rcnn.is_training: False
                    }
                    step, summaries, scores, cur_loss = sess.run(
                        [rcnn.global_step, validation_summary_op, rcnn.scores, rcnn.loss], feed_dict)

                    # Predict by threshold
                    predicted_labels_threshold, predicted_values_threshold = \
                        dh.get_label_using_scores_by_threshold(scores=scores, threshold=FLAGS.threshold)

                    cur_rec_ts, cur_acc_ts, cur_F_ts = 0.0, 0.0, 0.0

                    for index, predicted_label_threshold in enumerate(predicted_labels_threshold):
                        rec_inc_ts, acc_inc_ts, F_inc_ts = dh.cal_metric(predicted_label_threshold,
                                                                         y_batch_validation[index])
                        cur_rec_ts, cur_acc_ts, cur_F_ts = cur_rec_ts + rec_inc_ts, \
                                                           cur_acc_ts + acc_inc_ts, \
                                                           cur_F_ts + F_inc_ts

                    cur_rec_ts = cur_rec_ts / len(y_batch_validation)
                    cur_acc_ts = cur_acc_ts / len(y_batch_validation)
                    cur_F_ts = cur_F_ts / len(y_batch_validation)

                    eval_rec_ts, eval_acc_ts, eval_F_ts = eval_rec_ts + cur_rec_ts, \
                                                          eval_acc_ts + cur_acc_ts, \
                                                          eval_F_ts + cur_F_ts

                    # Predict by topK
                    topK_predicted_labels = []
                    for top_num in range(FLAGS.top_num):
                        predicted_labels_topk, predicted_values_topk = \
                            dh.get_label_using_scores_by_topk(scores=scores, top_num=top_num+1)
                        topK_predicted_labels.append(predicted_labels_topk)

                    cur_rec_tk = [0.0] * FLAGS.top_num
                    cur_acc_tk = [0.0] * FLAGS.top_num
                    cur_F_tk = [0.0] * FLAGS.top_num

                    for top_num, predicted_labels_topK in enumerate(topK_predicted_labels):
                        for index, predicted_label_topK in enumerate(predicted_labels_topK):
                            rec_inc_tk, acc_inc_tk, F_inc_tk = dh.cal_metric(predicted_label_topK,
                                                                             y_batch_validation[index])
                            cur_rec_tk[top_num], cur_acc_tk[top_num], cur_F_tk[top_num] = \
                                cur_rec_tk[top_num] + rec_inc_tk, \
                                cur_acc_tk[top_num] + acc_inc_tk, \
                                cur_F_tk[top_num] + F_inc_tk

                        cur_rec_tk[top_num] = cur_rec_tk[top_num] / len(y_batch_validation)
                        cur_acc_tk[top_num] = cur_acc_tk[top_num] / len(y_batch_validation)
                        cur_F_tk[top_num] = cur_F_tk[top_num] / len(y_batch_validation)

                        eval_rec_tk[top_num], eval_acc_tk[top_num], eval_F_tk[top_num] = \
                            eval_rec_tk[top_num] + cur_rec_tk[top_num], \
                            eval_acc_tk[top_num] + cur_acc_tk[top_num], \
                            eval_F_tk[top_num] + cur_F_tk[top_num]

                    eval_loss = eval_loss + cur_loss
                    eval_counter = eval_counter + 1

                    logger.info("✔︎ validation batch {0}: loss {1:g}".format(eval_counter, cur_loss))
                    logger.info("︎☛ Predict by threshold: recall {0:g}, accuracy {1:g}, F {2:g}"
                                .format(cur_rec_ts, cur_acc_ts, cur_F_ts))

                    logger.info("︎☛ Predict by topK:")
                    for top_num in range(FLAGS.top_num):
                        logger.info("Top{0}: recall {1:g}, accuracy {2:g}, F {3:g}"
                                    .format(top_num + 1, cur_rec_tk[top_num], cur_acc_tk[top_num], cur_F_tk[top_num]))

                    if writer:
                        writer.add_summary(summaries, step)

                eval_loss = float(eval_loss / eval_counter)
                eval_rec_ts = float(eval_rec_ts / eval_counter)
                eval_acc_ts = float(eval_acc_ts / eval_counter)
                eval_F_ts = float(eval_F_ts / eval_counter)

                for top_num in range(FLAGS.top_num):
                    eval_rec_tk[top_num] = float(eval_rec_tk[top_num] / eval_counter)
                    eval_acc_tk[top_num] = float(eval_acc_tk[top_num] / eval_counter)
                    eval_F_tk[top_num] = float(eval_F_tk[top_num] / eval_counter)

                return eval_loss, eval_rec_ts, eval_acc_ts, eval_F_ts, eval_rec_tk, eval_acc_tk, eval_F_tk
Esempio n. 3
0
            def validation_step(x_validation, y_validation, writer=None):
                """Evaluates model on a validation set"""
                batches_validation = dh.batch_iter(
                    list(zip(x_validation, y_validation)), FLAGS.batch_size, 1)

                # Predict classes by threshold or topk ('ts': threshold; 'tk': topk)
                eval_counter, eval_loss, eval_rec_ts, eval_pre_ts, eval_F_ts = 0, 0.0, 0.0, 0.0, 0.0
                eval_rec_tk = [0.0] * FLAGS.top_num
                eval_pre_tk = [0.0] * FLAGS.top_num
                eval_F_tk = [0.0] * FLAGS.top_num

                for batch_validation in batches_validation:
                    x_batch_validation, y_batch_validation = zip(
                        *batch_validation)
                    feed_dict = {
                        fasttext.input_x: x_batch_validation,
                        fasttext.input_y: y_batch_validation,
                        fasttext.dropout_keep_prob: 1.0,
                        fasttext.is_training: False
                    }
                    step, summaries, scores, cur_loss = sess.run([
                        fasttext.global_step, validation_summary_op,
                        fasttext.scores, fasttext.loss
                    ], feed_dict)

                    # Predict by threshold
                    predicted_labels_threshold, predicted_values_threshold = \
                        dh.get_label_using_scores_by_threshold(scores=scores, threshold=FLAGS.threshold)

                    cur_rec_ts, cur_pre_ts, cur_F_ts = 0.0, 0.0, 0.0

                    for index, predicted_label_threshold in enumerate(
                            predicted_labels_threshold):
                        rec_inc_ts, pre_inc_ts = dh.cal_metric(
                            predicted_label_threshold,
                            y_batch_validation[index])
                        cur_rec_ts, cur_pre_ts = cur_rec_ts + rec_inc_ts, cur_pre_ts + pre_inc_ts

                    cur_rec_ts = cur_rec_ts / len(y_batch_validation)
                    cur_pre_ts = cur_pre_ts / len(y_batch_validation)

                    cur_F_ts = dh.cal_F(cur_rec_ts, cur_pre_ts)

                    eval_rec_ts, eval_pre_ts = eval_rec_ts + cur_rec_ts, eval_pre_ts + cur_pre_ts

                    # Predict by topK
                    topK_predicted_labels = []
                    for top_num in range(FLAGS.top_num):
                        predicted_labels_topk, predicted_values_topk = \
                            dh.get_label_using_scores_by_topk(scores=scores, top_num=top_num+1)
                        topK_predicted_labels.append(predicted_labels_topk)

                    cur_rec_tk = [0.0] * FLAGS.top_num
                    cur_pre_tk = [0.0] * FLAGS.top_num
                    cur_F_tk = [0.0] * FLAGS.top_num

                    for top_num, predicted_labels_topK in enumerate(
                            topK_predicted_labels):
                        for index, predicted_label_topK in enumerate(
                                predicted_labels_topK):
                            rec_inc_tk, pre_inc_tk = dh.cal_metric(
                                predicted_label_topK,
                                y_batch_validation[index])
                            cur_rec_tk[top_num], cur_pre_tk[top_num] = \
                                cur_rec_tk[top_num] + rec_inc_tk, cur_pre_tk[top_num] + pre_inc_tk

                        cur_rec_tk[top_num] = cur_rec_tk[top_num] / len(
                            y_batch_validation)
                        cur_pre_tk[top_num] = cur_pre_tk[top_num] / len(
                            y_batch_validation)

                        cur_F_tk[top_num] = dh.cal_F(cur_rec_tk[top_num],
                                                     cur_pre_tk[top_num])

                        eval_rec_tk[top_num], eval_pre_tk[top_num] = \
                            eval_rec_tk[top_num] + cur_rec_tk[top_num], eval_pre_tk[top_num] + cur_pre_tk[top_num]

                    eval_loss = eval_loss + cur_loss
                    eval_counter = eval_counter + 1

                    if writer:
                        writer.add_summary(summaries, step)

                eval_loss = float(eval_loss / eval_counter)
                eval_rec_ts = float(eval_rec_ts / eval_counter)
                eval_pre_ts = float(eval_pre_ts / eval_counter)
                eval_F_ts = dh.cal_F(eval_rec_ts, eval_pre_ts)

                for top_num in range(FLAGS.top_num):
                    eval_rec_tk[top_num] = float(eval_rec_tk[top_num] /
                                                 eval_counter)
                    eval_pre_tk[top_num] = float(eval_pre_tk[top_num] /
                                                 eval_counter)
                    eval_F_tk[top_num] = dh.cal_F(eval_rec_tk[top_num],
                                                  eval_pre_tk[top_num])

                return eval_loss, eval_rec_ts, eval_pre_ts, eval_F_ts, eval_rec_tk, eval_pre_tk, eval_F_tk
def test_mann():
    """Test MANN model."""

    # Load data
    logger.info("✔ Loading data...")
    logger.info('Recommended padding Sequence length is: {0}'.format(
        FLAGS.pad_seq_len))

    logger.info('✔︎ Test data processing...')
    test_data = dh.load_data_and_labels(FLAGS.test_data_file,
                                        FLAGS.num_classes, FLAGS.embedding_dim)

    logger.info('✔︎ Test data padding...')
    x_test, y_test = dh.pad_data(test_data, FLAGS.pad_seq_len)
    y_test_labels = test_data.labels

    # Load mann model
    logger.info("✔ Loading model...")
    checkpoint_file = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
    logger.info(checkpoint_file)

    graph = tf.Graph()
    with graph.as_default():
        session_conf = tf.ConfigProto(
            allow_soft_placement=FLAGS.allow_soft_placement,
            log_device_placement=FLAGS.log_device_placement)
        session_conf.gpu_options.allow_growth = FLAGS.gpu_options_allow_growth
        sess = tf.Session(config=session_conf)
        with sess.as_default():
            # Load the saved meta graph and restore variables
            saver = tf.train.import_meta_graph(
                "{0}.meta".format(checkpoint_file))
            saver.restore(sess, checkpoint_file)

            # Get the placeholders from the graph by name
            input_x = graph.get_operation_by_name("input_x").outputs[0]
            input_y = graph.get_operation_by_name("input_y").outputs[0]
            dropout_keep_prob = graph.get_operation_by_name(
                "dropout_keep_prob").outputs[0]
            is_training = graph.get_operation_by_name("is_training").outputs[0]

            # Tensors we want to evaluate
            scores = graph.get_operation_by_name("output/scores").outputs[0]
            loss = graph.get_operation_by_name("loss/loss").outputs[0]

            # Split the output nodes name by '|' if you have several output nodes
            output_node_names = 'output/logits|output/scores'

            # Save the .pb model file
            output_graph_def = tf.graph_util.convert_variables_to_constants(
                sess, sess.graph_def, output_node_names.split("|"))
            tf.train.write_graph(output_graph_def,
                                 'graph',
                                 'graph-mann-{0}.pb'.format(MODEL),
                                 as_text=False)

            # Generate batches for one epoch
            batches = dh.batch_iter(list(zip(x_test, y_test, y_test_labels)),
                                    FLAGS.batch_size,
                                    1,
                                    shuffle=False)

            # Collect the predictions here
            all_labels = []
            all_predicted_labels = []
            all_predicted_values = []

            # Calculate the metric
            test_counter, test_loss, test_rec, test_pre, test_F = 0, 0.0, 0.0, 0.0, 0.0

            for batch_test in batches:
                x_batch_test, y_batch_test, y_batch_test_labels = zip(
                    *batch_test)
                feed_dict = {
                    input_x: x_batch_test,
                    input_y: y_batch_test,
                    dropout_keep_prob: 1.0,
                    is_training: False
                }
                batch_scores, cur_loss = sess.run([scores, loss], feed_dict)

                # Predict by threshold
                predicted_labels_threshold, predicted_values_threshold = \
                    dh.get_label_using_scores_by_threshold(scores=batch_scores, threshold=FLAGS.threshold)

                cur_rec, cur_pre, cur_F = 0.0, 0.0, 0.0

                for index, predicted_label_threshold in enumerate(
                        predicted_labels_threshold):
                    rec_inc, pre_inc = dh.cal_metric(predicted_label_threshold,
                                                     y_batch_test[index])
                    cur_rec, cur_pre = cur_rec + rec_inc, cur_pre + pre_inc

                cur_rec = cur_rec / len(y_batch_test)
                cur_pre = cur_pre / len(y_batch_test)

                test_rec, test_pre = test_rec + cur_rec, test_pre + cur_pre

                # Add results to collection
                for item in y_batch_test_labels:
                    all_labels.append(item)
                for item in predicted_labels_threshold:
                    all_predicted_labels.append(item)
                for item in predicted_values_threshold:
                    all_predicted_values.append(item)

                test_loss = test_loss + cur_loss
                test_counter = test_counter + 1

            test_loss = float(test_loss / test_counter)
            test_rec = float(test_rec / test_counter)
            test_pre = float(test_pre / test_counter)
            test_F = dh.cal_F(test_rec, test_pre)

            logger.info("☛ All Test Dataset: Loss {0:g}".format(test_loss))

            # Predict by threshold
            logger.info(
                "︎☛ Predict by threshold: Recall {0:g}, Precision {1:g}, F {2:g}"
                .format(test_rec, test_pre, test_F))
            # Save the prediction result
            if not os.path.exists(SAVE_DIR):
                os.makedirs(SAVE_DIR)
            dh.create_prediction_file(output_file=SAVE_DIR +
                                      '/predictions.json',
                                      data_id=test_data.testid,
                                      all_labels=all_labels,
                                      all_predict_labels=all_predicted_labels,
                                      all_predict_values=all_predicted_values)

    logger.info("✔ Done.")
def test_rcnn():
    """Test RCNN model."""

    # Load data
    logger.info("✔ Loading data...")
    logger.info('Recommended padding Sequence length is: {0}'.format(
        FLAGS.pad_seq_len))

    logger.info('✔︎ Test data processing...')
    test_data = dh.load_data_and_labels(FLAGS.test_data_file,
                                        FLAGS.num_classes, FLAGS.embedding_dim)

    logger.info('✔︎ Test data padding...')
    x_test, y_test = dh.pad_data(test_data, FLAGS.pad_seq_len)

    # Load rcnn model
    logger.info("✔ Loading model...")
    checkpoint_file = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
    logger.info(checkpoint_file)

    graph = tf.Graph()
    with graph.as_default():
        session_conf = tf.ConfigProto(
            allow_soft_placement=FLAGS.allow_soft_placement,
            log_device_placement=FLAGS.log_device_placement)
        session_conf.gpu_options.allow_growth = FLAGS.gpu_options_allow_growth
        sess = tf.Session(config=session_conf)
        with sess.as_default():
            # Load the saved meta graph and restore variables
            saver = tf.train.import_meta_graph(
                "{0}.meta".format(checkpoint_file))
            saver.restore(sess, checkpoint_file)

            # Get the placeholders from the graph by name
            input_x = graph.get_operation_by_name("input_x").outputs[0]
            input_y = graph.get_operation_by_name("input_y").outputs[0]
            dropout_keep_prob = graph.get_operation_by_name(
                "dropout_keep_prob").outputs[0]
            is_training = graph.get_operation_by_name("is_training").outputs[0]

            # Tensors we want to evaluate
            scores = graph.get_operation_by_name("output/scores").outputs[0]
            loss = graph.get_operation_by_name("loss/loss").outputs[0]

            # Split the output nodes name by '|' if you have several output nodes
            output_node_names = 'output/logits|output/scores'

            # Save the .pb model file
            output_graph_def = tf.graph_util.convert_variables_to_constants(
                sess, sess.graph_def, output_node_names.split("|"))
            tf.train.write_graph(output_graph_def,
                                 'graph',
                                 'graph-rcnn-{0}.pb'.format(MODEL),
                                 as_text=False)

            # Generate batches for one epoch

            # Warning: The last batch will be drop so there will be no predictions on the last batch data
            # Advise: You can add some make-up data to make the last batch data be exactly data size
            batches = batch_iter(list(zip(x_test, y_test)),
                                 FLAGS.batch_size,
                                 1,
                                 shuffle=False)

            # Collect the predictions here
            all_predicted_label_ts = []
            all_predicted_values_ts = []

            all_predicted_label_tk = []
            all_predicted_values_tk = []

            # Calculate the metric
            test_counter, test_loss, test_rec_ts, test_acc_ts, test_F_ts = 0, 0.0, 0.0, 0.0, 0.0
            test_rec_tk = [0.0] * FLAGS.top_num
            test_acc_tk = [0.0] * FLAGS.top_num
            test_F_tk = [0.0] * FLAGS.top_num

            for batch_test in batches:
                x_batch_test, y_batch_test = zip(*batch_test)
                feed_dict = {
                    input_x: x_batch_test,
                    input_y: y_batch_test,
                    dropout_keep_prob: 1.0,
                    is_training: False
                }
                batch_scores, cur_loss = sess.run([scores, loss], feed_dict)

                # Predict by threshold
                predicted_labels_threshold, predicted_values_threshold = \
                    dh.get_label_using_scores_by_threshold(scores=batch_scores, threshold=FLAGS.threshold)

                cur_rec_ts, cur_acc_ts, cur_F_ts = 0.0, 0.0, 0.0

                for index, predicted_label_threshold in enumerate(
                        predicted_labels_threshold):
                    rec_inc_ts, acc_inc_ts, F_inc_ts = dh.cal_metric(
                        predicted_label_threshold, y_batch_test[index])
                    cur_rec_ts, cur_acc_ts, cur_F_ts = cur_rec_ts + rec_inc_ts, \
                                                       cur_acc_ts + acc_inc_ts, \
                                                       cur_F_ts + F_inc_ts

                cur_rec_ts = cur_rec_ts / len(y_batch_test)
                cur_acc_ts = cur_acc_ts / len(y_batch_test)
                cur_F_ts = cur_F_ts / len(y_batch_test)

                test_rec_ts, test_acc_ts, test_F_ts = test_rec_ts + cur_rec_ts, \
                                                      test_acc_ts + cur_acc_ts, \
                                                      test_F_ts + cur_F_ts

                # Add results to collection
                for item in predicted_labels_threshold:
                    all_predicted_label_ts.append(item)
                for item in predicted_values_threshold:
                    all_predicted_values_ts.append(item)

                # Predict by topK
                topK_predicted_labels = []
                for top_num in range(FLAGS.top_num):
                    predicted_labels_topk, predicted_values_topk = \
                        dh.get_label_using_scores_by_topk(batch_scores, top_num=top_num + 1)
                    topK_predicted_labels.append(predicted_labels_topk)

                cur_rec_tk = [0.0] * FLAGS.top_num
                cur_acc_tk = [0.0] * FLAGS.top_num
                cur_F_tk = [0.0] * FLAGS.top_num

                for top_num, predicted_labels_topK in enumerate(
                        topK_predicted_labels):
                    for index, predicted_label_topK in enumerate(
                            predicted_labels_topK):
                        rec_inc_tk, acc_inc_tk, F_inc_tk = dh.cal_metric(
                            predicted_label_topK, y_batch_test[index])
                        cur_rec_tk[top_num], cur_acc_tk[top_num], cur_F_tk[top_num] = \
                            cur_rec_tk[top_num] + rec_inc_tk, \
                            cur_acc_tk[top_num] + acc_inc_tk, \
                            cur_F_tk[top_num] + F_inc_tk

                    cur_rec_tk[top_num] = cur_rec_tk[top_num] / len(
                        y_batch_test)
                    cur_acc_tk[top_num] = cur_acc_tk[top_num] / len(
                        y_batch_test)
                    cur_F_tk[top_num] = cur_F_tk[top_num] / len(y_batch_test)

                    test_rec_tk[top_num], test_acc_tk[top_num], test_F_tk[top_num] = \
                        test_rec_tk[top_num] + cur_rec_tk[top_num], \
                        test_acc_tk[top_num] + cur_acc_tk[top_num], \
                        test_F_tk[top_num] + cur_F_tk[top_num]

                test_loss = test_loss + cur_loss
                test_counter = test_counter + 1

            test_loss = float(test_loss / test_counter)
            test_rec_ts = float(test_rec_ts / test_counter)
            test_acc_ts = float(test_acc_ts / test_counter)
            test_F_ts = float(test_F_ts / test_counter)

            for top_num in range(FLAGS.top_num):
                test_rec_tk[top_num] = float(test_rec_tk[top_num] /
                                             test_counter)
                test_acc_tk[top_num] = float(test_acc_tk[top_num] /
                                             test_counter)
                test_F_tk[top_num] = float(test_F_tk[top_num] / test_counter)

            logger.info("☛ All Test Dataset: Loss {0:g}".format(test_loss))

            # Predict by threshold
            logger.info(
                "︎☛ Predict by threshold: Recall {0:g}, accuracy {1:g}, F {2:g}"
                .format(test_rec_ts, test_acc_ts, test_F_ts))

            # Predict by topK
            logger.info("︎☛ Predict by topK:")
            for top_num in range(FLAGS.top_num):
                logger.info(
                    "Top{0}: recall {1:g}, accuracy {2:g}, F {3:g}".format(
                        top_num + 1, test_rec_tk[top_num],
                        test_acc_tk[top_num], test_F_tk[top_num]))

            # Save the prediction result
            if not os.path.exists(SAVE_DIR):
                os.makedirs(SAVE_DIR)
            dh.create_prediction_file(
                output_file=SAVE_DIR + '/predictions.json',
                data_id=test_data.testid,
                all_predict_labels_ts=all_predicted_label_ts,
                all_predict_values_ts=all_predicted_values_ts)

    logger.info("✔ Done.")
def test_lmlp():
    """Test LMLP model."""

    # Load data
    logger.info("✔︎ Loading data...")
    logger.info("Recommended padding Sequence length is: {0}".format(
        FLAGS.pad_seq_len))

    logger.info("✔︎ Test data processing...")
    test_data = dh.load_data_and_labels(FLAGS.test_data_file,
                                        FLAGS.num_classes_list,
                                        FLAGS.embedding_dim,
                                        data_aug_flag=False)

    logger.info("✔︎ Test data padding...")
    x_test, y_test = dh.pad_data(test_data, FLAGS.pad_seq_len)
    y_test_labels = test_data.labels

    # Load LMLP model
    BEST_OR_LATEST = input("☛ Load Best or Latest Model?(B/L): ")

    while not (BEST_OR_LATEST.isalpha()
               and BEST_OR_LATEST.upper() in ['B', 'L']):
        BEST_OR_LATEST = input(
            "✘ The format of your input is illegal, please re-input: ")
    if BEST_OR_LATEST == 'B':
        logger.info("✔︎ Loading best model...")
        checkpoint_file = cm.get_best_checkpoint(FLAGS.best_checkpoint_dir,
                                                 select_maximum_value=True)
    else:
        logger.info("✔︎ Loading latest model...")
        checkpoint_file = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
    logger.info(checkpoint_file)

    graph = tf.Graph()
    with graph.as_default():
        session_conf = tf.ConfigProto(
            allow_soft_placement=FLAGS.allow_soft_placement,
            log_device_placement=FLAGS.log_device_placement)
        session_conf.gpu_options.allow_growth = FLAGS.gpu_options_allow_growth
        sess = tf.Session(config=session_conf)
        with sess.as_default():
            # Load the saved meta graph and restore variables
            saver = tf.train.import_meta_graph(
                "{0}.meta".format(checkpoint_file))
            saver.restore(sess, checkpoint_file)

            # Get the placeholders from the graph by name
            input_x = graph.get_operation_by_name("input_x").outputs[0]
            input_y_first = graph.get_operation_by_name(
                "input_y_first").outputs[0]
            input_y_second = graph.get_operation_by_name(
                "input_y_second").outputs[0]
            input_y = graph.get_operation_by_name("input_y").outputs[0]
            dropout_keep_prob = graph.get_operation_by_name(
                "dropout_keep_prob").outputs[0]
            is_training = graph.get_operation_by_name("is_training").outputs[0]

            # Tensors we want to evaluate
            scores = graph.get_operation_by_name("output/scores").outputs[0]
            loss = graph.get_operation_by_name("loss/loss").outputs[0]

            # Split the output nodes name by '|' if you have several output nodes
            output_node_names = "output/logits|output/scores"

            # Save the .pb model file
            output_graph_def = tf.graph_util.convert_variables_to_constants(
                sess, sess.graph_def, output_node_names.split("|"))
            tf.train.write_graph(output_graph_def,
                                 "graph",
                                 "graph-cnn-{0}.pb".format(MODEL),
                                 as_text=False)

            # Generate batches for one epoch
            batches = dh.batch_iter(list(zip(x_test, y_test, y_test_labels)),
                                    FLAGS.batch_size,
                                    1,
                                    shuffle=False)

            # Collect the predictions here
            all_labels = []
            all_predicted_labels = []
            all_predicted_values = []

            # Calculate the metric
            test_counter, test_loss, test_rec, test_pre, test_F = 0, 0.0, 0.0, 0.0, 0.0

            for batch_test in batches:
                x_batch_test, y_batch_test, y_batch_test_labels = zip(
                    *batch_test)

                y_batch_test_first = [i[0] for i in y_batch_test]
                y_batch_test_second = [j[1] for j in y_batch_test]
                y_batch_test_third = [k[2] for k in y_batch_test]

                feed_dict = {
                    input_x: x_batch_test,
                    input_y_first: y_batch_test_first,
                    input_y_second: y_batch_test_second,
                    input_y: y_batch_test_third,
                    dropout_keep_prob: 1.0,
                    is_training: False
                }
                batch_scores, cur_loss = sess.run([scores, loss], feed_dict)

                # Predict by threshold
                predicted_labels_threshold, predicted_values_threshold = \
                    dh.get_label_using_scores_by_threshold(scores=batch_scores, threshold=FLAGS.threshold)

                cur_rec, cur_pre, cur_F = 0.0, 0.0, 0.0

                for index, predicted_label_threshold in enumerate(
                        predicted_labels_threshold):
                    rec_inc, pre_inc = dh.cal_metric(predicted_label_threshold,
                                                     y_batch_test_third[index])
                    cur_rec, cur_pre = cur_rec + rec_inc, cur_pre + pre_inc

                cur_rec = cur_rec / len(y_batch_test_third)
                cur_pre = cur_pre / len(y_batch_test_third)

                test_rec, test_pre = test_rec + cur_rec, test_pre + cur_pre

                # Add results to collection
                for item in y_batch_test_labels:
                    all_labels.append(item)
                for item in predicted_labels_threshold:
                    all_predicted_labels.append(item)
                for item in predicted_values_threshold:
                    all_predicted_values.append(item)

                test_loss = test_loss + cur_loss
                test_counter = test_counter + 1

            test_loss = float(test_loss / test_counter)
            test_rec = float(test_rec / test_counter)
            test_pre = float(test_pre / test_counter)
            test_F = dh.cal_F(test_rec, test_pre)

            logger.info("☛ All Test Dataset: Loss {0:g}".format(test_loss))

            # Predict by threshold
            logger.info(
                "☛ Predict by threshold: Recall {0:g}, Precision {1:g}, F {2:g}"
                .format(test_rec, test_pre, test_F))
            # Save the prediction result
            if not os.path.exists(SAVE_DIR):
                os.makedirs(SAVE_DIR)
            dh.create_prediction_file(output_file=SAVE_DIR +
                                      "/predictions.json",
                                      data_id=test_data.patent_id,
                                      all_labels=all_labels,
                                      all_predict_labels=all_predicted_labels,
                                      all_predict_values=all_predicted_values)

    logger.info("✔︎ Done.")