def validation_step(x_validation,
                                y_validation,
                                y_validation_bind,
                                writer=None):
                """Evaluates model on a validation set"""
                batches_validation = data_helpers.batch_iter(
                    list(zip(x_validation, y_validation, y_validation_bind)),
                    FLAGS.batch_size, FLAGS.num_epochs)
                eval_loss, eval_rec, eval_acc, eval_counter = 0.0, 0.0, 0.0, 0
                for batch_validation in batches_validation:
                    x_batch_validation, y_batch_validation, y_batch_validation_bind = zip(
                        *batch_validation)
                    feed_dict = {
                        fasttext.input_x: x_batch_validation,
                        fasttext.input_y: y_batch_validation,
                        fasttext.dropout_keep_prob: 1.0,
                        fasttext.is_training: False
                    }
                    step, summaries, logits, cur_loss = sess.run([
                        fasttext.global_step, validation_summary_op,
                        fasttext.logits, fasttext.loss
                    ], feed_dict)

                    if FLAGS.use_classbind_or_not == 'Y':
                        predicted_labels = data_helpers.get_label_using_logits_and_classbind(
                            logits,
                            y_batch_validation_bind,
                            top_number=FLAGS.top_num)
                    if FLAGS.use_classbind_or_not == 'N':
                        predicted_labels = data_helpers.get_label_using_logits(
                            logits, top_number=FLAGS.top_num)

                    cur_rec, cur_acc = 0.0, 0.0
                    for index, predicted_label in enumerate(predicted_labels):
                        rec_inc, acc_inc = data_helpers.cal_rec_and_acc(
                            predicted_label, y_batch_validation[index])
                        cur_rec, cur_acc = cur_rec + rec_inc, cur_acc + acc_inc

                    cur_rec = cur_rec / len(y_batch_validation)
                    cur_acc = cur_acc / len(y_batch_validation)

                    eval_loss, eval_rec, eval_acc, eval_counter = eval_loss + cur_loss, eval_rec + cur_rec, \
                                                                  eval_acc + cur_acc, eval_counter + 1
                    logger.info("✔︎ validation batch {} finished.".format(
                        eval_counter))

                    if writer:
                        writer.add_summary(summaries, step)

                eval_loss = float(eval_loss / eval_counter)
                eval_rec = float(eval_rec / eval_counter)
                eval_acc = float(eval_acc / eval_counter)

                return eval_loss, eval_rec, eval_acc
def test_fasttext():
    """Test FASTTEXT model."""

    # Load data
    logger.info("✔ Loading data...")
    logger.info('Recommand padding Sequence length is: {}'.format(
        FLAGS.pad_seq_len))

    logger.info('✔︎ Test data processing...')
    test_data = data_helpers.load_data_and_labels(FLAGS.test_data_file,
                                                  FLAGS.num_classes,
                                                  FLAGS.embedding_dim)

    logger.info('✔︎ Test data padding...')
    x_test, y_test = data_helpers.pad_data(test_data, FLAGS.pad_seq_len)
    y_test_bind = test_data.labels_bind

    # Build vocabulary
    VOCAB_SIZE = data_helpers.load_vocab_size(FLAGS.embedding_dim)
    pretrained_word2vec_matrix = data_helpers.load_word2vec_matrix(
        VOCAB_SIZE, FLAGS.embedding_dim)

    # Load fasttext model
    logger.info("✔ Loading model...")
    checkpoint_file = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
    logger.info(checkpoint_file)

    graph = tf.Graph()
    with graph.as_default():
        session_conf = tf.ConfigProto(
            allow_soft_placement=FLAGS.allow_soft_placement,
            log_device_placement=FLAGS.log_device_placement)
        session_conf.gpu_options.allow_growth = FLAGS.gpu_options_allow_growth
        sess = tf.Session(config=session_conf)
        with sess.as_default():
            # Load the saved meta graph and restore variables
            saver = tf.train.import_meta_graph(
                "{}.meta".format(checkpoint_file))
            saver.restore(sess, checkpoint_file)

            # Get the placeholders from the graph by name
            input_x = graph.get_operation_by_name("input_x").outputs[0]

            # input_y = graph.get_operation_by_name("input_y").outputs[0]
            dropout_keep_prob = graph.get_operation_by_name(
                "dropout_keep_prob").outputs[0]

            # pre-trained_word2vec
            pretrained_embedding = graph.get_operation_by_name(
                "embedding/embedding").outputs[0]

            # Tensors we want to evaluate
            logits = graph.get_operation_by_name("output/logits").outputs[0]

            # Generate batches for one epoch
            batches = data_helpers.batch_iter(list(
                zip(x_test, y_test, y_test_bind)),
                                              FLAGS.batch_size,
                                              1,
                                              shuffle=False)

            # Collect the predictions here
            all_predicitons = []
            eval_loss, eval_rec, eval_acc, eval_counter = 0.0, 0.0, 0.0, 0
            for batch_test in batches:
                x_batch_test, y_batch_test, y_batch_test_bind = zip(
                    *batch_test)
                feed_dict = {input_x: x_batch_test, dropout_keep_prob: 1.0}
                batch_logits = sess.run(logits, feed_dict)

                if FLAGS.use_classbind_or_not == 'Y':
                    predicted_labels = data_helpers.get_label_using_logits_and_classbind(
                        batch_logits,
                        y_batch_test_bind,
                        top_number=FLAGS.top_num)
                if FLAGS.use_classbind_or_not == 'N':
                    predicted_labels = data_helpers.get_label_using_logits(
                        batch_logits, top_number=FLAGS.top_num)

                all_predicitons = np.append(all_predicitons, predicted_labels)
                cur_rec, cur_acc = 0.0, 0.0
                for index, predicted_label in enumerate(predicted_labels):
                    rec_inc, acc_inc = data_helpers.cal_rec_and_acc(
                        predicted_label, y_batch_test[index])
                    cur_rec, cur_acc = cur_rec + rec_inc, cur_acc + acc_inc

                cur_rec = cur_rec / len(y_batch_test)
                cur_acc = cur_acc / len(y_batch_test)

                eval_rec, eval_acc, eval_counter = eval_rec + cur_rec, eval_acc + cur_acc, eval_counter + 1
                logger.info(
                    "✔︎ validation batch {} finished.".format(eval_counter))

            eval_rec = float(eval_rec / eval_counter)
            eval_acc = float(eval_acc / eval_counter)
            logger.info("☛ Recall {:g}, Accuracy {:g}".format(
                eval_rec, eval_acc))
            np.savetxt(SAVE_FILE, list(zip(all_predicitons)), fmt='%s')

    logger.info("✔ Done.")