def train_simple(sess):
    """
  Training Mode: Create a new model and train the network
  """
    tf.set_random_seed(seed)
    ### Prepare data for training
    vocab_dict, word_embedding_array = DataProcessor(
    ).prepare_vocab_embeddingdict()
    # vocab_dict contains _PAD and _UNK but not word_embedding_array

    train_data = DataProcessor().prepare_news_data(vocab_dict,
                                                   data_type="training")

    # data in whole batch with padded matrixes
    val_batch = batch_load_data(DataProcessor().prepare_news_data(
        vocab_dict, data_type="validation"))

    fil_lens_to_test = FLAGS.max_filter_length - FLAGS.min_filter_length + 1
    if FLAGS.handle_filter_output == "concat" and FLAGS.sentembed_size % fil_lens_to_test != 0:
        q = int(FLAGS.sentembed_size // fil_lens_to_test)
        FLAGS.sentembed_size = q * fil_lens_to_test
        print("corrected embedding size: %d" % FLAGS.sentembed_size)

    # Create Model with various operations
    model = MY_Model(sess, len(vocab_dict) - 2)
    init_epoch = 1
    # Resume training if indicated Select the model
    if FLAGS.model_to_load != -1:
        selected_modelpath = FLAGS.train_dir + "/step-a.model.ckpt.epoch-" + str(
            FLAGS.model_to_load)
        init_epoch = FLAGS.model_to_load + 1
        print("Reading model parameters from %s" % selected_modelpath)
        model.saver.restore(sess, selected_modelpath)
        print("Model loaded.")

    # Initialize word embedding before training
    sess.run(model.vocab_embed_variable.assign(word_embedding_array))

    ### STEP A : Start Pretraining the policy with Supervised Labels: Simple Cross Entropy Training
    counter = 0
    max_val_acc = -1
    for epoch in range(init_epoch, FLAGS.train_epoch_crossentropy + 1):
        ep_time = time.time()  # to check duration

        train_data.shuffle_fileindices()
        # Start Batch Training
        step = 1
        total_loss = 0
        while (step * FLAGS.batch_size) <= len(train_data.fileindices):
            # Get batch data as Numpy Arrays
            batch = train_data.get_batch(((step - 1) * FLAGS.batch_size),
                                         (step * FLAGS.batch_size))

            # Run optimizer: optimize policy and reward estimator
            _, ce_loss = sess.run(
                [model.train_op_policynet_withgold, model.cross_entropy_loss],
                feed_dict={
                    model.document_placeholder: batch.docs,
                    model.label_placeholder: batch.labels,
                    model.weight_placeholder: batch.weights
                })
            total_loss += ce_loss
            step += 1
        #END-WHILE-TRAINING  ... but wait there is more
        ## eval metrics
        FLAGS.authorise_gold_label = False
        prev_use_dpt = FLAGS.use_dropout
        total_loss /= step
        FLAGS.use_dropout = False
        # retrieve batch with updated logits in it
        val_batch = batch_predict_with_a_model(val_batch,
                                               "validation",
                                               model,
                                               session=sess)
        FLAGS.authorise_gold_label = True
        FLAGS.use_dropout = prev_use_dpt
        probs = sess.run(
            model.predictions,
            feed_dict={model.logits_placeholder: val_batch.logits})
        validation_acc = accuracy_qas_top(probs, val_batch.labels,
                                          val_batch.weights,
                                          val_batch.isf_score_ids)
        val_mrr = mrr_metric(probs, val_batch.labels, val_batch.weights,
                             val_batch.isf_score_ids, "validation")
        val_map = map_score(probs, val_batch.labels, val_batch.weights,
                            val_batch.isf_score_ids, "validation")

        print(
            "\tEpoch %2d || Train ce_loss: %4.3f || Val acc: %.4f || Val mrr: %.4f || Val mac: %.4f || duration: %3.2f"
            % (epoch, total_loss, validation_acc, val_mrr, val_map,
               time.time() - ep_time))

        if FLAGS.save_models:
            print("Saving model after epoch completion")
            checkpoint_path = os.path.join(
                FLAGS.train_dir, "step-a.model.ckpt.epoch-" + str(epoch))
            model.saver.save(sess, checkpoint_path)
        print(
            "------------------------------------------------------------------------------------------"
        )
    #END-FOR-EPOCH

    print("Optimization Finished!")
Ejemplo n.º 2
0
def train():
    """
  Training Mode: Create a new model and train the network
  """

    # Training: use the tf default graph
    with tf.Graph().as_default() and tf.device(FLAGS.use_gpu):

        config = tf.ConfigProto(allow_soft_placement=True)

        # Start a session
        with tf.Session(config=config) as sess:

            ### Prepare data for training
            print("Prepare vocab dict and read pretrained word embeddings ...")
            vocab_dict, word_embedding_array = DataProcessor(
            ).prepare_vocab_embeddingdict()
            # vocab_dict contains _PAD and _UNK but not word_embedding_array

            print("Prepare training data ...")
            train_data = DataProcessor().prepare_news_data(
                vocab_dict, data_type="training")

            print("Prepare validation data ...")
            validation_data = DataProcessor().prepare_news_data(
                vocab_dict, data_type="validation")

            print("Prepare ROUGE reward generator ...")
            rouge_generator = Reward_Generator()

            # Create Model with various operations
            model = MY_Model(sess, len(vocab_dict) - 2)

            # Start training with some pretrained model
            start_epoch = 1
            # selected_modelpath = FLAGS.train_dir+"/model.ckpt.epoch-"+str(start_epoch-1)
            # if not (os.path.isfile(selected_modelpath)):
            #   print("Model not found in checkpoint folder.")
            #   exit(0)
            # # Reload saved model and test
            # print("Reading model parameters from %s" % selected_modelpath)
            # model.saver.restore(sess, selected_modelpath)
            # print("Model loaded.")

            # Initialize word embedding before training
            print(
                "Initialize word embedding vocabulary with pretrained embeddings ..."
            )
            sess.run(model.vocab_embed_variable.assign(word_embedding_array))

            ########### Start (No Mixer) Training : Reinforcement learning ################
            # Reward aware training as part of Reward weighted CE ,
            # No Curriculam learning: No annealing, consider full document like in MRT
            # Multiple Samples (include gold sample),  No future reward, Similar to MRT
            # During training does not use PYROUGE to avoid multiple file rewritings
            # Approximate MRT with multiple pre-estimated oracle samples
            # June 2017: Use Single sample from multiple oracles
            ###############################################################################

            print(
                "Start Reinforcement Training (single rollout from largest prob mass) ..."
            )

            for epoch in range(start_epoch, FLAGS.train_epoch_wce + 1):
                print("MRT: Epoch " + str(epoch))

                print("MRT: Epoch " + str(epoch) +
                      " : Reshuffle training document indices")
                train_data.shuffle_fileindices()

                print("MRT: Epoch " + str(epoch) + " : Restore Rouge Dict")
                rouge_generator.restore_rouge_dict()

                # Start Batch Training
                step = 1
                while (step * FLAGS.batch_size) <= len(train_data.fileindices):
                    # Get batch data as Numpy Arrays
                    batch_docnames, batch_docs, batch_label, batch_weight, batch_oracle_multiple, batch_reward_multiple = train_data.get_batch(
                        ((step - 1) * FLAGS.batch_size),
                        (step * FLAGS.batch_size))
                    # print(batch_docnames)
                    # print(batch_label[0])
                    # print(batch_weight[0])
                    # print(batch_oracle_multiple[0])
                    # print(batch_reward_multiple[0])
                    # exit(0)

                    # Print the progress
                    if (step % FLAGS.training_checkpoint) == 0:

                        ce_loss_val, ce_loss_sum, acc_val, acc_sum = sess.run(
                            [
                                model.
                                rewardweighted_cross_entropy_loss_multisample,
                                model.
                                rewardweighted_ce_multisample_loss_summary,
                                model.accuracy, model.taccuracy_summary
                            ],
                            feed_dict={
                                model.document_placeholder: batch_docs,
                                model.predicted_multisample_label_placeholder:
                                batch_oracle_multiple,
                                model.actual_reward_multisample_placeholder:
                                batch_reward_multiple,
                                model.label_placeholder: batch_label,
                                model.weight_placeholder: batch_weight
                            })

                        # Print Summary to Tensor Board
                        model.summary_writer.add_summary(
                            ce_loss_sum,
                            ((epoch - 1) * len(train_data.fileindices) +
                             step * FLAGS.batch_size))
                        model.summary_writer.add_summary(
                            acc_sum,
                            ((epoch - 1) * len(train_data.fileindices) +
                             step * FLAGS.batch_size))

                        print(
                            "MRT: Epoch " + str(epoch) + " : Covered " +
                            str(step * FLAGS.batch_size) + "/" +
                            str(len(train_data.fileindices)) +
                            " : Minibatch Reward Weighted Multisample CE Loss= {:.6f}"
                            .format(ce_loss_val) +
                            " : Minibatch training accuracy= {:.6f}".format(
                                acc_val))

                    # Run optimizer: optimize policy network
                    sess.run(
                        [model.train_op_policynet_expreward],
                        feed_dict={
                            model.document_placeholder: batch_docs,
                            model.predicted_multisample_label_placeholder:
                            batch_oracle_multiple,
                            model.actual_reward_multisample_placeholder:
                            batch_reward_multiple,
                            model.weight_placeholder: batch_weight
                        })

                    # Increase step
                    step += 1

                    # if step == 20:
                    #   break

                # Save Model
                print("MRT: Epoch " + str(epoch) +
                      " : Saving model after epoch completion")
                checkpoint_path = os.path.join(
                    FLAGS.train_dir, "model.ckpt.epoch-" + str(epoch))
                model.saver.save(sess, checkpoint_path)

                # Backup Rouge Dict
                print("MRT: Epoch " + str(epoch) +
                      " : Saving rouge dictionary")
                rouge_generator.save_rouge_dict()

                # Performance on the validation set
                print("MRT: Epoch " + str(epoch) +
                      " : Performance on the validation data")
                # Get Predictions: Prohibit the use of gold labels
                validation_logits, validation_labels, validation_weights = batch_predict_with_a_model(
                    validation_data, model, session=sess)
                # Validation Accuracy and Prediction
                validation_acc, validation_sum = sess.run(
                    [model.final_accuracy, model.vaccuracy_summary],
                    feed_dict={
                        model.logits_placeholder:
                        validation_logits.eval(session=sess),
                        model.label_placeholder:
                        validation_labels.eval(session=sess),
                        model.weight_placeholder:
                        validation_weights.eval(session=sess)
                    })
                # Print Validation Summary
                model.summary_writer.add_summary(
                    validation_sum, (epoch * len(train_data.fileindices)))

                print("MRT: Epoch " + str(epoch) + " : Validation (" +
                      str(len(validation_data.fileindices)) +
                      ") accuracy= {:.6f}".format(validation_acc))
                # Writing validation predictions and final summaries
                print("MRT: Epoch " + str(epoch) +
                      " : Writing final validation summaries")
                validation_data.write_prediction_summaries(
                    validation_logits,
                    "model.ckpt.epoch-" + str(epoch),
                    session=sess)
                # Extimate Rouge Scores
                rouge_score = rouge_generator.get_full_rouge(
                    FLAGS.train_dir + "/model.ckpt.epoch-" + str(epoch) +
                    ".validation-summary-topranked", "validation")
                print("MRT: Epoch " + str(epoch) + " : Validation (" +
                      str(len(validation_data.fileindices)) +
                      ") rouge= {:.6f}".format(rouge_score))

                # break

            print("Optimization Finished!")
def train_debug(sess):
    """
  Training Mode: Create a new model and train the network
  """
    tf.set_random_seed(seed)
    ### Prepare data for training
    print("Prepare vocab dict and read pretrained word embeddings ...")
    vocab_dict, word_embedding_array = DataProcessor(
    ).prepare_vocab_embeddingdict()
    # vocab_dict contains _PAD and _UNK but not word_embedding_array

    print("Prepare training data ...")
    train_data = DataProcessor().prepare_news_data(vocab_dict,
                                                   data_type="training")
    print("Training size: ", len(train_data.fileindices))

    print("Prepare validation data ...")
    # data in whole batch with padded matrixes
    val_batch = batch_load_data(DataProcessor().prepare_news_data(
        vocab_dict, data_type="validation"))
    print("Validation size: ", val_batch.docs.shape[0])

    fil_lens_to_test = FLAGS.max_filter_length - FLAGS.min_filter_length + 1
    if FLAGS.handle_filter_output == "concat" and FLAGS.sentembed_size % fil_lens_to_test != 0:
        q = int(FLAGS.sentembed_size // fil_lens_to_test)
        FLAGS.sentembed_size = q * fil_lens_to_test
        print("corrected embedding size: %d" % FLAGS.sentembed_size)

    # Create Model with various operations
    model = MY_Model(sess, len(vocab_dict) - 2)

    init_epoch = 1
    # Resume training if indicated Select the model
    if FLAGS.model_to_load != -1:
        selected_modelpath = FLAGS.train_dir + "/step-a.model.ckpt.epoch-" + str(
            FLAGS.model_to_load)
        init_epoch = FLAGS.model_to_load + 1
        print("Reading model parameters from %s" % selected_modelpath)
        model.saver.restore(sess, selected_modelpath)
        print("Model loaded.")

    # Initialize word embedding before training
    print(
        "Initialize word embedding vocabulary with pretrained embeddings ...")
    sess.run(model.vocab_embed_variable.assign(word_embedding_array))

    ### STEP A : Start Pretraining the policy with Supervised Labels: Simple Cross Entropy Training
    counter = 0
    max_val_acc = -1
    for epoch in range(init_epoch, FLAGS.train_epoch_crossentropy + 1):
        ep_time = time.time()  # to check duration

        train_data.shuffle_fileindices()

        # Start Batch Training
        step = 1
        total_ce_loss = 0
        total_train_acc = 0
        while (step * FLAGS.batch_size) <= len(train_data.fileindices):
            # Get batch data as Numpy Arrays
            batch = train_data.get_batch(((step - 1) * FLAGS.batch_size),
                                         (step * FLAGS.batch_size))

            # Run optimizer: optimize policy and reward estimator
            sess.run(
                [model.train_op_policynet_withgold],
                feed_dict={
                    model.document_placeholder: batch.docs,
                    model.label_placeholder: batch.labels,
                    model.weight_placeholder: batch.weights
                })

            prev_use_dpt = FLAGS.use_dropout
            FLAGS.use_dropout = False
            batch_logits, ce_loss, merged_summ = sess.run(
                [model.logits, model.cross_entropy_loss, model.merged],
                feed_dict={
                    model.document_placeholder: batch.docs,
                    model.label_placeholder: batch.labels,
                    model.weight_placeholder: batch.weights
                })
            total_ce_loss += ce_loss
            probs = sess.run(
                model.predictions,
                feed_dict={model.logits_placeholder: batch_logits})
            training_acc = accuracy_qas_top(probs, batch.labels, batch.weights,
                                            batch.isf_score_ids)
            FLAGS.use_dropout = prev_use_dpt
            total_train_acc += training_acc
            # Print the progress
            if (step % FLAGS.training_checkpoint) == 0:
                total_train_acc /= FLAGS.training_checkpoint
                acc_sum = sess.run(
                    model.tstepa_accuracy_summary,
                    feed_dict={model.train_acc_placeholder: total_train_acc})

                total_ce_loss /= FLAGS.training_checkpoint
                # Print Summary to Tensor Board
                model.summary_writer.add_summary(merged_summ, counter)
                model.summary_writer.add_summary(acc_sum, counter)

                # Performance on the validation set
                # Get Predictions: Prohibit the use of gold labels
                FLAGS.authorise_gold_label = False
                prev_use_dpt = FLAGS.use_dropout
                FLAGS.use_dropout = False
                val_batch = batch_predict_with_a_model(val_batch,
                                                       "validation",
                                                       model,
                                                       session=sess)
                FLAGS.use_dropout = prev_use_dpt
                FLAGS.authorise_gold_label = True

                # Validation Accuracy and Prediction
                probs = sess.run(
                    model.predictions,
                    feed_dict={model.logits_placeholder: val_batch.logits})
                validation_acc = accuracy_qas_top(probs, val_batch.labels,
                                                  val_batch.weights,
                                                  val_batch.isf_score_ids)

                ce_loss_val, ce_loss_sum, acc_sum = sess.run(
                    [
                        model.cross_entropy_loss_val,
                        model.ce_loss_summary_val,
                        model.vstepa_accuracy_summary
                    ],
                    feed_dict={
                        model.logits_placeholder: val_batch.logits,
                        model.label_placeholder: val_batch.labels,
                        model.weight_placeholder: val_batch.weights,
                        model.val_acc_placeholder: validation_acc
                    })

                # Print Validation Summary
                model.summary_writer.add_summary(acc_sum, counter)
                model.summary_writer.add_summary(ce_loss_sum, counter)
                print(
                    "Epoch %2d, step: %2d(%2d) || CE loss || Train : %4.3f , Val : %4.3f ||| ACC || Train : %.3f , Val : %.3f"
                    % (epoch, step, counter, total_ce_loss, ce_loss_val,
                       training_acc, validation_acc))
                total_ce_loss = 0
                total_train_acc = 0

            if (step % 5) == 0:  # to have comparable tensorboard plots
                counter += 1
            # Increase step
            step += 1
        #END-WHILE-TRAINING  ... but wait there is more
        ## eval metrics
        FLAGS.authorise_gold_label = False
        prev_use_dpt = FLAGS.use_dropout
        FLAGS.use_dropout = False
        val_batch = batch_predict_with_a_model(val_batch,
                                               "validation",
                                               model,
                                               session=sess)
        FLAGS.use_dropout = prev_use_dpt
        FLAGS.authorise_gold_label = True
        # Validation metrics
        probs = sess.run(
            model.predictions,
            feed_dict={model.logits_placeholder: val_batch.logits})
        acc = accuracy_qas_top(probs, val_batch.labels, val_batch.weights,
                               val_batch.isf_score_ids)
        mrr = mrr_metric(probs, val_batch.labels, val_batch.weights,
                         val_batch.isf_score_ids, "validation")
        _map = map_score(probs, val_batch.labels, val_batch.weights,
                         val_batch.isf_score_ids, "validation")
        print("Metrics: acc: %.4f | mrr: %.4f | map: %.4f" % (acc, mrr, _map))

        print("Epoch %2d : Duration: %.4f" % (epoch, time.time() - ep_time))
        if FLAGS.save_models:
            print("Saving model after epoch completion")
            checkpoint_path = os.path.join(
                FLAGS.train_dir, "step-a.model.ckpt.epoch-" + str(epoch))
            model.saver.save(sess, checkpoint_path)
        print(
            "------------------------------------------------------------------------------------------"
        )
    #END-FOR-EPOCH

    print("Optimization Finished!")
def train(sess):
    """
  Training Mode: Create a new model and train the network
  """
    ### Prepare data for training
    print("Prepare vocab dict and read pretrained word embeddings ...")
    vocab_dict, word_embedding_array = DataProcessor(
    ).prepare_vocab_embeddingdict()
    # vocab_dict contains _PAD and _UNK but not word_embedding_array

    print("Prepare training data ...")
    train_data = DataProcessor().prepare_news_data(vocab_dict,
                                                   data_type="training")

    print("Prepare validation data ...")
    # data in whole batch with padded matrixes
    val_batch = batch_load_data(DataProcessor().prepare_news_data(
        vocab_dict, data_type="validation"))

    fil_lens_to_test = FLAGS.max_filter_length - FLAGS.min_filter_length + 1
    if FLAGS.handle_filter_output == "concat" and FLAGS.sentembed_size % fil_lens_to_test != 0:
        q = int(FLAGS.sentembed_size // fil_lens_to_test)
        FLAGS.sentembed_size = q * fil_lens_to_test
        print("corrected embedding size: %d" % FLAGS.sentembed_size)

    # Create Model with various operations
    model = MY_Model(sess, len(vocab_dict) - 2)

    init_epoch = 1
    # Resume training if indicated Select the model
    if FLAGS.model_to_load != -1:
        if (os.path.isfile(FLAGS.train_dir + "/step-a.model.ckpt.epoch-" +
                           str(FLAGS.model_to_load))):
            selected_modelpath = FLAGS.train_dir + "/step-a.model.ckpt.epoch-" + str(
                FLAGS.model_to_load)
        else:
            print("Model not found in checkpoint folder.")
            exit(0)

        # Reload saved model and test
        init_epoch = FLAGS.model_to_load + 1
        print("Reading model parameters from %s" % selected_modelpath)
        model.saver.restore(sess, selected_modelpath)
        print("Model loaded.")

    # Initialize word embedding before training
    print(
        "Initialize word embedding vocabulary with pretrained embeddings ...")
    sess.run(model.vocab_embed_variable.assign(word_embedding_array))

    ### STEP A : Start Pretraining the policy with Supervised Labels: Simple Cross Entropy Training

    for epoch in range(init_epoch, FLAGS.train_epoch_crossentropy + 1):
        ep_time = time.time()  # to check duration

        print("STEP A: Epoch " + str(epoch) +
              " : Start pretraining with supervised labels")

        print("STEP A: Epoch " + str(epoch) +
              " : Reshuffle training document indices")
        train_data.shuffle_fileindices()

        # Start Batch Training
        step = 1
        total_ce_loss = 0
        counter = 0
        while (step * FLAGS.batch_size) <= len(train_data.fileindices):
            # Get batch data as Numpy Arrays
            batch = train_data.get_batch(((step - 1) * FLAGS.batch_size),
                                         (step * FLAGS.batch_size))

            # Run optimizer: optimize policy and reward estimator
            _, batch_logits, ce_loss, merged_summ = sess.run(
                [
                    model.train_op_policynet_withgold, model.logits,
                    model.cross_entropy_loss, model.merged
                ],
                feed_dict={
                    model.document_placeholder: batch.docs,
                    model.label_placeholder: batch.labels,
                    model.weight_placeholder: batch.weights
                })
            total_ce_loss += ce_loss
            # Print the progress
            if (step % FLAGS.training_checkpoint) == 0:
                probs = sess.run(
                    model.predictions,
                    feed_dict={model.logits_placeholder: batch_logits})
                training_acc = accuracy_qas_top(probs, batch.labels,
                                                batch.weights)
                acc_sum = sess.run(
                    model.tstepa_accuracy_summary,
                    feed_dict={model.train_acc_placeholder: training_acc})

                total_ce_loss /= FLAGS.training_checkpoint
                print("STEP A: Epoch " + str(epoch) + " : Covered " +
                      str(step * FLAGS.batch_size) + "/" +
                      str(len(train_data.fileindices)) +
                      " : Minibatch CE Loss= {:.6f}".format(total_ce_loss) +
                      ", Minibatch Accuracy= {:.6f}".format(training_acc))
                total_ce_loss = 0
                # Print Summary to Tensor Board
                model.summary_writer.add_summary(
                    merged_summ, (epoch - 1) * len(train_data.fileindices) +
                    step * FLAGS.batch_size)
                model.summary_writer.add_summary(
                    acc_sum, (epoch - 1) * len(train_data.fileindices) +
                    step * FLAGS.batch_size)
            # Increase step
            step += 1

            # if step == 100:
            # break

            #END-WHILE-TRAINING

        # Save Model
        print("STEP A: Epoch " + str(epoch) +
              " : Saving model after epoch completion")
        checkpoint_path = os.path.join(FLAGS.train_dir,
                                       "step-a.model.ckpt.epoch-" + str(epoch))
        model.saver.save(sess, checkpoint_path)

        # Performance on the validation set
        print("STEP A: Epoch " + str(epoch) +
              " : Performance on the validation data")
        # Get Predictions: Prohibit the use of gold labels
        FLAGS.authorise_gold_label = False
        FLAGS.use_dropout = False
        val_batch = batch_predict_with_a_model(val_batch,
                                               "validation",
                                               model,
                                               session=sess)
        FLAGS.use_dropout = True

        # Validation Accuracy and Prediction
        probs = sess.run(
            model.predictions,
            feed_dict={model.logits_placeholder: val_batch.logits})
        validation_acc = accuracy_qas_top(probs, val_batch.labels,
                                          val_batch.weights,
                                          val_batch.isf_score_ids)

        ce_loss_val, ce_loss_sum, acc_sum = sess.run(
            [
                model.cross_entropy_loss_val, model.ce_loss_summary_val,
                model.vstepa_accuracy_summary
            ],
            feed_dict={
                model.logits_placeholder: val_batch.logits,
                model.label_placeholder: val_batch.labels,
                model.weight_placeholder: val_batch.weights,
                model.val_acc_placeholder: validation_acc
            })

        # Print Validation Summary
        model.summary_writer.add_summary(acc_sum,
                                         epoch * len(train_data.fileindices))
        model.summary_writer.add_summary(ce_loss_sum,
                                         epoch * len(train_data.fileindices))
        print("STEP A: Epoch %s : Validation (%s) CE loss = %.6f" %
              (str(epoch), str(val_batch.docs.shape[0]), ce_loss_val))
        print("STEP A: Epoch %s : Validation (%s) Accuracy= %.6f" %
              (str(epoch), str(val_batch.docs.shape[0]), validation_acc))

        # Estimate MRR on validation set
        mrr_score = mrr_metric(probs, val_batch.labels, val_batch.weights,
                               val_batch.isf_score_ids, "validation")
        print("STEP A: Epoch %s : Validation (%s) MRR= %.6f" %
              (str(epoch), str(val_batch.docs.shape[0]), mrr_score))
        # Estimate MAP score on validation set
        mapsc = map_score(probs, val_batch.labels, val_batch.weights,
                          val_batch.isf_score_ids, "validation")
        print("STEP A: Epoch %s : Validation (%s) MAP= %.6f" %
              (str(epoch), str(val_batch.docs.shape[0]), mapsc))

        fn = "%s/step-a.model.ckpt.validation-metrics" % (FLAGS.train_dir)
        save_metrics(fn, FLAGS.load_prediction, validation_acc, mrr_score,
                     mapsc)

        print("STEP A: Epoch %d : Duration: %.4f" %
              (epoch, time.time() - ep_time))

    #END-FOR-EPOCH

    print("Optimization Finished!")
Ejemplo n.º 5
0
    best_wer = 200
    best_cer = 200
    best_ep = 0
    best_ep_wer = 0
    with tf.Graph().as_default() and tf.device('/gpu:'+FLAGS.gpu_id):
      config = tf.ConfigProto(allow_soft_placement = True)
      tf.set_random_seed(seed)
      with tf.Session(config = config) as sess:
        model = MY_Model(sess, len(vocab_dict))
        init_epoch = 1
        
        for epoch in range(init_epoch, FLAGS.train_epoch+1):
          ep_time = time.time() # to check duration
          shuffle = (epoch != init_epoch) if FLAGS.do_sortagrad  else True
          if shuffle:
            train_data.shuffle_fileindices()

          total_loss = 0
          # Start Batch Training
          step = 1
          while (step * FLAGS.batch_size) <= len(train_data.fileindices):
            # Get batch data as Numpy Arrays
            batch = train_data.get_batch(((step-1)*FLAGS.batch_size), (step * FLAGS.batch_size))
            dshape = np.array([batch.spect.shape[0],max_fixed],dtype=np.int64)
            sparse_labels = tf.SparseTensorValue(batch.label_indices,batch.label_values,dshape)

            # Run optimizer: optimize policy and reward estimator
            _,ctc_loss = sess.run([model.train_main_network,
                                  model.ctc_loss],
                                  feed_dict={model.spect_placeholder: batch.spect,
                                             model.label_placeholder: sparse_labels,
def train_debug(sess):
  """
  Training Mode: Create a new model and train the network
  """
  ### Prepare data for training
  print("Prepare vocab dict and read pretrained word embeddings ...")
  vocab_dict, word_embedding_array = DataProcessor().prepare_vocab_embeddingdict()
  # vocab_dict contains _PAD and _UNK but not word_embedding_array

  print("Prepare training data ...")
  train_data = DataProcessor().prepare_news_data(vocab_dict, data_type="training")
  print("Training size: ",len(train_data.fileindices))

  print("Prepare validation data ...")
  # data in whole batch with padded matrixes
  val_batch = batch_load_data(DataProcessor().prepare_news_data(vocab_dict,
                                                                data_type="validation",
                                                                normalizer=train_data.normalizer,
                                                                pca_model=train_data.pca_model))
  print("Validation size: ",val_batch.docs.shape[0])

  #print("Prepare ROUGE reward generator ...")
  #rouge_generator = Reward_Generator()

  # Create Model with various operations
  model = MY_Model(sess, len(vocab_dict)-2)
  
  init_epoch = 1
  # Resume training if indicated Select the model

  # Initialize word embedding before training
  print("Initialize word embedding vocabulary with pretrained embeddings ...")
  sess.run(model.vocab_embed_variable.assign(word_embedding_array))

  ### STEP A : Start Pretraining the policy with Supervised Labels: Simple Cross Entropy Training
  counter = 0
  for epoch in range(init_epoch, FLAGS.train_epoch_crossentropy+1):
    ep_time = time.time() # to check duration

    train_data.shuffle_fileindices()

    # Start Batch Training
    step = 1
    total_ce_loss = 0
    total_train_acc = 0
    while (step * FLAGS.batch_size) <= len(train_data.fileindices):
      # Get batch data as Numpy Arrays
      batch = train_data.get_batch(((step-1)*FLAGS.batch_size), (step * FLAGS.batch_size))

      # Run optimizer: optimize policy and reward estimator
      _,batch_logits,ce_loss,merged_summ = sess.run([
                                model.train_op_policynet_withgold,
                                model.logits,
                                model.cross_entropy_loss,
                                model.merged],
                                feed_dict={model.document_placeholder: batch.docs,
                                           model.label_placeholder: batch.labels,
                                           model.weight_placeholder: batch.weights,
                                           #model.cnt_placeholder: batch.cnt_score,
                                           model.isf_placeholder: batch.isf_score,
                                           model.idf_placeholder: batch.idf_score,
                                           model.locisf_score_placeholder:  batch.locisf_score
                                           #model.sent_len_placeholder: batch.sent_lens
                                           })

      total_ce_loss += ce_loss
      probs = sess.run(model.predictions,feed_dict={model.logits_placeholder: batch_logits})
      training_acc = accuracy_qas_top( probs,
                                       batch.labels,
                                       batch.weights,
                                       batch.isf_score_ids)
      total_train_acc += training_acc
      # Print the progress
      if (step % FLAGS.training_checkpoint) == 0:
        total_train_acc /= FLAGS.training_checkpoint
        acc_sum = sess.run( model.tstepa_accuracy_summary,
                            feed_dict={model.train_acc_placeholder: training_acc})
        
        total_ce_loss /= FLAGS.training_checkpoint
        # Print Summary to Tensor Board
        model.summary_writer.add_summary(merged_summ, counter)
        model.summary_writer.add_summary(acc_sum, counter)

        # Performance on the validation set
        # Get Predictions: Prohibit the use of gold labels
        FLAGS.authorise_gold_label = False
        FLAGS.use_dropout = False
        val_batch = batch_predict_with_a_model(val_batch,"validation", model, session=sess)
        FLAGS.authorise_gold_label = True
        FLAGS.use_dropout = True

        # Validation Accuracy and Prediction
        probs = sess.run(model.probs,feed_dict={model.logits_placeholder: val_batch.logits})
        validation_acc = accuracy_qas_top(probs, val_batch.labels, val_batch.weights, val_batch.isf_score_ids)

        ce_loss_val, ce_loss_sum, acc_sum = sess.run([ model.cross_entropy_loss_val,
                                                       model.ce_loss_summary_val,
                                                       model.vstepa_accuracy_summary],
                                                      feed_dict={model.logits_placeholder: val_batch.logits,
                                                                 model.label_placeholder:  val_batch.labels,
                                                                 model.weight_placeholder: val_batch.weights,
                                                                 model.isf_score_placeholder:  val_batch.isf_score,
                                                                 model.locisf_score_placeholder:  val_batch.locisf_score,
                                                                 model.val_acc_placeholder: validation_acc})

        # Print Validation Summary
        model.summary_writer.add_summary(acc_sum, counter)
        model.summary_writer.add_summary(ce_loss_sum, counter)
        
        print("Epoch %2d, step: %2d(%2d) || CE loss || Train : %4.3f , Val : %4.3f ||| ACC || Train : %.3f , Val : %.3f" % 
            (epoch,step,counter,total_ce_loss,ce_loss_val,total_train_acc,validation_acc))
        total_ce_loss = 0
        total_train_acc = 0

      if (step % 5) == 0: # to have comparable tensorboard plots
        counter += 1

      # Increase step
      step += 1
      
    #END-WHILE-TRAINING  ... but wait there is more 
    print("Epoch %2d : Duration: %.4f" % (epoch,time.time()-ep_time) )
    if not FLAGS.use_subsampled_dataset:
      print("Saving model after epoch completion")
      checkpoint_path = os.path.join(FLAGS.train_dir, "step-a.model.ckpt.epoch-"+str(epoch))
      model.saver.save(sess, checkpoint_path)

  #END-FOR-EPOCH

  print("Optimization Finished!")
def train():
  """
  Training Mode: Create a new model and train the network
  """

  # Training: use the tf default graph
  with tf.Graph().as_default() and tf.device('/gpu:2'):

    config = tf.ConfigProto(allow_soft_placement = True)

    # Start a session
    with tf.Session(config = config) as sess:
      
      ### Prepare data for training
      print("Prepare vocab dict and read pretrained word embeddings ...")
      vocab_dict, word_embedding_array = DataProcessor().prepare_vocab_embeddingdict()
      # vocab_dict contains _PAD and _UNK but not word_embedding_array

      print("Prepare training data ...")
      train_data = DataProcessor().prepare_news_data(vocab_dict, data_type="training")
      
      print("Prepare validation data ...")
      validation_data = DataProcessor().prepare_news_data(vocab_dict, data_type="validation")

      print("Prepare ROUGE reward generator ...")
      rouge_generator = Reward_Generator()
      
      # Create Model with various operations
      model = MY_Model(sess, len(vocab_dict)-2)
      # model =  MY_Model(sess, 100)
      
      # Initialize word embedding before training
      print("Initialize word embedding vocabulary with pretrained embeddings ...")
      sess.run(model.vocab_embed_variable.assign(word_embedding_array))
      
      ### STEP A : Start Pretraining the policy with Supervised Labels: Simple Cross Entropy Training 
      validation_epochvsrougescores = []
      for epoch in range(1, FLAGS.train_epoch_crossentropy+1):
        print("STEP A: Epoch "+str(epoch)+" : Start pretraining with supervised labels")
        
        print("STEP A: Epoch "+str(epoch)+" : Reshuffle training document indices")
        train_data.shuffle_fileindices()
        
        # Start Batch Training
        step = 1
        while (step * FLAGS.batch_size) <= len(train_data.fileindices):
          # Get batch data as Numpy Arrays
          batch_docnames, batch_docs, batch_label, batch_weight = train_data.get_batch(((step-1)*FLAGS.batch_size), (step * FLAGS.batch_size))
          # print(batch_docnames,batch_label)
  
          # Run optimizer: optimize policy and reward estimator
          sess.run([model.train_op_policynet_withgold], feed_dict={model.document_placeholder: batch_docs, 
                                                                   model.label_placeholder: batch_label, 
                                                                   model.weight_placeholder: batch_weight})
          
          # Print the progress
          if (step % FLAGS.training_checkpoint)  == 0:
            ce_loss_val, acc_val, ce_loss_sum, acc_sum = sess.run([model.cross_entropy_loss, model.accuracy, model.ce_loss_summary, model.tstepa_accuracy_summary],
                                                                  feed_dict={model.document_placeholder: batch_docs, 
                                                                             model.label_placeholder: batch_label, 
                                                                             model.weight_placeholder: batch_weight})
            
            print("STEP A: Epoch "+str(epoch)+" : Covered " + str(step*FLAGS.batch_size)+"/"+str(len(train_data.fileindices))+ 
                  " : Minibatch CE Loss= {:.6f}".format(ce_loss_val) + ", Minibatch Accuracy= {:.6f}".format(acc_val))
              
            # Print Summary to Tensor Board
            model.summary_writer.add_summary(ce_loss_sum, (epoch-1)*len(train_data.fileindices)+step*FLAGS.batch_size)
            model.summary_writer.add_summary(acc_sum, (epoch-1)*len(train_data.fileindices)+step*FLAGS.batch_size)
              
          # Increase step
          step += 1
          
          # if step == 100:
          # break

        # Save Model 
        print("STEP A: Epoch "+str(epoch)+" : Saving model after epoch completion")
        checkpoint_path = os.path.join(FLAGS.train_dir, "step-a.model.ckpt.epoch-"+str(epoch))
        model.saver.save(sess, checkpoint_path)
        
        # Performance on the validation set 
        print("STEP A: Epoch "+str(epoch)+" : Performance on the validation data")
        # Get Predictions: Prohibit the use of gold labels
        FLAGS.authorise_gold_label = False
        validation_logits, validation_labels, validation_weights = batch_predict_with_a_model(validation_data, model, session=sess)
        FLAGS.authorise_gold_label = True
        # Validation Accuracy and Prediction
        validation_acc, validation_sum = sess.run([model.final_accuracy, model.vstepa_accuracy_summary], feed_dict={model.logits_placeholder: validation_logits.eval(session=sess), 
                                                                                                                    model.label_placeholder: validation_labels.eval(session=sess), 
                                                                                                                    model.weight_placeholder: validation_weights.eval(session=sess)})
        # Print Validation Summary
        model.summary_writer.add_summary(validation_sum, epoch*len(train_data.fileindices))
        print("STEP A: Epoch "+str(epoch)+" : Validation ("+str(len(validation_data.fileindices))+") accuracy= {:.6f}".format(validation_acc))
        # Writing validation predictions and final summaries
        print("STEP A: Epoch "+str(epoch)+" : Writing final validation summaries")
        validation_data.write_prediction_summaries(validation_logits, "step-a.model.ckpt.epoch-"+str(epoch), session=sess)
        # Extimate Rouge Scores
        rouge_score = rouge_generator.get_full_rouge(FLAGS.train_dir+"/step-a.model.ckpt.epoch-"+str(epoch)+".validation-summary-rankedtop3", "validation")
        print("STEP A: Epoch "+str(epoch)+" : Validation ("+str(len(validation_data.fileindices))+") rouge= {:.6f}".format(rouge_score))
        # Store validation rouge scores
        validation_epochvsrougescores.append([rouge_score, epoch])
        
        # break
      
      print(validation_epochvsrougescores)

      print("Optimization Finished!")