def get_dev_loss(self, session, dev_context_path, dev_qn_path,
                     dev_ans_path):
        """
        Get loss for entire dev set.

        Inputs:
          session: TensorFlow session
          dev_qn_path, dev_context_path, dev_ans_path: paths to the dev.{context/question/answer} data files

        Outputs:
          dev_loss: float. Average loss across the dev set.
        """
        logging.info("Calculating dev loss...")
        tic = time.time()
        loss_per_batch, batch_lengths = [], []

        # Iterate over dev set batches
        # Note: here we set discard_long=True, meaning we discard any examples
        # which are longer than our context_len or question_len.
        # We need to do this because if, for example, the true answer is cut
        # off the context, then the loss function is undefined.
        for batch in get_batch_generator(self.word2id,
                                         dev_context_path,
                                         dev_qn_path,
                                         dev_ans_path,
                                         self.FLAGS.batch_size,
                                         context_len=self.FLAGS.context_len,
                                         question_len=self.FLAGS.question_len,
                                         discard_long=True):

            # Get loss for this batch
            loss = self.get_loss(session, batch)
            curr_batch_size = batch.batch_size
            loss_per_batch.append(loss * curr_batch_size)
            batch_lengths.append(curr_batch_size)

        # Calculate average loss
        total_num_examples = sum(batch_lengths)
        toc = time.time()
        print "Computed dev loss over %i examples in %.2f seconds" % (
            total_num_examples, toc - tic)

        # Overall loss is total loss divided by total number of examples
        dev_loss = sum(loss_per_batch) / float(total_num_examples)

        return dev_loss
    def train(self, session, train_context_path, train_qn_path, train_ans_path,
              dev_qn_path, dev_context_path, dev_ans_path):
        """
        Main training loop.

        Inputs:
          session: TensorFlow session
          {train/dev}_{qn/context/ans}_path: paths to {train/dev}.{context/question/answer} data files
        """

        # Print number of model parameters
        tic = time.time()
        params = tf.trainable_variables()
        num_params = sum(
            map(lambda t: np.prod(tf.shape(t.value()).eval()), params))
        toc = time.time()
        logging.info("Number of params: %d (retrieval took %f secs)" %
                     (num_params, toc - tic))

        # We will keep track of exponentially-smoothed loss
        exp_loss = None

        # Checkpoint management.
        # We keep one latest checkpoint, and one best checkpoint (early stopping)
        checkpoint_path = os.path.join(self.FLAGS.train_dir, "qa.ckpt")
        bestmodel_dir = os.path.join(self.FLAGS.train_dir, "best_checkpoint")
        bestmodel_ckpt_path = os.path.join(bestmodel_dir, "qa_best.ckpt")
        best_dev_f1 = None
        best_dev_em = None

        # for TensorBoard
        summary_writer = tf.summary.FileWriter(self.FLAGS.train_dir,
                                               session.graph)

        epoch = 0

        logging.info("Beginning training loop...")
        while self.FLAGS.num_epochs == 0 or epoch < self.FLAGS.num_epochs:
            epoch += 1
            epoch_tic = time.time()

            # Loop over batches
            for batch in get_batch_generator(
                    self.word2id,
                    train_context_path,
                    train_qn_path,
                    train_ans_path,
                    self.FLAGS.batch_size,
                    context_len=self.FLAGS.context_len,
                    question_len=self.FLAGS.question_len,
                    discard_long=True):

                # Run training iteration
                iter_tic = time.time()
                loss, global_step, param_norm, grad_norm = self.run_train_iter(
                    session, batch, summary_writer)
                iter_toc = time.time()
                iter_time = iter_toc - iter_tic

                # Update exponentially-smoothed loss
                if not exp_loss:  # first iter
                    exp_loss = loss
                else:
                    exp_loss = 0.99 * exp_loss + 0.01 * loss

                # Sometimes print info to screen
                if global_step % self.FLAGS.print_every == 0:
                    logging.info(
                        'epoch %d, iter %d, loss %.5f, smoothed loss %.5f, grad norm %.5f, param norm %.5f, batch time %.3f'
                        % (epoch, global_step, loss, exp_loss, grad_norm,
                           param_norm, iter_time))

                # Sometimes save model
                if global_step % self.FLAGS.save_every == 0:
                    logging.info("Saving to %s..." % checkpoint_path)
                    self.saver.save(session,
                                    checkpoint_path,
                                    global_step=global_step)

                # Sometimes evaluate model on dev loss, train F1/EM and dev F1/EM
                if global_step % self.FLAGS.eval_every == 0:

                    # Get loss for entire dev set and log to tensorboard
                    dev_loss = self.get_dev_loss(session, dev_context_path,
                                                 dev_qn_path, dev_ans_path)
                    logging.info("Epoch %d, Iter %d, dev loss: %f" %
                                 (epoch, global_step, dev_loss))
                    write_summary(dev_loss, "dev/loss", summary_writer,
                                  global_step)

                    # Get F1/EM on train set and log to tensorboard
                    train_f1, train_em = self.check_f1_em(session,
                                                          train_context_path,
                                                          train_qn_path,
                                                          train_ans_path,
                                                          "train",
                                                          num_samples=1000)
                    logging.info(
                        "Epoch %d, Iter %d, Train F1 score: %f, Train EM score: %f"
                        % (epoch, global_step, train_f1, train_em))
                    write_summary(train_f1, "train/F1", summary_writer,
                                  global_step)
                    write_summary(train_em, "train/EM", summary_writer,
                                  global_step)

                    # Get F1/EM on dev set and log to tensorboard
                    dev_f1, dev_em = self.check_f1_em(session,
                                                      dev_context_path,
                                                      dev_qn_path,
                                                      dev_ans_path,
                                                      "dev",
                                                      num_samples=0)
                    logging.info(
                        "Epoch %d, Iter %d, Dev F1 score: %f, Dev EM score: %f"
                        % (epoch, global_step, dev_f1, dev_em))
                    write_summary(dev_f1, "dev/F1", summary_writer,
                                  global_step)
                    write_summary(dev_em, "dev/EM", summary_writer,
                                  global_step)

                    # Early stopping based on dev EM. You could switch this to use F1 instead.
                    if best_dev_em is None or dev_em > best_dev_em:
                        best_dev_em = dev_em
                        logging.info("Saving to %s..." % bestmodel_ckpt_path)
                        self.bestmodel_saver.save(session,
                                                  bestmodel_ckpt_path,
                                                  global_step=global_step)

            epoch_toc = time.time()
            logging.info("End of epoch %i. Time for epoch: %f" %
                         (epoch, epoch_toc - epoch_tic))

        sys.stdout.flush()
    def check_f1_em(self,
                    session,
                    context_path,
                    qn_path,
                    ans_path,
                    dataset,
                    num_samples=100,
                    print_to_screen=False):
        """
        Sample from the provided (train/dev) set.
        For each sample, calculate F1 and EM score.
        Return average F1 and EM score for all samples.
        Optionally pretty-print examples.

        Note: This function is not quite the same as the F1/EM numbers you get from "official_eval" mode.
        This function uses the pre-processed version of the e.g. dev set for speed,
        whereas "official_eval" mode uses the original JSON. Therefore:
          1. official_eval takes your max F1/EM score w.r.t. the three reference answers,
            whereas this function compares to just the first answer (which is what's saved in the preprocessed data)
          2. Our preprocessed version of the dev set is missing some examples
            due to tokenization issues (see squad_preprocess.py).
            "official_eval" includes all examples.

        Inputs:
          session: TensorFlow session
          qn_path, context_path, ans_path: paths to {dev/train}.{question/context/answer} data files.
          dataset: string. Either "train" or "dev". Just for logging purposes.
          num_samples: int. How many samples to use. If num_samples=0 then do whole dataset.
          print_to_screen: if True, pretty-prints each example to screen

        Returns:
          F1 and EM: Scalars. The average across the sampled examples.
        """
        logging.info(
            "Calculating F1/EM for %s examples in %s set..." %
            (str(num_samples) if num_samples != 0 else "all", dataset))

        f1_total = 0.
        em_total = 0.
        example_num = 0

        tic = time.time()

        # Note here we select discard_long=False because we want to sample from the entire dataset
        # That means we're truncating, rather than discarding, examples with too-long context or questions
        for batch in get_batch_generator(self.word2id,
                                         context_path,
                                         qn_path,
                                         ans_path,
                                         self.FLAGS.batch_size,
                                         context_len=self.FLAGS.context_len,
                                         question_len=self.FLAGS.question_len,
                                         discard_long=False):

            pred_start_pos, pred_end_pos = self.get_start_end_pos(
                session, batch)

            # Convert the start and end positions to lists length batch_size
            pred_start_pos = pred_start_pos.tolist()  # list length batch_size
            pred_end_pos = pred_end_pos.tolist()  # list length batch_size

            for ex_idx, (pred_ans_start, pred_ans_end,
                         true_ans_tokens) in enumerate(
                             zip(pred_start_pos, pred_end_pos,
                                 batch.ans_tokens)):
                example_num += 1

                # Get the predicted answer
                # Important: batch.context_tokens contains the original words (no UNKs)
                # You need to use the original no-UNK version when measuring F1/EM
                pred_ans_tokens = batch.context_tokens[ex_idx][
                    pred_ans_start:pred_ans_end + 1]
                pred_answer = " ".join(pred_ans_tokens)

                # Get true answer (no UNKs)
                true_answer = " ".join(true_ans_tokens)

                # Calc F1/EM
                f1 = f1_score(pred_answer, true_answer)
                em = exact_match_score(pred_answer, true_answer)
                f1_total += f1
                em_total += em

                # Optionally pretty-print
                if print_to_screen:
                    print_example(self.word2id, batch.context_tokens[ex_idx],
                                  batch.qn_tokens[ex_idx],
                                  batch.ans_span[ex_idx,
                                                 0], batch.ans_span[ex_idx, 1],
                                  pred_ans_start, pred_ans_end, true_answer,
                                  pred_answer, f1, em)

                if num_samples != 0 and example_num >= num_samples:
                    break

            if num_samples != 0 and example_num >= num_samples:
                break

        f1_total /= example_num
        em_total /= example_num

        toc = time.time()
        logging.info(
            "Calculating F1/EM for %i examples in %s set took %.2f seconds" %
            (example_num, dataset, toc - tic))

        return f1_total, em_total
param_norm = tf.global_norm(params)


global_step = tf.Variable(0, name="global_step", trainable=False)
opt = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate) # you can try other optimizers
updates = opt.apply_gradients(zip(clipped_gradients, params), global_step=global_step)

# Define savers (for checkpointing) and summaries (for tensorboard)
saver = tf.train.Saver(tf.global_variables(), max_to_keep=FLAGS.keep)
bestmodel_saver = tf.train.Saver(tf.global_variables(), max_to_keep=1)
summaries = tf.summary.merge_all()

# In[]


for batch in get_batch_generator(word2id, train_context_path, train_qn_path, train_ans_path, FLAGS.batch_size, context_len=FLAGS.context_len, question_len=FLAGS.question_len, discard_long=True):
    #loss, global_step, param_norm, grad_norm = run_train_iter(session, batch, summary_writer)
    # Match up our input data with the placeholders
    input_feed = {}
    print('1')
    input_feed[context_ids] = batch.context_ids
    print(batch.context_ids.shape)
    input_feed[context_mask] = batch.context_mask
    input_feed[qn_ids] = batch.qn_ids
    input_feed[qn_mask] = batch.qn_mask
    input_feed[ans_span] = batch.ans_span
    input_feed[keep_prob] = 1.0 - FLAGS.dropout # apply dropout
    input_feed[question_length] = batch.qn_length
    input_feed[document_length] = batch.context_length
    # output_feed contains the things we want to fetch.
    #output_feed = [updates, summaries, loss, global_step, param_norm, gradient_norm]