def test(sess):
    tf.set_random_seed(seed)
    ### Prepare data for training
    print("Prepare vocab dict and read pretrained word embeddings ...")
    vocab_dict, word_embedding_array = DataProcessor(
    ).prepare_vocab_embeddingdict()
    # vocab_dict contains _PAD and _UNK but not word_embedding_array

    print("Prepare test data ...")
    test_batch = batch_load_data(DataProcessor().prepare_news_data(
        vocab_dict, data_type="test"))

    fil_lens_to_test = FLAGS.max_filter_length - FLAGS.min_filter_length + 1
    if FLAGS.handle_filter_output == "concat" and FLAGS.sentembed_size % fil_lens_to_test != 0:
        q = int(FLAGS.sentembed_size // fil_lens_to_test)
        FLAGS.sentembed_size = q * fil_lens_to_test
        print("corrected embedding size: %d" % FLAGS.sentembed_size)

    # Create Model with various operations
    model = MY_Model(sess, len(vocab_dict) - 2)

    # Select the model

    selected_modelpath = FLAGS.train_dir + "/step-a.model.ckpt.epoch-" + str(
        FLAGS.model_to_load)

    # Reload saved model and test
    #print("Reading model parameters from %s" % selected_modelpath)
    model.saver.restore(sess, selected_modelpath)
    #print("Model loaded.")

    # Initialize word embedding before training
    sess.run(model.vocab_embed_variable.assign(word_embedding_array))

    # Test Accuracy and Prediction
    #print("Performance on the test data:")
    FLAGS.authorise_gold_label = False
    FLAGS.use_dropout = False
    test_batch = batch_predict_with_a_model(test_batch,
                                            "test",
                                            model,
                                            session=sess)
    probs = sess.run(model.predictions,
                     feed_dict={model.logits_placeholder: test_batch.logits})

    acc = accuracy_qas_top(probs, test_batch.labels, test_batch.weights,
                           test_batch.isf_score_ids)
    mrr = mrr_metric(probs, test_batch.labels, test_batch.weights,
                     test_batch.isf_score_ids, "test")
    _map = map_score(probs, test_batch.labels, test_batch.weights,
                     test_batch.isf_score_ids, "test")
    print("Metrics: acc: %.4f | mrr: %.4f | map: %.4f" % (acc, mrr, _map))

    # Writing test predictions and final summaries
    if FLAGS.save_preds:
        #print("Writing predictions...")
        modelname = "step-a.model.ckpt.epoch-" + str(FLAGS.model_to_load)
        write_prediction_summaries(test_batch, probs, modelname, "test")
        write_cos_sim(test_batch.cos_sim, modelname, "test")
def test_val(sess):
    """
  Test on validation Mode: Loads an existing model and test it on the validation set
  """
    tf.set_random_seed(seed)
    if FLAGS.load_prediction != -1:
        print(
            "====================================== [%d] ======================================"
            % (FLAGS.load_prediction))

    ### Prepare data for training
    #print("Prepare vocab dict and read pretrained word embeddings ...")
    vocab_dict, word_embedding_array = DataProcessor(
    ).prepare_vocab_embeddingdict()
    # vocab_dict contains _PAD and _UNK but not word_embedding_array

    val_batch = batch_load_data(DataProcessor().prepare_news_data(
        vocab_dict, data_type="validation"))

    fil_lens_to_test = FLAGS.max_filter_length - FLAGS.min_filter_length + 1
    if FLAGS.handle_filter_output == "concat" and FLAGS.sentembed_size % fil_lens_to_test != 0:
        q = int(FLAGS.sentembed_size // fil_lens_to_test)
        FLAGS.sentembed_size = q * fil_lens_to_test
        print("corrected embedding size: %d" % FLAGS.sentembed_size)

    # Create Model with various operations
    model = MY_Model(sess, len(vocab_dict) - 2)

    # Select the model
    selected_modelpath = FLAGS.train_dir + "/step-a.model.ckpt.epoch-" + str(
        FLAGS.model_to_load)

    # Reload saved model and test
    model.saver.restore(sess, selected_modelpath)

    # Initialize word embedding before training
    sess.run(model.vocab_embed_variable.assign(word_embedding_array))

    # Test Accuracy and Prediction
    FLAGS.authorise_gold_label = False
    FLAGS.use_dropout = False
    val_batch = batch_predict_with_a_model(val_batch,
                                           "validation",
                                           model,
                                           session=sess)
    FLAGS.authorise_gold_label = True
    probs = sess.run(model.predictions,
                     feed_dict={model.logits_placeholder: val_batch.logits})

    acc = accuracy_qas_top(probs, val_batch.labels, val_batch.weights,
                           val_batch.isf_score_ids)
    mrr = mrr_metric(probs, val_batch.labels, val_batch.weights,
                     val_batch.isf_score_ids, "validation")
    _map = map_score(probs, val_batch.labels, val_batch.weights,
                     val_batch.isf_score_ids, "validation")
    print("Metrics: acc: %.4f | mrr: %.4f | map: %.4f" % (acc, mrr, _map))

    if FLAGS.load_prediction != -1:
        fn = ''
        if FLAGS.filtered_setting:
            fn = "%s/step-a.model.ckpt.%s-top%d-isf-metrics" % (
                FLAGS.train_dir, "validation", FLAGS.topK)
        else:
            fn = "%s/step-a.model.ckpt.%s-metrics" % (FLAGS.train_dir,
                                                      "validation")
        save_metrics(fn, FLAGS.load_prediction, validation_acc, mrr_score,
                     mapsc)

    if FLAGS.save_preds:
        modelname = "step-a.model.ckpt.epoch-" + str(FLAGS.model_to_load)
        write_prediction_summaries(val_batch, probs, modelname, "validation")
        write_cos_sim(val_batch.cos_sim, modelname, "validation")