Exemple #1
0
def train_ann(word2vec_path):
    """Training ANN model."""

    # Load sentences, labels, and training parameters
    logger.info("✔︎ Loading data...")

    logger.info("✔︎ Training data processing...")
    train_data = feed.load_data_and_labels(FLAGS.training_data_file,
                                           FLAGS.num_classes,
                                           FLAGS.embedding_dim,
                                           data_aug_flag=False,
                                           word2vec_path=word2vec_path)

    logger.info("✔︎ Validation data processing...")
    val_data = feed.load_data_and_labels(FLAGS.validation_data_file,
                                         FLAGS.num_classes,
                                         FLAGS.embedding_dim,
                                         data_aug_flag=False,
                                         word2vec_path=word2vec_path)

    logger.info("Recommended padding Sequence length is: {0}".format(
        FLAGS.pad_seq_len))

    logger.info("✔︎ Training data padding...")
    x_train, y_train = feed.pad_data(train_data, FLAGS.pad_seq_len)

    logger.info("✔︎ Validation data padding...")
    x_val, y_val = feed.pad_data(val_data, FLAGS.pad_seq_len)

    # Build vocabulary
    VOCAB_SIZE = feed.load_vocab_size(FLAGS.embedding_dim,
                                      word2vec_path=word2vec_path)

    # Use pretrained W2V
    pretrained_word2vec_matrix = feed.load_word2vec_matrix(
        VOCAB_SIZE, FLAGS.embedding_dim, word2vec_path=word2vec_path)

    # Build a graph and ann object
    with tf.Graph().as_default():
        session_conf = tf.ConfigProto(
            allow_soft_placement=FLAGS.allow_soft_placement,
            log_device_placement=FLAGS.log_device_placement)
        session_conf.gpu_options.allow_growth = FLAGS.gpu_options_allow_growth
        sess = tf.Session(config=session_conf)
        with sess.as_default():
            ann = TextANN(sequence_length=FLAGS.pad_seq_len,
                          num_classes=FLAGS.num_classes,
                          vocab_size=VOCAB_SIZE,
                          fc_hidden_size=FLAGS.fc_hidden_size,
                          embedding_size=FLAGS.embedding_dim,
                          embedding_type=FLAGS.embedding_type,
                          l2_reg_lambda=FLAGS.l2_reg_lambda,
                          pretrained_embedding=pretrained_word2vec_matrix)

            # Define training procedure
            with tf.control_dependencies(
                    tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
                learning_rate = tf.train.exponential_decay(
                    learning_rate=FLAGS.learning_rate,
                    global_step=ann.global_step,
                    decay_steps=FLAGS.decay_steps,
                    decay_rate=FLAGS.decay_rate,
                    staircase=True)
                optimizer = tf.train.AdamOptimizer(learning_rate)
                grads, variables = zip(*optimizer.compute_gradients(ann.loss))
                grads, _ = tf.clip_by_global_norm(grads,
                                                  clip_norm=FLAGS.norm_ratio)
                train_op = optimizer.apply_gradients(
                    zip(grads, variables),
                    global_step=ann.global_step,
                    name="train_op")

            # Keep track of gradient values and sparsity (optional)
            grad_summaries = []
            for g, v in zip(grads, variables):
                if g is not None:
                    grad_hist_summary = tf.summary.histogram(
                        "{0}/grad/hist".format(v.name), g)
                    sparsity_summary = tf.summary.scalar(
                        "{0}/grad/sparsity".format(v.name),
                        tf.nn.zero_fraction(g))
                    grad_summaries.append(grad_hist_summary)
                    grad_summaries.append(sparsity_summary)
            grad_summaries_merged = tf.summary.merge(grad_summaries)

            # Output directory for models and summaries
            if FLAGS.train_or_restore == 'R':
                MODEL = input(
                    "☛ Please input the checkpoints model you want to "
                    "restore, it should be like(1490175368): ")
                # The model you want to restore

                while not (MODEL.isdigit() and len(MODEL) == 10):
                    MODEL = input("✘ The format of your input is illegal, "
                                  "please re-input: ")
                logger.info("✔︎ The format of your input is legal, "
                            "now loading to next step...")
                out_dir = os.path.abspath(
                    os.path.join(os.path.curdir, "runs", MODEL))
                logger.info("✔︎ Writing to {0}\n".format(out_dir))
            else:
                timestamp = str(int(time.time()))
                out_dir = os.path.abspath(
                    os.path.join(os.path.curdir, "runs", timestamp))
                logger.info("✔︎ Writing to {0}\n".format(out_dir))

            checkpoint_dir = os.path.abspath(
                os.path.join(out_dir, "checkpoints"))
            best_checkpoint_dir = os.path.abspath(
                os.path.join(out_dir, "bestcheckpoints"))

            # Summaries for loss
            loss_summary = tf.summary.scalar("loss", ann.loss)

            # Train summaries
            train_summary_op = tf.summary.merge(
                [loss_summary, grad_summaries_merged])
            train_summary_dir = os.path.join(out_dir, "summaries", "train")
            train_summary_writer = tf.summary.FileWriter(
                train_summary_dir, sess.graph)

            # Validation summaries
            validation_summary_op = tf.summary.merge([loss_summary])
            validation_summary_dir = os.path.join(out_dir, "summaries",
                                                  "validation")
            validation_summary_writer = tf.summary.FileWriter(
                validation_summary_dir, sess.graph)

            saver = tf.train.Saver(tf.global_variables(),
                                   max_to_keep=FLAGS.num_checkpoints)
            best_saver = checkpoints.BestCheckpointSaver(
                save_dir=best_checkpoint_dir, num_to_keep=3, maximize=True)

            if FLAGS.train_or_restore == 'R':
                # Load ann model
                logger.info("✔︎ Loading model...")
                checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir)
                logger.info(checkpoint_file)

                # Load the saved meta graph and restore variables
                saver = tf.train.import_meta_graph(
                    "{0}.meta".format(checkpoint_file))
                saver.restore(sess, checkpoint_file)
            else:
                if not os.path.exists(checkpoint_dir):
                    os.makedirs(checkpoint_dir)
                sess.run(tf.global_variables_initializer())
                sess.run(tf.local_variables_initializer())

                # Embedding visualization config
                config = projector.ProjectorConfig()
                embedding_conf = config.embeddings.add()
                embedding_conf.tensor_name = "embedding"
                embedding_conf.metadata_path = FLAGS.metadata_file

                projector.visualize_embeddings(train_summary_writer, config)
                projector.visualize_embeddings(validation_summary_writer,
                                               config)

                # Save the embedding visualization
                saver.save(
                    sess, os.path.join(out_dir, "embedding", "embedding.ckpt"))

            current_step = sess.run(ann.global_step)
            print("current_step: ", current_step)

            def train_step(x_batch, y_batch):
                """A single training step"""
                feed_dict = {
                    ann.input_x: x_batch,
                    ann.input_y: y_batch,
                    ann.dropout_keep_prob: FLAGS.dropout_keep_prob,
                    ann.is_training: True
                }
                _, step, summaries, loss = sess.run(
                    [train_op, ann.global_step, train_summary_op, ann.loss],
                    feed_dict)
                logger.info("step {0}: loss {1:g}".format(step, loss))
                train_summary_writer.add_summary(summaries, step)

            def validation_step(_x_val, _y_val, writer=None):
                """Evaluates model on a validation set"""
                batches_validation = feed.batch_iter(list(zip(_x_val, _y_val)),
                                                     FLAGS.batch_size, 1)

                # Predict classes by threshold or topk
                # ('ts': threshold; 'tk': topk)
                _eval_counter, _eval_loss = 0, 0.0

                _eval_pre_tk = [0.0] * FLAGS.top_num
                _eval_rec_tk = [0.0] * FLAGS.top_num
                _eval_F_tk = [0.0] * FLAGS.top_num

                true_onehot_labels = []
                predicted_onehot_scores = []
                predicted_onehot_labels_ts = []
                predicted_onehot_labels_tk = [[] for _ in range(FLAGS.top_num)]

                for batch_validation in batches_validation:
                    x_batch_val, y_batch_val = zip(*batch_validation)
                    feed_dict = {
                        ann.input_x: x_batch_val,
                        ann.input_y: y_batch_val,
                        ann.dropout_keep_prob: 1.0,
                        ann.is_training: False
                    }
                    step, summaries, scores, cur_loss = sess.run([
                        ann.global_step, validation_summary_op, ann.scores,
                        ann.loss
                    ], feed_dict)

                    # Prepare for calculating metrics
                    for i in y_batch_val:
                        true_onehot_labels.append(i)
                    for j in scores:
                        predicted_onehot_scores.append(j)

                    # Predict by threshold
                    batch_predicted_onehot_labels_ts = \
                        feed.get_onehot_label_threshold(scores=scores,
                                                        threshold=FLAGS.
                                                        threshold)

                    for k in batch_predicted_onehot_labels_ts:
                        predicted_onehot_labels_ts.append(k)

                    # Predict by topK
                    for _top_num in range(FLAGS.top_num):
                        batch_predicted_onehot_labels_tk = \
                            feed.get_onehot_label_topk(
                            scores=scores, top_num=_top_num + 1)

                        for i in batch_predicted_onehot_labels_tk:
                            predicted_onehot_labels_tk[_top_num].append(i)

                    _eval_loss = _eval_loss + cur_loss
                    _eval_counter = _eval_counter + 1

                    if writer:
                        writer.add_summary(summaries, step)

                _eval_loss = float(_eval_loss / _eval_counter)

                # Calculate Precision & Recall & F1 (threshold & topK)
                _eval_pre_ts = precision_score(
                    y_true=np.array(true_onehot_labels),
                    y_pred=np.array(predicted_onehot_labels_ts),
                    average='micro')
                _eval_rec_ts = recall_score(
                    y_true=np.array(true_onehot_labels),
                    y_pred=np.array(predicted_onehot_labels_ts),
                    average='micro')
                _eval_F_ts = f1_score(
                    y_true=np.array(true_onehot_labels),
                    y_pred=np.array(predicted_onehot_labels_ts),
                    average='micro')

                for _top_num in range(FLAGS.top_num):
                    _eval_pre_tk[_top_num] = precision_score(
                        y_true=np.array(true_onehot_labels),
                        y_pred=np.array(predicted_onehot_labels_tk[_top_num]),
                        average='micro')
                    _eval_rec_tk[_top_num] = recall_score(
                        y_true=np.array(true_onehot_labels),
                        y_pred=np.array(predicted_onehot_labels_tk[_top_num]),
                        average='micro')
                    _eval_F_tk[_top_num] = f1_score(
                        y_true=np.array(true_onehot_labels),
                        y_pred=np.array(predicted_onehot_labels_tk[_top_num]),
                        average='micro')

                # Calculate the average AUC
                _eval_auc = roc_auc_score(
                    y_true=np.array(true_onehot_labels),
                    y_score=np.array(predicted_onehot_scores),
                    average='micro')
                # Calculate the average PR
                _eval_prc = average_precision_score(
                    y_true=np.array(true_onehot_labels),
                    y_score=np.array(predicted_onehot_scores),
                    average='micro')

                return _eval_loss, _eval_auc, _eval_prc, _eval_rec_ts, \
                       _eval_pre_ts, _eval_F_ts, _eval_rec_tk, _eval_pre_tk,\
                       _eval_F_tk

            # Generate batches
            batches_train = feed.batch_iter(list(zip(x_train, y_train)),
                                            FLAGS.batch_size, FLAGS.num_epochs)

            num_batches_per_epoch = int(
                (len(x_train) - 1) / FLAGS.batch_size) + 1

            # Training loop. For each batch...
            for batch_train in batches_train:
                x_batch_train, y_batch_train = zip(*batch_train)
                train_step(x_batch_train, y_batch_train)
                current_step = tf.train.global_step(sess, ann.global_step)

                if current_step % FLAGS.evaluate_every == 0:
                    logger.info("\nEvaluation:")
                    eval_loss, eval_auc, eval_prc, \
                    eval_rec_ts, eval_pre_ts, eval_F_ts, eval_rec_tk, \
                    eval_pre_tk, eval_F_tk = \
                        validation_step(x_val, y_val,
                                        writer=validation_summary_writer)

                    logger.info(
                        "All Validation set: Loss {0:g} | AUC {1:g} | AUPRC {2:g}"
                        .format(eval_loss, eval_auc, eval_prc))

                    # Predict by threshold
                    logger.info("☛ Predict by threshold: Precision {0:g}, "
                                "Recall {1:g}, F {2:g}".format(
                                    eval_pre_ts, eval_rec_ts, eval_F_ts))

                    # Predict by topK
                    logger.info("☛ Predict by topK:")
                    for top_num in range(FLAGS.top_num):
                        logger.info(
                            "Top{0}: Precision {1:g}, Recall {2:g}, F {3:g}".
                            format(top_num + 1, eval_pre_tk[top_num],
                                   eval_rec_tk[top_num], eval_F_tk[top_num]))
                    best_saver.handle(eval_prc, sess, current_step)
                if current_step % FLAGS.checkpoint_every == 0:
                    checkpoint_prefix = os.path.join(checkpoint_dir, "model")
                    path = saver.save(sess,
                                      checkpoint_prefix,
                                      global_step=current_step)
                    logger.info(
                        "✔︎ Saved model checkpoint to {0}\n".format(path))
                if current_step % num_batches_per_epoch == 0:
                    current_epoch = current_step // num_batches_per_epoch
                    logger.info(
                        "✔︎ Epoch {0} has finished!".format(current_epoch))

    logger.info("✔︎ Done.")
Exemple #2
0
def test_ann(word2vec_path):
    """Test ANN model."""

    # Load data
    logger.info("✔︎ Loading data...")
    logger.info("Recommended padding Sequence length is: {0}".format(
        FLAGS.pad_seq_len))

    logger.info("✔︎ Test data processing...")
    test_data = feed.load_data_and_labels(FLAGS.test_data_file,
                                          FLAGS.num_classes,
                                          FLAGS.embedding_dim,
                                          data_aug_flag=False,
                                          word2vec_path=word2vec_path)

    logger.info("✔︎ Test data padding...")
    x_test, y_test = feed.pad_data(test_data, FLAGS.pad_seq_len)
    y_test_labels = test_data.labels

    # Load ann model
    BEST_OR_LATEST = input("☛ Load Best or Latest Model?(B/L): ")

    while not (BEST_OR_LATEST.isalpha()
               and BEST_OR_LATEST.upper() in ['B', 'L']):
        BEST_OR_LATEST = \
            input("✘ The format of your input is illegal, please re-input: ")
    if BEST_OR_LATEST.upper() == 'B':
        logger.info("✔︎ Loading best model...")
        checkpoint_file = checkpoints.get_best_checkpoint(
            FLAGS.best_checkpoint_dir, select_maximum_value=True)
    else:
        logger.info("✔︎ Loading latest model...")
        checkpoint_file = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
    logger.info(checkpoint_file)

    graph = tf.Graph()
    with graph.as_default():
        session_conf = tf.ConfigProto(
            allow_soft_placement=FLAGS.allow_soft_placement,
            log_device_placement=FLAGS.log_device_placement)
        session_conf.gpu_options.allow_growth = FLAGS.gpu_options_allow_growth
        sess = tf.Session(config=session_conf)
        with sess.as_default():
            # Load the saved meta graph and restore variables
            saver = tf.train.import_meta_graph(
                "{0}.meta".format(checkpoint_file))
            saver.restore(sess, checkpoint_file)

            # Get the placeholders from the graph by name
            input_x = graph.get_operation_by_name("input_x").outputs[0]
            input_y = graph.get_operation_by_name("input_y").outputs[0]
            dropout_keep_prob = graph.get_operation_by_name(
                "dropout_keep_prob").outputs[0]
            is_training = graph.get_operation_by_name("is_training").outputs[0]

            # Tensors we want to evaluate
            scores = graph.get_operation_by_name("output/scores").outputs[0]
            loss = graph.get_operation_by_name("loss/loss").outputs[0]

            # Split the output nodes name by '|' if you have several output
            # nodes
            output_node_names = "output/scores"

            # Save the .pb model file
            output_graph_def = tf.graph_util.convert_variables_to_constants(
                sess, sess.graph_def, output_node_names.split("|"))
            tf.train.write_graph(output_graph_def,
                                 "graph",
                                 "graph-ann-{0}.pb".format(MODEL),
                                 as_text=False)

            # Generate batches for one epoch
            batches = feed.batch_iter(list(zip(x_test, y_test, y_test_labels)),
                                      FLAGS.batch_size,
                                      1,
                                      shuffle=False)

            test_counter, test_loss = 0, 0.0

            test_pre_tk = [0.0] * FLAGS.top_num
            test_rec_tk = [0.0] * FLAGS.top_num
            test_F_tk = [0.0] * FLAGS.top_num

            # Collect the predictions here
            true_labels = []
            predicted_labels = []
            predicted_scores = []

            # Collect for calculating metrics
            true_onehot_labels = []
            predicted_onehot_scores = []
            predicted_onehot_labels_ts = []
            predicted_onehot_labels_tk = [[] for _ in range(FLAGS.top_num)]

            for batch_test in batches:
                x_batch_test, y_batch_test, y_batch_test_labels = zip(
                    *batch_test)
                print("x_batch_test", x_batch_test)
                print("y_batch_test", y_batch_test)
                feed_dict = {
                    input_x: x_batch_test,
                    input_y: y_batch_test,
                    dropout_keep_prob: 1.0,
                    is_training: False
                }
                batch_scores, cur_loss = sess.run([scores, loss], feed_dict)

                # Prepare for calculating metrics
                for i in y_batch_test:
                    true_onehot_labels.append(i)
                for j in batch_scores:
                    predicted_onehot_scores.append(j)

                # Get the predicted labels by threshold
                batch_predicted_labels_ts, batch_predicted_scores_ts = \
                    feed.get_label_threshold(scores=batch_scores,
                                             threshold=FLAGS.threshold)

                # Add results to collection
                for i in y_batch_test_labels:
                    true_labels.append(i)
                for j in batch_predicted_labels_ts:
                    predicted_labels.append(j)
                for k in batch_predicted_scores_ts:
                    predicted_scores.append(k)

                # Get onehot predictions by threshold
                batch_predicted_onehot_labels_ts = \
                    feed.get_onehot_label_threshold(scores=batch_scores,
                                                    threshold=FLAGS.threshold)
                for i in batch_predicted_onehot_labels_ts:
                    predicted_onehot_labels_ts.append(i)

                # Get onehot predictions by topK
                for top_num in range(FLAGS.top_num):
                    batch_predicted_onehot_labels_tk = feed.\
                        get_onehot_label_topk(scores=batch_scores,
                                              top_num=top_num + 1)

                    for i in batch_predicted_onehot_labels_tk:
                        predicted_onehot_labels_tk[top_num].append(i)

                test_loss = test_loss + cur_loss
                test_counter = test_counter + 1

            # Calculate Precision & Recall & F1 (threshold & topK)
            test_pre_ts = precision_score(
                y_true=np.array(true_onehot_labels),
                y_pred=np.array(predicted_onehot_labels_ts),
                average='micro')
            test_rec_ts = recall_score(
                y_true=np.array(true_onehot_labels),
                y_pred=np.array(predicted_onehot_labels_ts),
                average='micro')
            test_F_ts = f1_score(y_true=np.array(true_onehot_labels),
                                 y_pred=np.array(predicted_onehot_labels_ts),
                                 average='micro')

            for top_num in range(FLAGS.top_num):
                test_pre_tk[top_num] = precision_score(
                    y_true=np.array(true_onehot_labels),
                    y_pred=np.array(predicted_onehot_labels_tk[top_num]),
                    average='micro')
                test_rec_tk[top_num] = recall_score(
                    y_true=np.array(true_onehot_labels),
                    y_pred=np.array(predicted_onehot_labels_tk[top_num]),
                    average='micro')
                test_F_tk[top_num] = f1_score(
                    y_true=np.array(true_onehot_labels),
                    y_pred=np.array(predicted_onehot_labels_tk[top_num]),
                    average='micro')

            # Calculate the average AUC
            test_auc = roc_auc_score(y_true=np.array(true_onehot_labels),
                                     y_score=np.array(predicted_onehot_scores),
                                     average='micro')

            # Calculate the average PR
            test_prc = average_precision_score(
                y_true=np.array(true_onehot_labels),
                y_score=np.array(predicted_onehot_scores),
                average="micro")
            test_loss = float(test_loss / test_counter)

            logger.info(
                "☛ All Test Dataset: Loss {0:g} | AUC {1:g} | AUPRC {2:g}".
                format(test_loss, test_auc, test_prc))

            # Predict by threshold
            logger.info(
                "☛ Predict by threshold: Precision {0:g}, Recall {1:g}, F1 {2:g}"
                .format(test_pre_ts, test_rec_ts, test_F_ts))

            # Predict by topK
            logger.info("☛ Predict by topK:")
            for top_num in range(FLAGS.top_num):
                logger.info(
                    "Top{0}: Precision {1:g}, Recall {2:g}, F {3:g}".format(
                        top_num + 1, test_pre_tk[top_num],
                        test_rec_tk[top_num], test_F_tk[top_num]))

            # Save the prediction result
            if not os.path.exists(SAVE_DIR):
                os.makedirs(SAVE_DIR)
            feed.create_prediction_file(output_file=SAVE_DIR +
                                        "/predictions.json",
                                        data_id=test_data.testid,
                                        all_labels=true_labels,
                                        all_predict_labels=predicted_labels,
                                        all_predict_scores=predicted_scores)

    logger.info("✔︎ Done.")
Exemple #3
0
def test_ann(word2vec_path, model_number):
    # Parameters
    # =============================================================================

    logger = feed.logger_fn("tflog",
                            "logs/test-{0}.log".format(time.asctime()))

    # MODEL = input("☛ Please input the model file you want to test, "
    #               "it should be like(1490175368): ")

    MODEL = str(model_number)

    while not (MODEL.isdigit() and len(MODEL) == 10):
        MODEL = input("✘ The format of your input is illegal, "
                      "it should be like(1490175368), please re-input: ")

    logger.info("✔︎ The format of your input is legal, "
                "now loading to next step...")

    TRAININGSET_DIR = 'models/citability/data/Train.json'
    VALIDATIONSET_DIR = 'models/citability/data/Validation.json'
    # TEST_DIR = 'data/Test.json'
    cwd = os.getcwd()
    TEST_DIR = os.path.join(cwd, 'web/test_data.json')

    cwd = os.getcwd()
    MODEL_DIR = os.path.join(cwd, 'web/runs/' + MODEL + '/checkpoints/')
    print(MODEL_DIR)
    BEST_MODEL_DIR = 'runs/' + MODEL + '/bestcheckpoints/'
    SAVE_DIR = 'results/' + MODEL

    # Data Parameters
    tf.flags.DEFINE_string("training_data_file", TRAININGSET_DIR,
                           "Data source for the training data.")
    tf.flags.DEFINE_string("validation_data_file", VALIDATIONSET_DIR,
                           "Data source for the validation data")
    tf.flags.DEFINE_string("test_data_file", TEST_DIR,
                           "Data source for the test data")
    tf.flags.DEFINE_string("checkpoint_dir", MODEL_DIR,
                           "Checkpoint directory from training run")
    tf.flags.DEFINE_string("best_checkpoint_dir", BEST_MODEL_DIR,
                           "Best checkpoint directory from training run")

    # Model Hyperparameters
    tf.flags.DEFINE_integer(
        "pad_seq_len", 35842, "Recommended padding Sequence length of data "
        "(depends on the data)")
    tf.flags.DEFINE_integer(
        "embedding_dim", 300, "Dimensionality of character embedding "
        "(default: 128)")
    tf.flags.DEFINE_integer("embedding_type", 1,
                            "The embedding type (default: 1)")
    tf.flags.DEFINE_integer(
        "fc_hidden_size", 1024, "Hidden size for fully connected layer "
        "(default: 1024)")
    tf.flags.DEFINE_float("dropout_keep_prob", 0.5,
                          "Dropout keep probability (default: 0.5)")
    tf.flags.DEFINE_float("l2_reg_lambda", 0.0,
                          "L2 regularization lambda (default: 0.0)")
    tf.flags.DEFINE_integer("num_classes", 80,
                            "Number of labels (depends on the task)")
    tf.flags.DEFINE_integer("top_num", 80,
                            "Number of top K prediction classes (default: 5)")
    tf.flags.DEFINE_float("threshold", 0.5,
                          "Threshold for prediction classes (default: 0.5)")

    # Test Parameters
    tf.flags.DEFINE_integer("batch_size", 1, "Batch Size (default: 1)")

    # Misc Parameters
    tf.flags.DEFINE_boolean("allow_soft_placement", True,
                            "Allow device soft device placement")
    tf.flags.DEFINE_boolean("log_device_placement", False,
                            "Log placement of ops on devices")
    tf.flags.DEFINE_boolean("gpu_options_allow_growth", True,
                            "Allow gpu options growth")

    FLAGS = tf.flags.FLAGS
    FLAGS(sys.argv)
    dilim = '-' * 100
    logger.info('\n'.join([
        dilim, *[
            '{0:>50}|{1:<50}'.format(attr.upper(), FLAGS.__getattr__(attr))
            for attr in sorted(FLAGS.__dict__['__wrapped'])
        ], dilim
    ]))
    """Test ANN model."""

    # Load data
    logger.info("✔︎ Loading data...")
    logger.info("Recommended padding Sequence length is: {0}".format(
        FLAGS.pad_seq_len))

    logger.info("✔︎ Test data processing...")
    test_data = feed.load_data_and_labels(FLAGS.test_data_file,
                                          FLAGS.num_classes,
                                          FLAGS.embedding_dim,
                                          data_aug_flag=False,
                                          word2vec_path=word2vec_path)

    logger.info("✔︎ Test data padding...")
    x_test, y_test = feed.pad_data(test_data, FLAGS.pad_seq_len)
    y_test_labels = test_data.labels

    # Load ann model
    # BEST_OR_LATEST = input("☛ Load Best or Latest Model?(B/L): ")
    BEST_OR_LATEST = 'L'

    while not (BEST_OR_LATEST.isalpha()
               and BEST_OR_LATEST.upper() in ['B', 'L']):
        BEST_OR_LATEST = \
            input("✘ The format of your input is illegal, please re-input: ")
    if BEST_OR_LATEST.upper() == 'B':
        logger.info("✔︎ Loading best model...")
        checkpoint_file = checkpoints.get_best_checkpoint(
            FLAGS.best_checkpoint_dir, select_maximum_value=True)
    else:
        logger.info("✔︎ Loading latest model...")
        checkpoint_file = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
    logger.info(checkpoint_file)

    graph = tf.Graph()
    with graph.as_default():
        session_conf = tf.ConfigProto(
            allow_soft_placement=FLAGS.allow_soft_placement,
            log_device_placement=FLAGS.log_device_placement)
        session_conf.gpu_options.allow_growth = FLAGS.gpu_options_allow_growth
        sess = tf.Session(config=session_conf)
        with sess.as_default():
            # Load the saved meta graph and restore variables
            saver = tf.train.import_meta_graph(
                "{0}.meta".format(checkpoint_file))
            saver.restore(sess, checkpoint_file)

            # Get the placeholders from the graph by name
            input_x = graph.get_operation_by_name("input_x").outputs[0]
            input_y = graph.get_operation_by_name("input_y").outputs[0]
            dropout_keep_prob = graph.get_operation_by_name(
                "dropout_keep_prob").outputs[0]
            is_training = graph.get_operation_by_name("is_training").outputs[0]

            # Tensors we want to evaluate
            scores = graph.get_operation_by_name("output/scores").outputs[0]
            loss = graph.get_operation_by_name("loss/loss").outputs[0]

            # Split the output nodes name by '|' if you have several output
            # nodes
            output_node_names = "output/scores"

            # Save the .pb model file
            output_graph_def = tf.graph_util.convert_variables_to_constants(
                sess, sess.graph_def, output_node_names.split("|"))
            tf.train.write_graph(output_graph_def,
                                 "graph",
                                 "graph-ann-{0}.pb".format(MODEL),
                                 as_text=False)

            # Generate batches for one epoch
            batches = feed.batch_iter(list(zip(x_test, y_test, y_test_labels)),
                                      FLAGS.batch_size,
                                      1,
                                      shuffle=False)

            test_counter, test_loss = 0, 0.0

            test_pre_tk = [0.0] * FLAGS.top_num
            test_rec_tk = [0.0] * FLAGS.top_num
            test_F_tk = [0.0] * FLAGS.top_num

            # Collect the predictions here
            true_labels = []
            predicted_labels = []
            predicted_scores = []

            # Collect for calculating metrics
            true_onehot_labels = []
            predicted_onehot_scores = []
            predicted_onehot_labels_ts = []
            predicted_onehot_labels_tk = [[] for _ in range(FLAGS.top_num)]

            for batch_test in batches:
                x_batch_test, y_batch_test, y_batch_test_labels = zip(
                    *batch_test)
                print("x_batch_test", x_batch_test)
                print("y_batch_test", y_batch_test)
                feed_dict = {
                    input_x: x_batch_test,
                    input_y: y_batch_test,
                    dropout_keep_prob: 1.0,
                    is_training: False
                }
                batch_scores, cur_loss = sess.run([scores, loss], feed_dict)

                # Prepare for calculating metrics
                for i in y_batch_test:
                    true_onehot_labels.append(i)
                for j in batch_scores:
                    predicted_onehot_scores.append(j)

                # Get the predicted labels by threshold
                batch_predicted_labels_ts, batch_predicted_scores_ts = \
                    feed.get_label_threshold(scores=batch_scores,
                                             threshold=FLAGS.threshold)

                # Add results to collection
                for i in y_batch_test_labels:
                    true_labels.append(i)
                for j in batch_predicted_labels_ts:
                    predicted_labels.append(j)
                for k in batch_predicted_scores_ts:
                    predicted_scores.append(k)

                # Get onehot predictions by threshold
                batch_predicted_onehot_labels_ts = \
                    feed.get_onehot_label_threshold(scores=batch_scores,
                                                    threshold=FLAGS.threshold)
                for i in batch_predicted_onehot_labels_ts:
                    predicted_onehot_labels_ts.append(i)

                # Get onehot predictions by topK
                for top_num in range(FLAGS.top_num):
                    batch_predicted_onehot_labels_tk = feed.\
                        get_onehot_label_topk(scores=batch_scores,
                                              top_num=top_num + 1)

                    for i in batch_predicted_onehot_labels_tk:
                        predicted_onehot_labels_tk[top_num].append(i)

                test_loss = test_loss + cur_loss
                test_counter = test_counter + 1

            # Calculate Precision & Recall & F1 (threshold & topK)
            test_pre_ts = precision_score(
                y_true=np.array(true_onehot_labels),
                y_pred=np.array(predicted_onehot_labels_ts),
                average='micro')
            test_rec_ts = recall_score(
                y_true=np.array(true_onehot_labels),
                y_pred=np.array(predicted_onehot_labels_ts),
                average='micro')
            test_F_ts = f1_score(y_true=np.array(true_onehot_labels),
                                 y_pred=np.array(predicted_onehot_labels_ts),
                                 average='micro')

            for top_num in range(FLAGS.top_num):
                test_pre_tk[top_num] = precision_score(
                    y_true=np.array(true_onehot_labels),
                    y_pred=np.array(predicted_onehot_labels_tk[top_num]),
                    average='micro')
                test_rec_tk[top_num] = recall_score(
                    y_true=np.array(true_onehot_labels),
                    y_pred=np.array(predicted_onehot_labels_tk[top_num]),
                    average='micro')
                test_F_tk[top_num] = f1_score(
                    y_true=np.array(true_onehot_labels),
                    y_pred=np.array(predicted_onehot_labels_tk[top_num]),
                    average='micro')

            # Calculate the average AUC
            test_auc = roc_auc_score(y_true=np.array(true_onehot_labels),
                                     y_score=np.array(predicted_onehot_scores),
                                     average='micro')

            # Calculate the average PR
            test_prc = average_precision_score(
                y_true=np.array(true_onehot_labels),
                y_score=np.array(predicted_onehot_scores),
                average="micro")
            test_loss = float(test_loss / test_counter)

            logger.info(
                "☛ All Test Dataset: Loss {0:g} | AUC {1:g} | AUPRC {2:g}".
                format(test_loss, test_auc, test_prc))

            # Predict by threshold
            logger.info(
                "☛ Predict by threshold: Precision {0:g}, Recall {1:g}, F1 {2:g}"
                .format(test_pre_ts, test_rec_ts, test_F_ts))

            # Predict by topK
            logger.info("☛ Predict by topK:")
            for top_num in range(FLAGS.top_num):
                logger.info(
                    "Top{0}: Precision {1:g}, Recall {2:g}, F {3:g}".format(
                        top_num + 1, test_pre_tk[top_num],
                        test_rec_tk[top_num], test_F_tk[top_num]))

            # Save the prediction result
            if not os.path.exists(SAVE_DIR):
                os.makedirs(SAVE_DIR)
            feed.create_prediction_file(output_file=SAVE_DIR +
                                        "/predictions.json",
                                        data_id=test_data.testid,
                                        all_labels=true_labels,
                                        all_predict_labels=predicted_labels,
                                        all_predict_scores=predicted_scores)

    logger.info("✔︎ Done.")