def test(sess):
    tf.set_random_seed(seed)
    ### Prepare data for training
    print("Prepare vocab dict and read pretrained word embeddings ...")
    vocab_dict, word_embedding_array = DataProcessor(
    ).prepare_vocab_embeddingdict()
    # vocab_dict contains _PAD and _UNK but not word_embedding_array

    print("Prepare test data ...")
    test_batch = batch_load_data(DataProcessor().prepare_news_data(
        vocab_dict, data_type="test"))

    fil_lens_to_test = FLAGS.max_filter_length - FLAGS.min_filter_length + 1
    if FLAGS.handle_filter_output == "concat" and FLAGS.sentembed_size % fil_lens_to_test != 0:
        q = int(FLAGS.sentembed_size // fil_lens_to_test)
        FLAGS.sentembed_size = q * fil_lens_to_test
        print("corrected embedding size: %d" % FLAGS.sentembed_size)

    # Create Model with various operations
    model = MY_Model(sess, len(vocab_dict) - 2)

    # Select the model

    selected_modelpath = FLAGS.train_dir + "/step-a.model.ckpt.epoch-" + str(
        FLAGS.model_to_load)

    # Reload saved model and test
    #print("Reading model parameters from %s" % selected_modelpath)
    model.saver.restore(sess, selected_modelpath)
    #print("Model loaded.")

    # Initialize word embedding before training
    sess.run(model.vocab_embed_variable.assign(word_embedding_array))

    # Test Accuracy and Prediction
    #print("Performance on the test data:")
    FLAGS.authorise_gold_label = False
    FLAGS.use_dropout = False
    test_batch = batch_predict_with_a_model(test_batch,
                                            "test",
                                            model,
                                            session=sess)
    probs = sess.run(model.predictions,
                     feed_dict={model.logits_placeholder: test_batch.logits})

    acc = accuracy_qas_top(probs, test_batch.labels, test_batch.weights,
                           test_batch.isf_score_ids)
    mrr = mrr_metric(probs, test_batch.labels, test_batch.weights,
                     test_batch.isf_score_ids, "test")
    _map = map_score(probs, test_batch.labels, test_batch.weights,
                     test_batch.isf_score_ids, "test")
    print("Metrics: acc: %.4f | mrr: %.4f | map: %.4f" % (acc, mrr, _map))

    # Writing test predictions and final summaries
    if FLAGS.save_preds:
        #print("Writing predictions...")
        modelname = "step-a.model.ckpt.epoch-" + str(FLAGS.model_to_load)
        write_prediction_summaries(test_batch, probs, modelname, "test")
        write_cos_sim(test_batch.cos_sim, modelname, "test")
def test(sess):
  ### Prepare data for training
  print("Prepare vocab dict and read pretrained word embeddings ...")
  vocab_dict, word_embedding_array = DataProcessor().prepare_vocab_embeddingdict()
  # vocab_dict contains _PAD and _UNK but not word_embedding_array

  print("Prepare test data ...")
  train_data = DataProcessor().prepare_news_data(vocab_dict, data_type="training")
  test_batch = batch_load_data(DataProcessor().prepare_news_data(vocab_dict,
                                                                data_type="test",
                                                                normalizer=train_data.normalizer,
                                                                pca_model=train_data.pca_model))
  del train_data

  # Create Model with various operations
  model = MY_Model(sess, len(vocab_dict)-2)

  # # Initialize word embedding before training
  #print("Initialize word embedding vocabulary with pretrained embeddings ...")
  #sess.run(model.vocab_embed_variable.assign(word_embedding_array))

  # Select the model
  if (os.path.isfile(FLAGS.train_dir+"/step-a.model.ckpt.epoch-"+str(FLAGS.model_to_load))):
    selected_modelpath = FLAGS.train_dir+"/step-a.model.ckpt.epoch-"+str(FLAGS.model_to_load)
  else:
    print("Model not found in checkpoint folder.")
    exit(0)

  # Reload saved model and test
  print("Reading model parameters from %s" % selected_modelpath)
  model.saver.restore(sess, selected_modelpath)
  print("Model loaded.")

  # Initialize word embedding before training
  print("Initialize word embedding vocabulary with pretrained embeddings ...")
  sess.run(model.vocab_embed_variable.assign(word_embedding_array))

  # Test Accuracy and Prediction
  print("Performance on the test data:")
  FLAGS.authorise_gold_label = False
  test_batch = batch_predict_with_a_model(test_batch,"test",model, session=sess)
  probs = sess.run(model.probs,feed_dict={model.logits_placeholder: test_batch.logits})

  test_acc = accuracy_qas_top(probs, test_batch.labels, test_batch.weights, test_batch.isf_score_ids)

  # Print Test Summary
  print("Test ("+str(test_batch.docs.shape[0])+") Accuracy = {:.6f}".format(test_acc))
  # Estimate MRR on validation set
  mrr_score = mrr_metric(probs, test_batch.labels, test_batch.weights, test_batch.isf_score_ids)
  print("Test ("+str(test_batch.docs.shape[0])+") MRR = {:.6f}".format(mrr_score))
  # Estimate MAP score on validation set
  mapsc = map_score(probs, test_batch.labels, test_batch.weights, test_batch.isf_score_ids)
  print("Test ("+str(test_batch.docs.shape[0])+") MAP = {:.6f}".format(mapsc))

  # Writing test predictions and final summaries
  modelname = "step-a.model.ckpt.epoch-" + str(FLAGS.model_to_load)
  write_prediction_summaries(test_batch, probs, modelname, "test")
def test_val(sess):
    """
  Test on validation Mode: Loads an existing model and test it on the validation set
  """
    tf.set_random_seed(seed)
    if FLAGS.load_prediction != -1:
        print(
            "====================================== [%d] ======================================"
            % (FLAGS.load_prediction))

    ### Prepare data for training
    #print("Prepare vocab dict and read pretrained word embeddings ...")
    vocab_dict, word_embedding_array = DataProcessor(
    ).prepare_vocab_embeddingdict()
    # vocab_dict contains _PAD and _UNK but not word_embedding_array

    val_batch = batch_load_data(DataProcessor().prepare_news_data(
        vocab_dict, data_type="validation"))

    fil_lens_to_test = FLAGS.max_filter_length - FLAGS.min_filter_length + 1
    if FLAGS.handle_filter_output == "concat" and FLAGS.sentembed_size % fil_lens_to_test != 0:
        q = int(FLAGS.sentembed_size // fil_lens_to_test)
        FLAGS.sentembed_size = q * fil_lens_to_test
        print("corrected embedding size: %d" % FLAGS.sentembed_size)

    # Create Model with various operations
    model = MY_Model(sess, len(vocab_dict) - 2)

    # Select the model
    selected_modelpath = FLAGS.train_dir + "/step-a.model.ckpt.epoch-" + str(
        FLAGS.model_to_load)

    # Reload saved model and test
    model.saver.restore(sess, selected_modelpath)

    # Initialize word embedding before training
    sess.run(model.vocab_embed_variable.assign(word_embedding_array))

    # Test Accuracy and Prediction
    FLAGS.authorise_gold_label = False
    FLAGS.use_dropout = False
    val_batch = batch_predict_with_a_model(val_batch,
                                           "validation",
                                           model,
                                           session=sess)
    FLAGS.authorise_gold_label = True
    probs = sess.run(model.predictions,
                     feed_dict={model.logits_placeholder: val_batch.logits})

    acc = accuracy_qas_top(probs, val_batch.labels, val_batch.weights,
                           val_batch.isf_score_ids)
    mrr = mrr_metric(probs, val_batch.labels, val_batch.weights,
                     val_batch.isf_score_ids, "validation")
    _map = map_score(probs, val_batch.labels, val_batch.weights,
                     val_batch.isf_score_ids, "validation")
    print("Metrics: acc: %.4f | mrr: %.4f | map: %.4f" % (acc, mrr, _map))

    if FLAGS.load_prediction != -1:
        fn = ''
        if FLAGS.filtered_setting:
            fn = "%s/step-a.model.ckpt.%s-top%d-isf-metrics" % (
                FLAGS.train_dir, "validation", FLAGS.topK)
        else:
            fn = "%s/step-a.model.ckpt.%s-metrics" % (FLAGS.train_dir,
                                                      "validation")
        save_metrics(fn, FLAGS.load_prediction, validation_acc, mrr_score,
                     mapsc)

    if FLAGS.save_preds:
        modelname = "step-a.model.ckpt.epoch-" + str(FLAGS.model_to_load)
        write_prediction_summaries(val_batch, probs, modelname, "validation")
        write_cos_sim(val_batch.cos_sim, modelname, "validation")
def train_simple(sess):
    """
  Training Mode: Create a new model and train the network
  """
    tf.set_random_seed(seed)
    ### Prepare data for training
    vocab_dict, word_embedding_array = DataProcessor(
    ).prepare_vocab_embeddingdict()
    # vocab_dict contains _PAD and _UNK but not word_embedding_array

    train_data = DataProcessor().prepare_news_data(vocab_dict,
                                                   data_type="training")

    # data in whole batch with padded matrixes
    val_batch = batch_load_data(DataProcessor().prepare_news_data(
        vocab_dict, data_type="validation"))

    fil_lens_to_test = FLAGS.max_filter_length - FLAGS.min_filter_length + 1
    if FLAGS.handle_filter_output == "concat" and FLAGS.sentembed_size % fil_lens_to_test != 0:
        q = int(FLAGS.sentembed_size // fil_lens_to_test)
        FLAGS.sentembed_size = q * fil_lens_to_test
        print("corrected embedding size: %d" % FLAGS.sentembed_size)

    # Create Model with various operations
    model = MY_Model(sess, len(vocab_dict) - 2)
    init_epoch = 1
    # Resume training if indicated Select the model
    if FLAGS.model_to_load != -1:
        selected_modelpath = FLAGS.train_dir + "/step-a.model.ckpt.epoch-" + str(
            FLAGS.model_to_load)
        init_epoch = FLAGS.model_to_load + 1
        print("Reading model parameters from %s" % selected_modelpath)
        model.saver.restore(sess, selected_modelpath)
        print("Model loaded.")

    # Initialize word embedding before training
    sess.run(model.vocab_embed_variable.assign(word_embedding_array))

    ### STEP A : Start Pretraining the policy with Supervised Labels: Simple Cross Entropy Training
    counter = 0
    max_val_acc = -1
    for epoch in range(init_epoch, FLAGS.train_epoch_crossentropy + 1):
        ep_time = time.time()  # to check duration

        train_data.shuffle_fileindices()
        # Start Batch Training
        step = 1
        total_loss = 0
        while (step * FLAGS.batch_size) <= len(train_data.fileindices):
            # Get batch data as Numpy Arrays
            batch = train_data.get_batch(((step - 1) * FLAGS.batch_size),
                                         (step * FLAGS.batch_size))

            # Run optimizer: optimize policy and reward estimator
            _, ce_loss = sess.run(
                [model.train_op_policynet_withgold, model.cross_entropy_loss],
                feed_dict={
                    model.document_placeholder: batch.docs,
                    model.label_placeholder: batch.labels,
                    model.weight_placeholder: batch.weights
                })
            total_loss += ce_loss
            step += 1
        #END-WHILE-TRAINING  ... but wait there is more
        ## eval metrics
        FLAGS.authorise_gold_label = False
        prev_use_dpt = FLAGS.use_dropout
        total_loss /= step
        FLAGS.use_dropout = False
        # retrieve batch with updated logits in it
        val_batch = batch_predict_with_a_model(val_batch,
                                               "validation",
                                               model,
                                               session=sess)
        FLAGS.authorise_gold_label = True
        FLAGS.use_dropout = prev_use_dpt
        probs = sess.run(
            model.predictions,
            feed_dict={model.logits_placeholder: val_batch.logits})
        validation_acc = accuracy_qas_top(probs, val_batch.labels,
                                          val_batch.weights,
                                          val_batch.isf_score_ids)
        val_mrr = mrr_metric(probs, val_batch.labels, val_batch.weights,
                             val_batch.isf_score_ids, "validation")
        val_map = map_score(probs, val_batch.labels, val_batch.weights,
                            val_batch.isf_score_ids, "validation")

        print(
            "\tEpoch %2d || Train ce_loss: %4.3f || Val acc: %.4f || Val mrr: %.4f || Val mac: %.4f || duration: %3.2f"
            % (epoch, total_loss, validation_acc, val_mrr, val_map,
               time.time() - ep_time))

        if FLAGS.save_models:
            print("Saving model after epoch completion")
            checkpoint_path = os.path.join(
                FLAGS.train_dir, "step-a.model.ckpt.epoch-" + str(epoch))
            model.saver.save(sess, checkpoint_path)
        print(
            "------------------------------------------------------------------------------------------"
        )
    #END-FOR-EPOCH

    print("Optimization Finished!")
def train_debug(sess):
    """
  Training Mode: Create a new model and train the network
  """
    tf.set_random_seed(seed)
    ### Prepare data for training
    print("Prepare vocab dict and read pretrained word embeddings ...")
    vocab_dict, word_embedding_array = DataProcessor(
    ).prepare_vocab_embeddingdict()
    # vocab_dict contains _PAD and _UNK but not word_embedding_array

    print("Prepare training data ...")
    train_data = DataProcessor().prepare_news_data(vocab_dict,
                                                   data_type="training")
    print("Training size: ", len(train_data.fileindices))

    print("Prepare validation data ...")
    # data in whole batch with padded matrixes
    val_batch = batch_load_data(DataProcessor().prepare_news_data(
        vocab_dict, data_type="validation"))
    print("Validation size: ", val_batch.docs.shape[0])

    fil_lens_to_test = FLAGS.max_filter_length - FLAGS.min_filter_length + 1
    if FLAGS.handle_filter_output == "concat" and FLAGS.sentembed_size % fil_lens_to_test != 0:
        q = int(FLAGS.sentembed_size // fil_lens_to_test)
        FLAGS.sentembed_size = q * fil_lens_to_test
        print("corrected embedding size: %d" % FLAGS.sentembed_size)

    # Create Model with various operations
    model = MY_Model(sess, len(vocab_dict) - 2)

    init_epoch = 1
    # Resume training if indicated Select the model
    if FLAGS.model_to_load != -1:
        selected_modelpath = FLAGS.train_dir + "/step-a.model.ckpt.epoch-" + str(
            FLAGS.model_to_load)
        init_epoch = FLAGS.model_to_load + 1
        print("Reading model parameters from %s" % selected_modelpath)
        model.saver.restore(sess, selected_modelpath)
        print("Model loaded.")

    # Initialize word embedding before training
    print(
        "Initialize word embedding vocabulary with pretrained embeddings ...")
    sess.run(model.vocab_embed_variable.assign(word_embedding_array))

    ### STEP A : Start Pretraining the policy with Supervised Labels: Simple Cross Entropy Training
    counter = 0
    max_val_acc = -1
    for epoch in range(init_epoch, FLAGS.train_epoch_crossentropy + 1):
        ep_time = time.time()  # to check duration

        train_data.shuffle_fileindices()

        # Start Batch Training
        step = 1
        total_ce_loss = 0
        total_train_acc = 0
        while (step * FLAGS.batch_size) <= len(train_data.fileindices):
            # Get batch data as Numpy Arrays
            batch = train_data.get_batch(((step - 1) * FLAGS.batch_size),
                                         (step * FLAGS.batch_size))

            # Run optimizer: optimize policy and reward estimator
            sess.run(
                [model.train_op_policynet_withgold],
                feed_dict={
                    model.document_placeholder: batch.docs,
                    model.label_placeholder: batch.labels,
                    model.weight_placeholder: batch.weights
                })

            prev_use_dpt = FLAGS.use_dropout
            FLAGS.use_dropout = False
            batch_logits, ce_loss, merged_summ = sess.run(
                [model.logits, model.cross_entropy_loss, model.merged],
                feed_dict={
                    model.document_placeholder: batch.docs,
                    model.label_placeholder: batch.labels,
                    model.weight_placeholder: batch.weights
                })
            total_ce_loss += ce_loss
            probs = sess.run(
                model.predictions,
                feed_dict={model.logits_placeholder: batch_logits})
            training_acc = accuracy_qas_top(probs, batch.labels, batch.weights,
                                            batch.isf_score_ids)
            FLAGS.use_dropout = prev_use_dpt
            total_train_acc += training_acc
            # Print the progress
            if (step % FLAGS.training_checkpoint) == 0:
                total_train_acc /= FLAGS.training_checkpoint
                acc_sum = sess.run(
                    model.tstepa_accuracy_summary,
                    feed_dict={model.train_acc_placeholder: total_train_acc})

                total_ce_loss /= FLAGS.training_checkpoint
                # Print Summary to Tensor Board
                model.summary_writer.add_summary(merged_summ, counter)
                model.summary_writer.add_summary(acc_sum, counter)

                # Performance on the validation set
                # Get Predictions: Prohibit the use of gold labels
                FLAGS.authorise_gold_label = False
                prev_use_dpt = FLAGS.use_dropout
                FLAGS.use_dropout = False
                val_batch = batch_predict_with_a_model(val_batch,
                                                       "validation",
                                                       model,
                                                       session=sess)
                FLAGS.use_dropout = prev_use_dpt
                FLAGS.authorise_gold_label = True

                # Validation Accuracy and Prediction
                probs = sess.run(
                    model.predictions,
                    feed_dict={model.logits_placeholder: val_batch.logits})
                validation_acc = accuracy_qas_top(probs, val_batch.labels,
                                                  val_batch.weights,
                                                  val_batch.isf_score_ids)

                ce_loss_val, ce_loss_sum, acc_sum = sess.run(
                    [
                        model.cross_entropy_loss_val,
                        model.ce_loss_summary_val,
                        model.vstepa_accuracy_summary
                    ],
                    feed_dict={
                        model.logits_placeholder: val_batch.logits,
                        model.label_placeholder: val_batch.labels,
                        model.weight_placeholder: val_batch.weights,
                        model.val_acc_placeholder: validation_acc
                    })

                # Print Validation Summary
                model.summary_writer.add_summary(acc_sum, counter)
                model.summary_writer.add_summary(ce_loss_sum, counter)
                print(
                    "Epoch %2d, step: %2d(%2d) || CE loss || Train : %4.3f , Val : %4.3f ||| ACC || Train : %.3f , Val : %.3f"
                    % (epoch, step, counter, total_ce_loss, ce_loss_val,
                       training_acc, validation_acc))
                total_ce_loss = 0
                total_train_acc = 0

            if (step % 5) == 0:  # to have comparable tensorboard plots
                counter += 1
            # Increase step
            step += 1
        #END-WHILE-TRAINING  ... but wait there is more
        ## eval metrics
        FLAGS.authorise_gold_label = False
        prev_use_dpt = FLAGS.use_dropout
        FLAGS.use_dropout = False
        val_batch = batch_predict_with_a_model(val_batch,
                                               "validation",
                                               model,
                                               session=sess)
        FLAGS.use_dropout = prev_use_dpt
        FLAGS.authorise_gold_label = True
        # Validation metrics
        probs = sess.run(
            model.predictions,
            feed_dict={model.logits_placeholder: val_batch.logits})
        acc = accuracy_qas_top(probs, val_batch.labels, val_batch.weights,
                               val_batch.isf_score_ids)
        mrr = mrr_metric(probs, val_batch.labels, val_batch.weights,
                         val_batch.isf_score_ids, "validation")
        _map = map_score(probs, val_batch.labels, val_batch.weights,
                         val_batch.isf_score_ids, "validation")
        print("Metrics: acc: %.4f | mrr: %.4f | map: %.4f" % (acc, mrr, _map))

        print("Epoch %2d : Duration: %.4f" % (epoch, time.time() - ep_time))
        if FLAGS.save_models:
            print("Saving model after epoch completion")
            checkpoint_path = os.path.join(
                FLAGS.train_dir, "step-a.model.ckpt.epoch-" + str(epoch))
            model.saver.save(sess, checkpoint_path)
        print(
            "------------------------------------------------------------------------------------------"
        )
    #END-FOR-EPOCH

    print("Optimization Finished!")
def train(sess):
    """
  Training Mode: Create a new model and train the network
  """
    ### Prepare data for training
    print("Prepare vocab dict and read pretrained word embeddings ...")
    vocab_dict, word_embedding_array = DataProcessor(
    ).prepare_vocab_embeddingdict()
    # vocab_dict contains _PAD and _UNK but not word_embedding_array

    print("Prepare training data ...")
    train_data = DataProcessor().prepare_news_data(vocab_dict,
                                                   data_type="training")

    print("Prepare validation data ...")
    # data in whole batch with padded matrixes
    val_batch = batch_load_data(DataProcessor().prepare_news_data(
        vocab_dict, data_type="validation"))

    fil_lens_to_test = FLAGS.max_filter_length - FLAGS.min_filter_length + 1
    if FLAGS.handle_filter_output == "concat" and FLAGS.sentembed_size % fil_lens_to_test != 0:
        q = int(FLAGS.sentembed_size // fil_lens_to_test)
        FLAGS.sentembed_size = q * fil_lens_to_test
        print("corrected embedding size: %d" % FLAGS.sentembed_size)

    # Create Model with various operations
    model = MY_Model(sess, len(vocab_dict) - 2)

    init_epoch = 1
    # Resume training if indicated Select the model
    if FLAGS.model_to_load != -1:
        if (os.path.isfile(FLAGS.train_dir + "/step-a.model.ckpt.epoch-" +
                           str(FLAGS.model_to_load))):
            selected_modelpath = FLAGS.train_dir + "/step-a.model.ckpt.epoch-" + str(
                FLAGS.model_to_load)
        else:
            print("Model not found in checkpoint folder.")
            exit(0)

        # Reload saved model and test
        init_epoch = FLAGS.model_to_load + 1
        print("Reading model parameters from %s" % selected_modelpath)
        model.saver.restore(sess, selected_modelpath)
        print("Model loaded.")

    # Initialize word embedding before training
    print(
        "Initialize word embedding vocabulary with pretrained embeddings ...")
    sess.run(model.vocab_embed_variable.assign(word_embedding_array))

    ### STEP A : Start Pretraining the policy with Supervised Labels: Simple Cross Entropy Training

    for epoch in range(init_epoch, FLAGS.train_epoch_crossentropy + 1):
        ep_time = time.time()  # to check duration

        print("STEP A: Epoch " + str(epoch) +
              " : Start pretraining with supervised labels")

        print("STEP A: Epoch " + str(epoch) +
              " : Reshuffle training document indices")
        train_data.shuffle_fileindices()

        # Start Batch Training
        step = 1
        total_ce_loss = 0
        counter = 0
        while (step * FLAGS.batch_size) <= len(train_data.fileindices):
            # Get batch data as Numpy Arrays
            batch = train_data.get_batch(((step - 1) * FLAGS.batch_size),
                                         (step * FLAGS.batch_size))

            # Run optimizer: optimize policy and reward estimator
            _, batch_logits, ce_loss, merged_summ = sess.run(
                [
                    model.train_op_policynet_withgold, model.logits,
                    model.cross_entropy_loss, model.merged
                ],
                feed_dict={
                    model.document_placeholder: batch.docs,
                    model.label_placeholder: batch.labels,
                    model.weight_placeholder: batch.weights
                })
            total_ce_loss += ce_loss
            # Print the progress
            if (step % FLAGS.training_checkpoint) == 0:
                probs = sess.run(
                    model.predictions,
                    feed_dict={model.logits_placeholder: batch_logits})
                training_acc = accuracy_qas_top(probs, batch.labels,
                                                batch.weights)
                acc_sum = sess.run(
                    model.tstepa_accuracy_summary,
                    feed_dict={model.train_acc_placeholder: training_acc})

                total_ce_loss /= FLAGS.training_checkpoint
                print("STEP A: Epoch " + str(epoch) + " : Covered " +
                      str(step * FLAGS.batch_size) + "/" +
                      str(len(train_data.fileindices)) +
                      " : Minibatch CE Loss= {:.6f}".format(total_ce_loss) +
                      ", Minibatch Accuracy= {:.6f}".format(training_acc))
                total_ce_loss = 0
                # Print Summary to Tensor Board
                model.summary_writer.add_summary(
                    merged_summ, (epoch - 1) * len(train_data.fileindices) +
                    step * FLAGS.batch_size)
                model.summary_writer.add_summary(
                    acc_sum, (epoch - 1) * len(train_data.fileindices) +
                    step * FLAGS.batch_size)
            # Increase step
            step += 1

            # if step == 100:
            # break

            #END-WHILE-TRAINING

        # Save Model
        print("STEP A: Epoch " + str(epoch) +
              " : Saving model after epoch completion")
        checkpoint_path = os.path.join(FLAGS.train_dir,
                                       "step-a.model.ckpt.epoch-" + str(epoch))
        model.saver.save(sess, checkpoint_path)

        # Performance on the validation set
        print("STEP A: Epoch " + str(epoch) +
              " : Performance on the validation data")
        # Get Predictions: Prohibit the use of gold labels
        FLAGS.authorise_gold_label = False
        FLAGS.use_dropout = False
        val_batch = batch_predict_with_a_model(val_batch,
                                               "validation",
                                               model,
                                               session=sess)
        FLAGS.use_dropout = True

        # Validation Accuracy and Prediction
        probs = sess.run(
            model.predictions,
            feed_dict={model.logits_placeholder: val_batch.logits})
        validation_acc = accuracy_qas_top(probs, val_batch.labels,
                                          val_batch.weights,
                                          val_batch.isf_score_ids)

        ce_loss_val, ce_loss_sum, acc_sum = sess.run(
            [
                model.cross_entropy_loss_val, model.ce_loss_summary_val,
                model.vstepa_accuracy_summary
            ],
            feed_dict={
                model.logits_placeholder: val_batch.logits,
                model.label_placeholder: val_batch.labels,
                model.weight_placeholder: val_batch.weights,
                model.val_acc_placeholder: validation_acc
            })

        # Print Validation Summary
        model.summary_writer.add_summary(acc_sum,
                                         epoch * len(train_data.fileindices))
        model.summary_writer.add_summary(ce_loss_sum,
                                         epoch * len(train_data.fileindices))
        print("STEP A: Epoch %s : Validation (%s) CE loss = %.6f" %
              (str(epoch), str(val_batch.docs.shape[0]), ce_loss_val))
        print("STEP A: Epoch %s : Validation (%s) Accuracy= %.6f" %
              (str(epoch), str(val_batch.docs.shape[0]), validation_acc))

        # Estimate MRR on validation set
        mrr_score = mrr_metric(probs, val_batch.labels, val_batch.weights,
                               val_batch.isf_score_ids, "validation")
        print("STEP A: Epoch %s : Validation (%s) MRR= %.6f" %
              (str(epoch), str(val_batch.docs.shape[0]), mrr_score))
        # Estimate MAP score on validation set
        mapsc = map_score(probs, val_batch.labels, val_batch.weights,
                          val_batch.isf_score_ids, "validation")
        print("STEP A: Epoch %s : Validation (%s) MAP= %.6f" %
              (str(epoch), str(val_batch.docs.shape[0]), mapsc))

        fn = "%s/step-a.model.ckpt.validation-metrics" % (FLAGS.train_dir)
        save_metrics(fn, FLAGS.load_prediction, validation_acc, mrr_score,
                     mapsc)

        print("STEP A: Epoch %d : Duration: %.4f" %
              (epoch, time.time() - ep_time))

    #END-FOR-EPOCH

    print("Optimization Finished!")
def test_val(sess):
  """
  Test on validation Mode: Loads an existing model and test it on the validation set
  """
  if FLAGS.load_prediction != -1:
    print("====================================== [%d] ======================================" % (FLAGS.load_prediction))

  ### Prepare data for training
  print("Prepare vocab dict and read pretrained word embeddings ...")
  vocab_dict, word_embedding_array = DataProcessor().prepare_vocab_embeddingdict()
  # vocab_dict contains _PAD and _UNK but not word_embedding_array

  print("Prepare test data ...")
  train_data = DataProcessor().prepare_news_data(vocab_dict, data_type="training")
  val_batch = batch_load_data(DataProcessor().prepare_news_data(vocab_dict,
                                                                data_type="validation",
                                                                normalizer=train_data.normalizer,
                                                                pca_model=train_data.pca_model))
  del train_data

  # Create Model with various operations
  model = MY_Model(sess, len(vocab_dict)-2)

  # # Initialize word embedding before training
  #print("Initialize word embedding vocabulary with pretrained embeddings ...")
  #sess.run(model.vocab_embed_variable.assign(word_embedding_array))

  # Select the model
  if (os.path.isfile(FLAGS.train_dir+"/step-a.model.ckpt.epoch-"+str(FLAGS.model_to_load))):
    selected_modelpath = FLAGS.train_dir+"/step-a.model.ckpt.epoch-"+str(FLAGS.model_to_load)
  else:
    print("Model not found in checkpoint folder.")
    exit(0)

  # Reload saved model and test
  print("Reading model parameters from %s" % selected_modelpath)
  model.saver.restore(sess, selected_modelpath)
  print("Model loaded.")

  # Initialize word embedding before training
  print("Initialize word embedding vocabulary with pretrained embeddings ...")
  sess.run(model.vocab_embed_variable.assign(word_embedding_array))

  # Test Accuracy and Prediction
  print("Performance on the validation data:")
  FLAGS.authorise_gold_label = False
  val_batch = batch_predict_with_a_model(val_batch,"validation", model, session=sess)
  FLAGS.authorise_gold_label = True
  probs = sess.run(model.probs,feed_dict={model.logits_placeholder: val_batch.logits})

  # Validation Accuracy and Prediction
  validation_acc = accuracy_qas_top(probs, val_batch.labels, val_batch.weights, val_batch.isf_score_ids)

  # Print Validation Summary
  print("Validation (%s) Accuracy= %.6f" % (str(val_batch.docs.shape[0]),validation_acc))
  # Estimate MRR on validation set
  mrr_score = mrr_metric(probs, val_batch.labels, val_batch.weights, val_batch.isf_score_ids)
  print("Validation (%s) MRR= %.6f" % (str(val_batch.docs.shape[0]),mrr_score))
  # Estimate MAP score on validation set
  mapsc = map_score(probs, val_batch.labels, val_batch.weights, val_batch.isf_score_ids)
  print("Validation (%s) MAP= %.6f" % (str(val_batch.docs.shape[0]),mapsc))

  if FLAGS.load_prediction != -1:
    fn = ''
    if FLAGS.filtered_setting:
      fn = "%s/step-a.model.ckpt.%s-top%d-isf-metrics" % (FLAGS.train_dir,"validation",FLAGS.topK)
    else:
      fn = "%s/step-a.model.ckpt.%s-metrics" % (FLAGS.train_dir,"validation")
    save_metrics(fn,FLAGS.load_prediction,validation_acc,mrr_score,mapsc)

  # Writing validation predictions, embeddings
  print("Writing final validation summaries and embeddings")
  modelname = "step-a.model.ckpt.epoch-" + str(FLAGS.model_to_load)
  write_prediction_summaries(val_batch, probs, modelname, "validation")
  write_cos_sim(val_batch.cos_sim, modelname + "-cos_sim","validation")
def train_debug(sess):
  """
  Training Mode: Create a new model and train the network
  """
  ### Prepare data for training
  print("Prepare vocab dict and read pretrained word embeddings ...")
  vocab_dict, word_embedding_array = DataProcessor().prepare_vocab_embeddingdict()
  # vocab_dict contains _PAD and _UNK but not word_embedding_array

  print("Prepare training data ...")
  train_data = DataProcessor().prepare_news_data(vocab_dict, data_type="training")
  print("Training size: ",len(train_data.fileindices))

  print("Prepare validation data ...")
  # data in whole batch with padded matrixes
  val_batch = batch_load_data(DataProcessor().prepare_news_data(vocab_dict,
                                                                data_type="validation",
                                                                normalizer=train_data.normalizer,
                                                                pca_model=train_data.pca_model))
  print("Validation size: ",val_batch.docs.shape[0])

  #print("Prepare ROUGE reward generator ...")
  #rouge_generator = Reward_Generator()

  # Create Model with various operations
  model = MY_Model(sess, len(vocab_dict)-2)
  
  init_epoch = 1
  # Resume training if indicated Select the model

  # Initialize word embedding before training
  print("Initialize word embedding vocabulary with pretrained embeddings ...")
  sess.run(model.vocab_embed_variable.assign(word_embedding_array))

  ### STEP A : Start Pretraining the policy with Supervised Labels: Simple Cross Entropy Training
  counter = 0
  for epoch in range(init_epoch, FLAGS.train_epoch_crossentropy+1):
    ep_time = time.time() # to check duration

    train_data.shuffle_fileindices()

    # Start Batch Training
    step = 1
    total_ce_loss = 0
    total_train_acc = 0
    while (step * FLAGS.batch_size) <= len(train_data.fileindices):
      # Get batch data as Numpy Arrays
      batch = train_data.get_batch(((step-1)*FLAGS.batch_size), (step * FLAGS.batch_size))

      # Run optimizer: optimize policy and reward estimator
      _,batch_logits,ce_loss,merged_summ = sess.run([
                                model.train_op_policynet_withgold,
                                model.logits,
                                model.cross_entropy_loss,
                                model.merged],
                                feed_dict={model.document_placeholder: batch.docs,
                                           model.label_placeholder: batch.labels,
                                           model.weight_placeholder: batch.weights,
                                           #model.cnt_placeholder: batch.cnt_score,
                                           model.isf_placeholder: batch.isf_score,
                                           model.idf_placeholder: batch.idf_score,
                                           model.locisf_score_placeholder:  batch.locisf_score
                                           #model.sent_len_placeholder: batch.sent_lens
                                           })

      total_ce_loss += ce_loss
      probs = sess.run(model.predictions,feed_dict={model.logits_placeholder: batch_logits})
      training_acc = accuracy_qas_top( probs,
                                       batch.labels,
                                       batch.weights,
                                       batch.isf_score_ids)
      total_train_acc += training_acc
      # Print the progress
      if (step % FLAGS.training_checkpoint) == 0:
        total_train_acc /= FLAGS.training_checkpoint
        acc_sum = sess.run( model.tstepa_accuracy_summary,
                            feed_dict={model.train_acc_placeholder: training_acc})
        
        total_ce_loss /= FLAGS.training_checkpoint
        # Print Summary to Tensor Board
        model.summary_writer.add_summary(merged_summ, counter)
        model.summary_writer.add_summary(acc_sum, counter)

        # Performance on the validation set
        # Get Predictions: Prohibit the use of gold labels
        FLAGS.authorise_gold_label = False
        FLAGS.use_dropout = False
        val_batch = batch_predict_with_a_model(val_batch,"validation", model, session=sess)
        FLAGS.authorise_gold_label = True
        FLAGS.use_dropout = True

        # Validation Accuracy and Prediction
        probs = sess.run(model.probs,feed_dict={model.logits_placeholder: val_batch.logits})
        validation_acc = accuracy_qas_top(probs, val_batch.labels, val_batch.weights, val_batch.isf_score_ids)

        ce_loss_val, ce_loss_sum, acc_sum = sess.run([ model.cross_entropy_loss_val,
                                                       model.ce_loss_summary_val,
                                                       model.vstepa_accuracy_summary],
                                                      feed_dict={model.logits_placeholder: val_batch.logits,
                                                                 model.label_placeholder:  val_batch.labels,
                                                                 model.weight_placeholder: val_batch.weights,
                                                                 model.isf_score_placeholder:  val_batch.isf_score,
                                                                 model.locisf_score_placeholder:  val_batch.locisf_score,
                                                                 model.val_acc_placeholder: validation_acc})

        # Print Validation Summary
        model.summary_writer.add_summary(acc_sum, counter)
        model.summary_writer.add_summary(ce_loss_sum, counter)
        
        print("Epoch %2d, step: %2d(%2d) || CE loss || Train : %4.3f , Val : %4.3f ||| ACC || Train : %.3f , Val : %.3f" % 
            (epoch,step,counter,total_ce_loss,ce_loss_val,total_train_acc,validation_acc))
        total_ce_loss = 0
        total_train_acc = 0

      if (step % 5) == 0: # to have comparable tensorboard plots
        counter += 1

      # Increase step
      step += 1
      
    #END-WHILE-TRAINING  ... but wait there is more 
    print("Epoch %2d : Duration: %.4f" % (epoch,time.time()-ep_time) )
    if not FLAGS.use_subsampled_dataset:
      print("Saving model after epoch completion")
      checkpoint_path = os.path.join(FLAGS.train_dir, "step-a.model.ckpt.epoch-"+str(epoch))
      model.saver.save(sess, checkpoint_path)

  #END-FOR-EPOCH

  print("Optimization Finished!")
Пример #9
0
            total_loss += ce_loss
            # Increase step
            if step%5000==0:
              print ("\tStep: ",step)
            step += 1
          #END-WHILE-TRAINING
          total_loss /= step
          FLAGS.authorise_gold_label = False
          FLAGS.use_dropout = False
          # retrieve batch with updated logits in it
          val_batch = batch_predict_with_a_model(val_batch, "validation", model, session=sess)
          FLAGS.authorise_gold_label = True
          FLAGS.use_dropout = prev_drpt
          probs = sess.run(model.predictions,feed_dict={model.logits_placeholder: val_batch.logits})
          probs,lab,w = group_by_doc(probs,val_batch.labels,val_batch.qids)
          validation_acc = accuracy_qas_top(probs,lab,w)
          val_mrr = mrr_metric(probs,lab,w,"validation")
 
          print("\tEpoch %2d || Train ce_loss: %4.3f || Val acc: %.4f || Val mrr: %.4f || duration: %3.2f" % 
            (epoch,total_loss,validation_acc,val_mrr,time.time()-ep_time))
          output.write("\tEpoch %2d || Train ce_loss: %4.3f || Val acc: %.4f || Val mrr: %.4f || duration: %3.2f\n" % 
            (epoch,total_loss,validation_acc,val_mrr,time.time()-ep_time))

          if validation_acc > best_acc:
            best_acc = validation_acc
            best_ep = epoch
          if val_mrr > best_mrr:
            best_mrr = val_mrr
            best_ep_mrr = epoch
          #break # for time testing
        #END-FOR-EPOCH
Пример #10
0
                                             model.weight_placeholder: batch.weights})
            total_loss += ce_loss
            # Increase step
            if step%500==0:
              print ("\tStep: ",step)
            step += 1
          #END-WHILE-TRAINING
          total_loss /= step
          FLAGS.authorise_gold_label = False
          FLAGS.use_dropout = False
          # retrieve batch with updated logits in it
          val_batch = batch_predict_with_a_model(val_batch, "validation", model, session=sess)
          FLAGS.authorise_gold_label = True
          FLAGS.use_dropout = prev_drpt
          probs = sess.run(model.predictions,feed_dict={model.logits_placeholder: val_batch.logits})
          validation_acc = accuracy_qas_top(probs, val_batch.labels, val_batch.weights, val_batch.isf_score_ids)
          val_mrr = mrr_metric(probs, val_batch.labels, val_batch.weights, val_batch.isf_score_ids,"validation")

          print("\tEpoch %2d || Train ce_loss: %4.3f || Val acc: %.4f || Val mrr: %.4f || duration: %3.2f" % 
            (epoch,total_loss,validation_acc,val_mrr,time.time()-ep_time))
          output.write("\tEpoch %2d || Train ce_loss: %4.3f || Val acc: %.4f || Val mrr: %.4f || duration: %3.2f\n" % 
            (epoch,total_loss,validation_acc,val_mrr,time.time()-ep_time))

          if validation_acc > best_acc:
            best_acc = validation_acc
            best_ep = epoch
          if val_mrr > best_mrr:
            best_mrr = val_mrr
            best_ep_mrr = epoch
          #break # for time testing
        #END-FOR-EPOCH
Пример #11
0
def evaluate_model(assignments,train_data,val_batch,score="mrr"):

  FLAGS.batch_size = assignments["batch_size"]
  FLAGS.learning_rate = math.exp(assignments["log_learning_rate"])
  FLAGS.size = assignments["size"]
  FLAGS.sentembed_size = assignments["sentembed_size"]

  #FLAGS.dropout = setup["dropout"]

  fil_lens_to_test = FLAGS.max_filter_length - FLAGS.min_filter_length + 1
  if FLAGS.handle_filter_output == "concat" and FLAGS.sentembed_size%fil_lens_to_test != 0:
    q = int(FLAGS.sentembed_size // fil_lens_to_test)
    FLAGS.sentembed_size = q * fil_lens_to_test

  print("Setup: bs: %d | lr: %f | size: %d | sent_emb: %d" % 
    (FLAGS.batch_size,
     FLAGS.learning_rate,
     FLAGS.size,
     FLAGS.sentembed_size)
    )

  with tf.Graph().as_default() and tf.device('/gpu:'+FLAGS.gpu_id):
    config = tf.ConfigProto(allow_soft_placement = True)
    tf.set_random_seed(seed)
    with tf.Session(config = config) as sess:
      model = MY_Model(sess, len(vocab_dict)-2)
      init_epoch = 1
      sess.run(model.vocab_embed_variable.assign(word_embedding_array))
      
      best_metric = -1
      best_ep = 0
      for epoch in range(init_epoch, FLAGS.train_epoch_crossentropy+1):
        ep_time = time.time() # to check duration

        train_data.shuffle_fileindices()
        total_loss = 0
        # Start Batch Training
        step = 1
        while (step * FLAGS.batch_size) <= len(train_data.fileindices):
          # Get batch data as Numpy Arrays
          batch = train_data.get_batch(((step-1)*FLAGS.batch_size), (step * FLAGS.batch_size))

          # Run optimizer: optimize policy and reward estimator
          _,ce_loss = sess.run([model.train_op_policynet_withgold,
                                model.cross_entropy_loss],
                                feed_dict={model.document_placeholder: batch.docs,
                                           model.label_placeholder: batch.labels,
                                           model.weight_placeholder: batch.weights,
                                           model.isf_score_placeholder: batch.isf_score,
                                           model.idf_score_placeholder: batch.idf_score,
                                           model.locisf_score_placeholder: batch.locisf_score})
          total_loss += ce_loss
          # Increase step
          if step%500==0:
            print ("\tStep: ",step)
          step += 1
        #END-WHILE-TRAINING
        total_loss /= step
        FLAGS.authorise_gold_label = False
        #FLAGS.use_dropout = False
        # retrieve batch with updated logits in it
        val_batch = batch_predict_with_a_model(val_batch, "validation", model, session=sess)
        FLAGS.authorise_gold_label = True
        #FLAGS.use_dropout = True
        probs = sess.run(model.predictions,feed_dict={model.logits_placeholder: val_batch.logits})
        if score=="acc":
          metric = accuracy_qas_top(probs, val_batch.labels, val_batch.weights, val_batch.isf_score_ids)
        elif score=="mrr":
          metric = mrr_metric(probs, val_batch.labels, val_batch.weights, val_batch.isf_score_ids,"validation")
        
        print("\tEpoch %2d || Train ce_loss: %4.3f || Val %s: %.4f || duration: %3.2f" % (epoch,total_loss,score,metric,time.time()-ep_time))
        
        if metric > best_metric:
          best_metric = metric
          best_ep = epoch
      #END-FOR-EPOCH
    # clear graph
    tf.reset_default_graph()
  #END-GRAPH
  print("Best metric:%.6f | ep: %d" % (best_metric,best_ep))

  return best_metric