def test(sess): tf.set_random_seed(seed) ### Prepare data for training print("Prepare vocab dict and read pretrained word embeddings ...") vocab_dict, word_embedding_array = DataProcessor( ).prepare_vocab_embeddingdict() # vocab_dict contains _PAD and _UNK but not word_embedding_array print("Prepare test data ...") test_batch = batch_load_data(DataProcessor().prepare_news_data( vocab_dict, data_type="test")) fil_lens_to_test = FLAGS.max_filter_length - FLAGS.min_filter_length + 1 if FLAGS.handle_filter_output == "concat" and FLAGS.sentembed_size % fil_lens_to_test != 0: q = int(FLAGS.sentembed_size // fil_lens_to_test) FLAGS.sentembed_size = q * fil_lens_to_test print("corrected embedding size: %d" % FLAGS.sentembed_size) # Create Model with various operations model = MY_Model(sess, len(vocab_dict) - 2) # Select the model selected_modelpath = FLAGS.train_dir + "/step-a.model.ckpt.epoch-" + str( FLAGS.model_to_load) # Reload saved model and test #print("Reading model parameters from %s" % selected_modelpath) model.saver.restore(sess, selected_modelpath) #print("Model loaded.") # Initialize word embedding before training sess.run(model.vocab_embed_variable.assign(word_embedding_array)) # Test Accuracy and Prediction #print("Performance on the test data:") FLAGS.authorise_gold_label = False FLAGS.use_dropout = False test_batch = batch_predict_with_a_model(test_batch, "test", model, session=sess) probs = sess.run(model.predictions, feed_dict={model.logits_placeholder: test_batch.logits}) acc = accuracy_qas_top(probs, test_batch.labels, test_batch.weights, test_batch.isf_score_ids) mrr = mrr_metric(probs, test_batch.labels, test_batch.weights, test_batch.isf_score_ids, "test") _map = map_score(probs, test_batch.labels, test_batch.weights, test_batch.isf_score_ids, "test") print("Metrics: acc: %.4f | mrr: %.4f | map: %.4f" % (acc, mrr, _map)) # Writing test predictions and final summaries if FLAGS.save_preds: #print("Writing predictions...") modelname = "step-a.model.ckpt.epoch-" + str(FLAGS.model_to_load) write_prediction_summaries(test_batch, probs, modelname, "test") write_cos_sim(test_batch.cos_sim, modelname, "test")
def test_val(sess): """ Test on validation Mode: Loads an existing model and test it on the validation set """ tf.set_random_seed(seed) if FLAGS.load_prediction != -1: print( "====================================== [%d] ======================================" % (FLAGS.load_prediction)) ### Prepare data for training #print("Prepare vocab dict and read pretrained word embeddings ...") vocab_dict, word_embedding_array = DataProcessor( ).prepare_vocab_embeddingdict() # vocab_dict contains _PAD and _UNK but not word_embedding_array val_batch = batch_load_data(DataProcessor().prepare_news_data( vocab_dict, data_type="validation")) fil_lens_to_test = FLAGS.max_filter_length - FLAGS.min_filter_length + 1 if FLAGS.handle_filter_output == "concat" and FLAGS.sentembed_size % fil_lens_to_test != 0: q = int(FLAGS.sentembed_size // fil_lens_to_test) FLAGS.sentembed_size = q * fil_lens_to_test print("corrected embedding size: %d" % FLAGS.sentembed_size) # Create Model with various operations model = MY_Model(sess, len(vocab_dict) - 2) # Select the model selected_modelpath = FLAGS.train_dir + "/step-a.model.ckpt.epoch-" + str( FLAGS.model_to_load) # Reload saved model and test model.saver.restore(sess, selected_modelpath) # Initialize word embedding before training sess.run(model.vocab_embed_variable.assign(word_embedding_array)) # Test Accuracy and Prediction FLAGS.authorise_gold_label = False FLAGS.use_dropout = False val_batch = batch_predict_with_a_model(val_batch, "validation", model, session=sess) FLAGS.authorise_gold_label = True probs = sess.run(model.predictions, feed_dict={model.logits_placeholder: val_batch.logits}) acc = accuracy_qas_top(probs, val_batch.labels, val_batch.weights, val_batch.isf_score_ids) mrr = mrr_metric(probs, val_batch.labels, val_batch.weights, val_batch.isf_score_ids, "validation") _map = map_score(probs, val_batch.labels, val_batch.weights, val_batch.isf_score_ids, "validation") print("Metrics: acc: %.4f | mrr: %.4f | map: %.4f" % (acc, mrr, _map)) if FLAGS.load_prediction != -1: fn = '' if FLAGS.filtered_setting: fn = "%s/step-a.model.ckpt.%s-top%d-isf-metrics" % ( FLAGS.train_dir, "validation", FLAGS.topK) else: fn = "%s/step-a.model.ckpt.%s-metrics" % (FLAGS.train_dir, "validation") save_metrics(fn, FLAGS.load_prediction, validation_acc, mrr_score, mapsc) if FLAGS.save_preds: modelname = "step-a.model.ckpt.epoch-" + str(FLAGS.model_to_load) write_prediction_summaries(val_batch, probs, modelname, "validation") write_cos_sim(val_batch.cos_sim, modelname, "validation")