Example #1
0
    def get_c2q_attention(self, session, context_path, qn_path, ans_path, dataset, num_samples=0):
        """
        Sample from the provided (train/dev) set.
        Inputs:
          session: TensorFlow session
          qn_path, context_path, ans_path: paths to {dev/train}.{question/context/answer} data files.
          dataset: string. Either "train" or "dev". Just for logging purposes.
        Returns:
          begin_prob, end_prob: The average probabilities the sampled examples.
        """
        total_c2q_attention = []
        example_num = 0
        for batch in get_batch_generator(
            self.word2id,
            context_path,
            qn_path,
            ans_path,
            self.FLAGS.batch_size,
            context_len=self.FLAGS.context_len,
            question_len=self.FLAGS.question_len,
            discard_long=False,
            random=False):

            c2q_dists = self.get_c2q_attention_dist(session, batch)
            c2q_list = c2q_dists.tolist() # list length batch_size

            for _, (c2q_dist) in enumerate(c2q_list):
                example_num += 1
                total_c2q_attention.append(c2q_dist)
                # print_example(self.word2id, batch.context_tokens[ex_idx], batch.qn_tokens[ex_idx], batch.ans_span[ex_idx, 0], batch.ans_span[ex_idx, 1], pred_ans_start, pred_ans_end, true_answer, pred_answer, f1, em)
                if num_samples != 0 and example_num >= num_samples:
                    break
            if num_samples != 0 and example_num >= num_samples:
                break
        return np.asarray(total_c2q_attention)
Example #2
0
def evaluate(model, word2id, FLAGS, dev_context_path, dev_qn_path,
             dev_ans_path):
    logging.info("Calculating F1/EM for all examples in dev set...")

    f1_total = 0.
    em_total = 0.
    example_num = 0

    tic = time.time()

    for batch in get_batch_generator(word2id,
                                     dev_context_path,
                                     dev_qn_path,
                                     dev_ans_path,
                                     FLAGS.batch_size,
                                     context_len=FLAGS.context_len,
                                     question_len=FLAGS.question_len,
                                     discard_long=False):
        # print(type(batch))
        prob_start, prob_end = model.predict([
            batch.context_ids, batch.context_mask, batch.qn_ids, batch.qn_mask
        ])

        start_pos = np.argmax(prob_start, axis=1)
        end_pos = np.argmax(prob_end, axis=1)

        pred_start_pos = start_pos.tolist()
        pred_end_pos = end_pos.tolist()

        for ex_idx, (pred_ans_start, pred_ans_end,
                     true_ans_tokens) in enumerate(
                         zip(pred_start_pos, pred_end_pos, batch.ans_tokens)):
            example_num += 1

            # Get the predicted answer
            # Important: batch.context_tokens contains the original words (no UNKs)
            # You need to use the original no-UNK version when measuring F1/EM
            pred_ans_tokens = batch.context_tokens[ex_idx][
                pred_ans_start:pred_ans_end + 1]
            pred_answer = " ".join(pred_ans_tokens)

            # Get true answer (no UNKs)
            true_answer = " ".join(true_ans_tokens)

            # Calc F1/EM
            f1 = f1_score(pred_answer, true_answer)
            em = exact_match_score(pred_answer, true_answer)
            f1_total += f1
            em_total += em

            # print(f1, em, example_num)

    f1_total /= example_num
    em_total /= example_num

    toc = time.time()
    print("Calculating F1/EM for %i examples in %s set took %.2f seconds" %
          (example_num, "dev", toc - tic))

    return f1_total, em_total
Example #3
0
    def get_spans(self, session, context_path, qn_path, ans_path, dataset, num_samples=0):
        """
        Sample from the provided (train/dev) set.
        Inputs:
          session: TensorFlow session
          qn_path, context_path, ans_path: paths to {dev/train}.{question/context/answer} data files.
          dataset: string. Either "train" or "dev". Just for logging purposes.
        Returns:
          begin_prob, end_prob: The average probabilities the sampled examples.
        """
        total_start_dists = []
        total_end_dists = []
        f1_em_scores = []
        example_num = 0
        for batch in get_batch_generator(
            self.word2id,
            context_path,
            qn_path,
            ans_path,
            self.FLAGS.batch_size,
            context_len=self.FLAGS.context_len,
            question_len=self.FLAGS.question_len,
            discard_long=False,
            random=False):

            pred_start_dists, pred_end_dists = self.get_prob_dists(session, batch)
            pred_start_pos, pred_end_pos = self.get_start_end_pos(session, batch)

            # Convert the start and end positions to lists length batch_size
            pred_start_pos = pred_start_pos.tolist() # list length batch_size
            pred_end_pos = pred_end_pos.tolist() # list length batch_size
            pred_start_dists = pred_start_dists.tolist() # list length batch_size
            pred_end_dists = pred_end_dists.tolist() # list length batch_size

            for ex_idx, (pred_ans_start, pred_ans_end, true_ans_tokens) in enumerate(zip(pred_start_pos, pred_end_pos, batch.ans_tokens)):
                example_num += 1

                # Get the predicted answer
                # Important: batch.context_tokens contains the original words (no UNKs)
                # You need to use the original no-UNK version when measuring F1/EM
                pred_ans_tokens = batch.context_tokens[ex_idx][pred_ans_start : pred_ans_end + 1]
                pred_answer = " ".join(pred_ans_tokens)

                # Get true answer (no UNKs)
                true_answer = " ".join(true_ans_tokens)

                # Calc F1/EM
                f1 = f1_score(pred_answer, true_answer)
                em = exact_match_score(pred_answer, true_answer)
                f1_em_scores.append((f1,em))
                # print_example(self.word2id, batch.context_tokens[ex_idx], batch.qn_tokens[ex_idx], batch.ans_span[ex_idx, 0], batch.ans_span[ex_idx, 1], pred_ans_start, pred_ans_end, true_answer, pred_answer, f1, em)
                if num_samples != 0 and example_num >= num_samples:
                    break

            # Convert the start and end positions to lists length batch_size
            total_end_dists += pred_end_dists
            total_start_dists += pred_start_dists
            if num_samples != 0 and example_num >= num_samples:
                break
        return np.asarray(total_start_dists), np.asarray(total_end_dists), np.asarray(f1_em_scores)
Example #4
0
    def get_dev_loss(self, session, dev_context_path, dev_qn_path, dev_ans_path):
        """
        Get loss for entire dev set.
        """
        
        logging.info("Calculating dev loss...")
        tic = time.time()
        loss_per_batch, batch_lengths = [], []

        # Iterate over dev set batches
        for batch in get_batch_generator(self.word2id, dev_context_path, dev_qn_path, dev_ans_path, self.FLAGS.batch_size, context_len=self.FLAGS.context_len, question_len=self.FLAGS.question_len, discard_long=True):

            # Get loss for this batch
            loss = self.get_loss(session, batch)
            curr_batch_size = batch.batch_size
            loss_per_batch.append(loss * curr_batch_size)
            batch_lengths.append(curr_batch_size)

        # Calculate average loss
        total_num_examples = sum(batch_lengths)
        toc = time.time()
        print ("Computed dev loss over %i examples in %.2f seconds" % (total_num_examples, toc-tic))

        # Overall loss is total loss divided by total number of examples
        dev_loss = sum(loss_per_batch) / float(total_num_examples)

        return dev_loss
    def get_dev_loss(self, session, dev_context_path, dev_qn_path, dev_ans_path):
        """
        Get loss for entire dev set.
        Inputs:
          session: TensorFlow session
          dev_qn_path, dev_context_path, dev_ans_path: paths to the dev.{context/question/answer} data files
        Outputs:
          dev_loss: float. Average loss across the dev set.
        """
        logging.info("Calculating dev loss...")
        tic = time.time()
        loss_per_batch, batch_lengths = [], []

        # Iterate over dev set batches
        # Note: here we set discard_long=True, meaning we discard any examples
        # which are longer than our context_len or question_len.
        # We need to do this because if, for example, the true answer is cut
        # off the context, then the loss function is undefined.
        for batch in get_batch_generator(self.word2id, dev_context_path, dev_qn_path, dev_ans_path, self.FLAGS.batch_size, context_len=self.FLAGS.context_len, question_len=self.FLAGS.question_len, discard_long=True):

            # Get loss for this batch
            loss = self.get_loss(session, batch)
            curr_batch_size = batch.batch_size
            loss_per_batch.append(loss * curr_batch_size)
            batch_lengths.append(curr_batch_size)

        # Calculate average loss
        total_num_examples = sum(batch_lengths)
        toc = time.time()
        print("Computed dev loss over %i examples in %.2f seconds" % (total_num_examples, toc-tic))

        # Overall loss is total loss divided by total number of examples
        dev_loss = sum(loss_per_batch) / float(total_num_examples)

        return dev_loss
    def check_f1_em(self, model, dataset, num_samples=100, print_to_screen=False):
        logging.info("Calculating F1/EM for %s examples in %s set..." % (str(num_samples) if num_samples != 0 else "all", dataset))

        if dataset == "train":
            context_path, qn_path, ans_path = self.train_context_path, self.train_qn_path, self.train_ans_path
        elif dataset == "dev":
            context_path, qn_path, ans_path = self.dev_context_path, self.dev_qn_path, self.dev_ans_path
        else:
            raise ('dataset is not defined')

        f1_total = 0.
        em_total = 0.
        example_num = 0

        tic = time.time()

        for batch in get_batch_generator(self.word2id, context_path, qn_path, ans_path, config.batch_size,
                                         context_len=config.context_len, question_len=config.question_len,
                                         discard_long=False):

            pred_start_pos, pred_end_pos = self.test_one_batch(batch, model)

            pred_start_pos = pred_start_pos.tolist()
            pred_end_pos = pred_end_pos.tolist()

            for ex_idx, (pred_ans_start, pred_ans_end, true_ans_tokens) \
                    in enumerate(zip(pred_start_pos, pred_end_pos, batch.ans_tokens)):
                example_num += 1
                pred_ans_tokens = batch.context_tokens[ex_idx][pred_ans_start : pred_ans_end + 1]
                pred_answer = " ".join(pred_ans_tokens)

                true_answer = " ".join(true_ans_tokens)

                f1 = f1_score(pred_answer, true_answer)
                em = exact_match_score(pred_answer, true_answer)
                f1_total += f1
                em_total += em

                if print_to_screen:
                    print_example(self.word2id, batch.context_tokens[ex_idx], batch.qn_tokens[ex_idx],
                                  batch.ans_span[ex_idx, 0], batch.ans_span[ex_idx, 1], pred_ans_start,
                                  pred_ans_end, true_answer, pred_answer, f1, em)

                if num_samples != 0 and example_num >= num_samples:
                    break

            if num_samples != 0 and example_num >= num_samples:
                break

        f1_total /= example_num
        em_total /= example_num

        toc = time.time()
        logging.info("Calculating F1/EM for %i examples in %s set took %.2f seconds" % (example_num, dataset, toc-tic))

        return f1_total, em_total
    def check_f1_em(self,
                    context_path,
                    qn_path,
                    ans_path,
                    dataset,
                    num_samples=1000):
        f1_total = 0.
        em_total = 0.
        example_num = 0

        for batch in data_batcher.get_batch_generator(self.word2id,
                                                      self.id2idf,
                                                      context_path,
                                                      qn_path,
                                                      ans_path,
                                                      self.batch_size,
                                                      context_len=300,
                                                      question_len=30,
                                                      discard_long=False):

            pred_start_pos, pred_end_pos = self.get_predictions(batch)

            # Convert the start and end positions to lists length batch_size
            pred_start_pos = pred_start_pos.tolist()  # list length batch_size
            pred_end_pos = pred_end_pos.tolist()  # list length batch_size

            for ex_idx, (pred_ans_start, pred_ans_end, true_ans_tokens) in \
                    enumerate(zip(pred_start_pos, pred_end_pos, batch.ans_tokens)):
                example_num += 1

                # Get the predicted answer
                # Important: batch.context_tokens contains the original words (no UNKs)
                # You need to use the original no-UNK version when measuring F1/EM
                pred_ans_tokens = batch.context_tokens[ex_idx][
                    pred_ans_start:pred_ans_end + 1]
                pred_answer = " ".join(pred_ans_tokens)

                # Get true answer (no UNKs)
                true_answer = " ".join(true_ans_tokens)

                # Calc F1/EM
                f1 = f1_score(pred_answer, true_answer)
                em = exact_match_score(pred_answer, true_answer)
                f1_total += f1
                em_total += em

                if num_samples != 0 and example_num >= num_samples:
                    break

            if num_samples != 0 and example_num >= num_samples:
                break

        f1_total /= example_num
        em_total /= example_num

        return f1_total, em_total
Example #8
0
    def check_f1_em(self, session, context_path, qn_path, ans_path, dataset, num_samples=100, print_to_screen=False):
        """
        Sample from the provided (train/dev) set.
        For each sample, calculate F1 and EM score.
        Return average F1 and EM score for all samples.
        Optionally pretty-print examples.
        """
        logging.info("Calculating F1/EM for %s examples in %s set..." % (str(num_samples) if num_samples != 0 else "all", dataset))

        f1_total = 0.
        em_total = 0.
        example_num = 0

        tic = time.time()

        for batch in get_batch_generator(self.word2id, context_path, qn_path, ans_path, self.FLAGS.batch_size, context_len=self.FLAGS.context_len, question_len=self.FLAGS.question_len, discard_long=False):

            pred_start_pos, pred_end_pos = self.get_start_end_pos(session, batch)

            for ex_idx, (pred_ans_start, pred_ans_end, true_ans_tokens) in enumerate(zip(pred_start_pos, pred_end_pos, batch.ans_tokens)):
                example_num += 1

                # Get the predicted answer
                pred_ans_tokens = batch.context_tokens[ex_idx][pred_ans_start : pred_ans_end + 1]
                pred_answer = " ".join(pred_ans_tokens)

                # Get true answer (no UNKs)
                true_answer = " ".join(true_ans_tokens)

                # Calc F1/EM
                f1 = f1_score(pred_answer, true_answer)
                em = exact_match_score(pred_answer, true_answer)
                f1_total += f1
                em_total += em

                # Optionally pretty-print
                if print_to_screen:
                    print_example(self.word2id, batch.context_tokens[ex_idx], batch.qn_tokens[ex_idx], batch.ans_span[ex_idx, 0], batch.ans_span[ex_idx, 1], pred_ans_start, pred_ans_end, true_answer, pred_answer, f1, em)

                if num_samples != 0 and example_num >= num_samples:
                    break

            if num_samples != 0 and example_num >= num_samples:
                break

        f1_total /= example_num
        em_total /= example_num

        toc = time.time()
        logging.info("Calculating F1/EM for %i examples in %s set took %.2f seconds" % (example_num, dataset, toc-tic))

        return f1_total, em_total
Example #9
0
    def get_val_loss(self, session):
        '''
        Get average loss on the entire val set
        This function is called periodically during training
        '''
        total_loss, num_examples = 0., 0
        tic = time.time()
        for batch in get_batch_generator(self.word2id, self.img_features_map, self.val_caption_id_2_caption, self.caption_id_2_img_id, \
                                        self.FLAGS.batch_size, self.FLAGS.max_caption_len, 'train', None, self.FLAGS.data_source):
            total_loss += self.get_loss(session, batch) * batch.batch_size
            num_examples += batch.batch_size

        logging.info("Computing validation loss over {} examples took {} seconds".format(num_examples, time.time() - tic))
        return total_loss / num_examples
Example #10
0
    def get_dev_loss(self, session, dev_context_path, dev_qn_path,
                     dev_ans_path):
        """
        Get loss for entire dev set.

        Inputs:
          session: TensorFlow session
          dev_qn_path, dev_context_path, dev_ans_path: paths to the dev.{context/question/answer} data files

        Outputs:
          dev_loss: float. Average loss across the dev set.
        """
        logging.info("Calculating dev loss...")
        tic = time.time()
        loss_per_batch, batch_lengths = [], []

        i = 0
        for batch in get_batch_generator(
                self.word2id,
                self.context2id,
                self.ans2id,
                dev_context_path,
                dev_qn_path,
                dev_ans_path,
                self.FLAGS.batch_size,
                self.graph_vocab_class,
                context_len=self.FLAGS.context_len,
                question_len=self.FLAGS.question_len,
                answer_len=self.FLAGS.answer_len,
                discard_long=False,
                use_raw_graph=self.FLAGS.use_raw_graph,
                show_start_tokens=self.FLAGS.show_start_tokens):
            loss = self.get_loss(session, batch)
            curr_batch_size = batch.batch_size
            loss_per_batch.append(loss * curr_batch_size)
            batch_lengths.append(curr_batch_size)
            if i == 10:
                break
            i += 1

        # Calculate average loss
        total_num_examples = sum(batch_lengths)
        toc = time.time()
        print "Computed dev loss over %i examples in %.2f seconds" % (
            total_num_examples, toc - tic)
        # Overall loss is total loss divided by total number of examples
        dev_loss = sum(loss_per_batch) / float(total_num_examples)

        return dev_loss
    def train(self, session, train_context_path, train_qn_path, train_ans_path,
              dev_qn_path, dev_context_path, dev_ans_path):
        summary_writer = tf.summary.FileWriter(
            "/Users/lam/Desktop/Lam-cs224n/Projects/qa/squad", session.graph)
        for batch in get_batch_generator(self.word2id,
                                         self.char2id,
                                         train_context_path,
                                         train_qn_path,
                                         train_ans_path,
                                         self.FLAGS.batch_size,
                                         self.FLAGS.context_len,
                                         self.FLAGS.question_len,
                                         self.FLAGS.max_word_len,
                                         discard_long=True):
            self.sample_batch = batch

            self.run_train_iter(session, batch, summary_writer)
            break
Example #12
0
    def check_metric(self, session, mode='val', num_samples=0):
        '''
        Evaluate the model on the validation or test set.
        Inputs:
            mode: should be either 'val' or 'test'
            num_samples: number of images to evaluate on. Evaluate on all val images if 0.
        '''
        assert (mode == 'val' or mode == 'test')
        captions = []  # [{"image_id": image_id, "caption": caption_str}]

        # Generate all the captions and save in list 'captions'
        tic = time.time()
        num_seen = 0  # Record the number of samples predicted so far
        this_caption_map = self.val_caption_id_2_caption if mode == 'val' else self.test_caption_id_2_caption

        for batch in get_batch_generator(self.word2id, self.img_features_map, this_caption_map, self.caption_id_2_img_id, \
                                        self.FLAGS.batch_size, self.FLAGS.max_caption_len, 'eval', None, self.FLAGS.data_source):
            batch_captions = self.get_captions(session, batch)   # {imgae_id: caption_string}
            for id, cap in batch_captions.items():
                captions.append({"image_id": id, "caption": cap})

            num_seen += batch.batch_size
            if num_samples != 0 and num_seen >= num_samples:
                break

        logging.info("Predicting on {} examples took {} seconds".format(num_seen, time.time() - tic))

        # Dump the generated captions to json file
        file = open(self.FLAGS.train_res_dir, 'w')
        json.dump(captions, file)
        file.close()

        # Evaluate using the official evaluation API (The evaluation takes ~12s for 1000 examples)
        tic = time.time()
        cocoGold = COCO(self.FLAGS.goldAnn_val_dir) # Official annotations
        cocoRes = cocoGold.loadRes(self.FLAGS.train_res_dir) # Prediction
        cocoEval = COCOEvalCap(cocoGold, cocoRes)
        cocoEval.params['image_id'] = cocoRes.getImgIds() # Evaluate on a subset of the official captions_val2014
        cocoEval.evaluate()
        logging.info("Evaluating {} predictions took {} seconds".format(num_seen, time.time() - tic))

        scores = cocoEval.eval  # {metric_name: metric_score}
        return scores   # Bleu_1, Bleu_2, Bleu_3, Bleu_4, METEOR, ROUGE_L, CIDEr
Example #13
0
    def get_dev_loss(self, model):
        logging.info("Calculating dev loss...")
        tic = time.time()
        loss_per_batch, batch_lengths = [], []
        i = 0
        for batch in get_batch_generator(self.word2id, self.dev_context_path, self.dev_qn_path, self.dev_ans_path,
                                         config.batch_size, context_len=config.context_len,
                                         question_len=config.question_len, discard_long=True):

            loss, _, _ = self.eval_one_batch(batch, model)
            curr_batch_size = batch.batch_size
            loss_per_batch.append(loss * curr_batch_size)
            batch_lengths.append(curr_batch_size)
            i += 1
            if i == 10:
                break
        total_num_examples = sum(batch_lengths)
        toc = time.time()
        print ("Computed dev loss over %i examples in %.2f seconds" % (total_num_examples, toc-tic))

        dev_loss = sum(loss_per_batch) / float(total_num_examples)

        return dev_loss
Example #14
0
    def get_dev_loss(self, session, dev_context_path, dev_qn_path, dev_ans_path):
        """
        Get loss for entire dev set.

        Inputs:
          session: TensorFlow session
          dev_qn_path, dev_context_path, dev_ans_path: paths to the dev.{context/question/answer} data files

        Outputs:
          dev_loss: float. Average loss across the dev set.
        """
        logging.info("Calculating dev loss...")
        tic = time.time()
        loss_per_batch, batch_lengths = [], []

        # Iterate over dev set batches
        # Note: here we set discard_long=True, meaning we discard any examples
        # which are longer than our context_len or question_len.
        # We need to do this because if, for example, the true answer is cut
        # off the context, then the loss function is undefined.
        for batch in get_batch_generator(self.word2id, dev_context_path, dev_qn_path, dev_ans_path, self.FLAGS.batch_size, context_len=self.FLAGS.context_len, question_len=self.FLAGS.question_len, discard_long=True):

            # Get loss for this batch
            loss = self.get_loss(session, batch)
            curr_batch_size = batch.batch_size
            loss_per_batch.append(loss * curr_batch_size)
            batch_lengths.append(curr_batch_size)

        # Calculate average loss
        total_num_examples = sum(batch_lengths)
        toc = time.time()
        print "Computed dev loss over %i examples in %.2f seconds" % (total_num_examples, toc-tic)

        # Overall loss is total loss divided by total number of examples
        dev_loss = sum(loss_per_batch) / float(total_num_examples)

        return dev_loss
Example #15
0
    def demo(self,
             session,
             context_path,
             qn_path,
             ans_path,
             dataset,
             num_samples=10,
             print_to_screen=False,
             write_out=False,
             file_out=None,
             shuffle=True):
        """
        Sample from the provided (train/dev) set.
        For each sample, calculate F1 and EM score.
        Return average F1 and EM score for all samples.
        Optionally pretty-print examples.

        Inputs:
          session: TensorFlow session
          qn_path, context_path, ans_path: paths to {dev/train}.{question/context/answer} data files.
          dataset: string. Either "train" or "dev". Just for logging purposes.
          num_samples: int. How many samples to use. If num_samples=0 then do whole dataset.
          print_to_screen: if True, pretty-prints each example to screen

        Returns:
          F1 and EM: Scalars. The average across the sampled examples.
        """
        logging.info(
            "Calculating F1/EM for %s examples in %s set..." %
            (str(num_samples) if num_samples != 0 else "all", dataset))
        example_num = 0

        tic = time.time()
        ans_list = []
        graph_route_info = []

        for batch in get_batch_generator(
                self.word2id,
                self.context2id,
                self.ans2id,
                context_path,
                qn_path,
                ans_path,
                self.FLAGS.batch_size,
                self.graph_vocab_class,
                context_len=self.FLAGS.context_len,
                question_len=self.FLAGS.question_len,
                answer_len=self.FLAGS.answer_len,
                discard_long=False,
                use_raw_graph=self.FLAGS.use_raw_graph,
                shuffle=shuffle,
                show_start_tokens=self.FLAGS.show_start_tokens,
                output_goal=True):
            train_ids, pred_ids, dev_final_states, pred_logits = self.get_prob_dists(
                session, batch)
            start_ids = batch.ans_ids[:, 0].reshape(-1)

            if self.FLAGS.pred_method != 'beam':
                pred_ids, confidence_score, ans_str = output_route(
                    start_ids, pred_logits, batch.context_tokens, self.ans2id,
                    self.id2ans, self.FLAGS.answer_len)

            pred_ids = pred_ids.tolist()  # the output of using test network
            dev_attention_map = create_attention_images_summary(
                dev_final_states)
            print "dev_attention_map", dev_attention_map.shape
            dev_attention_map = dev_attention_map.eval().tolist()

            # the output of using training network, that the true input is fed as the input of the next RNN, for debug.
            for ex_idx, (pred_ans_list, true_ans_tokens,
                         attention_map) in enumerate(
                             zip(pred_ids, list(batch.ans_tokens),
                                 dev_attention_map)):

                example_num += 1
                pred_ans_tokens = []
                for id in pred_ans_list:
                    if id == PAD_ID:
                        break
                    else:
                        pred_ans_tokens.append(self.id2ans[id])
                pred_answer = " ".join(pred_ans_tokens)

                # Get true answer (no UNKs)
                true_answer = " ".join(true_ans_tokens[:])
                # Calculate metrics
                f1, em, edit_dist, rough_em = compute_all_metrics(
                    pred_ans_tokens, true_ans_tokens)
                ans_list.append(pred_answer)

                if print_to_screen:
                    print_example(self.word2id, self.context2id, self.ans2id,
                                  batch.context_tokens[ex_idx],
                                  batch.qn_tokens[ex_idx], true_answer,
                                  pred_answer, f1, em, edit_dist,
                                  confidence_score[ex_idx])
                    # Draw attention map
                    draw_attention(batch, ex_idx, attention_map,
                                   pred_ans_tokens)

                if num_samples != 0 and example_num >= num_samples:
                    break

            if num_samples != 0 and example_num >= num_samples:
                break

        toc = time.time()
        logging.info(
            "Calculating F1/EM for %i examples in %s set took %.2f seconds" %
            (example_num, dataset, toc - tic))
        if write_out:
            logging.info("Writing the prediction to {}".format(file_out))
            with open(file_out, 'w') as f:
                for line, extra_info in zip(ans_list, graph_route_info):
                    f.write(line + " " + " ".join(extra_info) + '\n')
            print("Wrote predictions to %s" % file_out)

        return
Example #16
0
    def train(self, session, train_context_path, train_qn_path, train_ans_path, dev_qn_path, dev_context_path, dev_ans_path):
        """
        Main training loop.
        """

        # Print number of model parameters
        tic = time.time()
        params = tf.trainable_variables()
        num_params = sum(map(lambda t: np.prod(tf.shape(t.value()).eval()), params))
        toc = time.time()
        logging.info("Number of params: %d (retrieval took %f secs)" % (num_params, toc - tic))

        exp_loss = None

        # Checkpoint management.
        checkpoint_path = os.path.join(self.FLAGS.train_dir, "qa.ckpt")
        bestmodel_dir = os.path.join(self.FLAGS.train_dir, "best_checkpoint")
        bestmodel_ckpt_path = os.path.join(bestmodel_dir, "qa_best.ckpt")
        best_dev_f1 = None
        best_dev_em = None

        # for TensorBoard
        summary_writer = tf.summary.FileWriter(self.FLAGS.train_dir, session.graph)

        epoch = 0

        logging.info("Beginning training loop...")
        while self.FLAGS.num_epochs == 0 or epoch < self.FLAGS.num_epochs:
            epoch += 1
            epoch_tic = time.time()

            for batch in get_batch_generator(self.word2id, train_context_path, train_qn_path, train_ans_path, self.FLAGS.batch_size, context_len=self.FLAGS.context_len, question_len=self.FLAGS.question_len, discard_long=True):

                iter_tic = time.time()
                loss, global_step, param_norm, grad_norm = self.run_train_iter(session, batch, summary_writer)
                iter_toc = time.time()
                iter_time = iter_toc - iter_tic

                if not exp_loss: # first iter
                    exp_loss = loss
                else:
                    exp_loss = 0.99 * exp_loss + 0.01 * loss

                if global_step % self.FLAGS.print_every == 0:
                    logging.info(
                        'epoch %d, iter %d, loss %.5f, smoothed loss %.5f, grad norm %.5f, param norm %.5f, batch time %.3f' %
                        (epoch, global_step, loss, exp_loss, grad_norm, param_norm, iter_time))

                if global_step % self.FLAGS.save_every == 0:
                    logging.info("Saving to %s..." % checkpoint_path)
                    self.saver.save(session, checkpoint_path, global_step=global_step)

                if global_step % self.FLAGS.eval_every == 0:

                    dev_loss = self.get_dev_loss(session, dev_context_path, dev_qn_path, dev_ans_path)
                    logging.info("Epoch %d, Iter %d, dev loss: %f" % (epoch, global_step, dev_loss))
                    write_summary(dev_loss, "dev/loss", summary_writer, global_step)


                    train_f1, train_em = self.check_f1_em(session, train_context_path, train_qn_path, train_ans_path, "train", num_samples=1000)
                    logging.info("Epoch %d, Iter %d, Train F1 score: %f, Train EM score: %f" % (epoch, global_step, train_f1, train_em))
                    write_summary(train_f1, "train/F1", summary_writer, global_step)
                    write_summary(train_em, "train/EM", summary_writer, global_step)


                    dev_f1, dev_em = self.check_f1_em(session, dev_context_path, dev_qn_path, dev_ans_path, "dev", num_samples=0)
                    logging.info("Epoch %d, Iter %d, Dev F1 score: %f, Dev EM score: %f" % (epoch, global_step, dev_f1, dev_em))
                    write_summary(dev_f1, "dev/F1", summary_writer, global_step)
                    write_summary(dev_em, "dev/EM", summary_writer, global_step)


                    if best_dev_f1 is None or dev_f1 > best_dev_f1:
                        best_dev_f1 = dev_f1
                        logging.info("Saving to %s..." % bestmodel_ckpt_path)
                        self.bestmodel_saver.save(session, bestmodel_ckpt_path, global_step=global_step)


            epoch_toc = time.time()
            logging.info("End of epoch %i. Time for epoch: %f" % (epoch, epoch_toc-epoch_tic))

        sys.stdout.flush()
Example #17
0
    def check_f1_em(self, session, context_path, qn_path, ans_path, dataset, num_samples=100, print_to_screen=False):
        """
        Sample from the provided (train/dev) set.
        For each sample, calculate F1 and EM score.
        Return average F1 and EM score for all samples.
        Optionally pretty-print examples.

        Note: This function is not quite the same as the F1/EM numbers you get from "official_eval" mode.
        This function uses the pre-processed version of the e.g. dev set for speed,
        whereas "official_eval" mode uses the original JSON. Therefore:
          1. official_eval takes your max F1/EM score w.r.t. the three reference answers,
            whereas this function compares to just the first answer (which is what's saved in the preprocessed data)
          2. Our preprocessed version of the dev set is missing some examples
            due to tokenization issues (see squad_preprocess.py).
            "official_eval" includes all examples.

        Inputs:
          session: TensorFlow session
          qn_path, context_path, ans_path: paths to {dev/train}.{question/context/answer} data files.
          dataset: string. Either "train" or "dev". Just for logging purposes.
          num_samples: int. How many samples to use. If num_samples=0 then do whole dataset.
          print_to_screen: if True, pretty-prints each example to screen

        Returns:
          F1 and EM: Scalars. The average across the sampled examples.
        """
        logging.info("Calculating F1/EM for %s examples in %s set..." % (str(num_samples) if num_samples != 0 else "all", dataset))

        f1_total = 0.
        em_total = 0.
        example_num = 0

        tic = time.time()

        # Note here we select discard_long=False because we want to sample from the entire dataset
        # That means we're truncating, rather than discarding, examples with too-long context or questions
        for batch in get_batch_generator(self.word2id, context_path, qn_path, ans_path, self.FLAGS.batch_size, context_len=self.FLAGS.context_len, question_len=self.FLAGS.question_len, discard_long=False):

            pred_start_pos, pred_end_pos = self.get_start_end_pos(session, batch)

            # Convert the start and end positions to lists length batch_size
            pred_start_pos = pred_start_pos.tolist() # list length batch_size
            pred_end_pos = pred_end_pos.tolist() # list length batch_size

            for ex_idx, (pred_ans_start, pred_ans_end, true_ans_tokens) in enumerate(zip(pred_start_pos, pred_end_pos, batch.ans_tokens)):
                example_num += 1

                # Get the predicted answer
                # Important: batch.context_tokens contains the original words (no UNKs)
                # You need to use the original no-UNK version when measuring F1/EM
                pred_ans_tokens = batch.context_tokens[ex_idx][pred_ans_start : pred_ans_end + 1]
                pred_answer = " ".join(pred_ans_tokens)

                # Get true answer (no UNKs)
                true_answer = " ".join(true_ans_tokens)

                # Calc F1/EM
                f1 = f1_score(pred_answer, true_answer)
                em = exact_match_score(pred_answer, true_answer)
                f1_total += f1
                em_total += em

                # Optionally pretty-print
                if print_to_screen:
                    print_example(self.word2id, batch.context_tokens[ex_idx], batch.qn_tokens[ex_idx], batch.ans_span[ex_idx, 0], batch.ans_span[ex_idx, 1], pred_ans_start, pred_ans_end, true_answer, pred_answer, f1, em)

                if num_samples != 0 and example_num >= num_samples:
                    break

            if num_samples != 0 and example_num >= num_samples:
                break

        f1_total /= example_num
        em_total /= example_num

        toc = time.time()
        logging.info("Calculating F1/EM for %i examples in %s set took %.2f seconds" % (example_num, dataset, toc-tic))

        return f1_total, em_total
Example #18
0
    def train(self, session, train_context_path, train_qn_path, train_ans_path, dev_qn_path, dev_context_path, dev_ans_path):
        """
        Main training loop.

        Inputs:
          session: TensorFlow session
          {train/dev}_{qn/context/ans}_path: paths to {train/dev}.{context/question/answer} data files
        """

        # Print number of model parameters
        tic = time.time()
        params = tf.trainable_variables()
        num_params = sum(map(lambda t: np.prod(tf.shape(t.value()).eval()), params))
        toc = time.time()
        logging.info("Number of params: %d (retrieval took %f secs)" % (num_params, toc - tic))

        # We will keep track of exponentially-smoothed loss
        exp_loss = None

        # Checkpoint management.
        # We keep one latest checkpoint, and one best checkpoint (early stopping)
        checkpoint_path = os.path.join(self.FLAGS.train_dir, "qa.ckpt")
        bestmodel_dir = os.path.join(self.FLAGS.train_dir, "best_checkpoint")
        bestmodel_ckpt_path = os.path.join(bestmodel_dir, "qa_best.ckpt")
        best_dev_f1 = None
        best_dev_em = None

        # for TensorBoard
        summary_writer = tf.summary.FileWriter(self.FLAGS.train_dir, session.graph)

        epoch = 0

        logging.info("Beginning training loop...")
        while self.FLAGS.num_epochs == 0 or epoch < self.FLAGS.num_epochs:
            epoch += 1
            epoch_tic = time.time()

            # Loop over batches
            for batch in get_batch_generator(self.word2id, train_context_path, train_qn_path, train_ans_path, self.FLAGS.batch_size, context_len=self.FLAGS.context_len, question_len=self.FLAGS.question_len, discard_long=True):

                # Run training iteration
                iter_tic = time.time()
                loss, global_step, param_norm, grad_norm = self.run_train_iter(session, batch, summary_writer)
                iter_toc = time.time()
                iter_time = iter_toc - iter_tic

                # Update exponentially-smoothed loss
                if not exp_loss: # first iter
                    exp_loss = loss
                else:
                    exp_loss = 0.99 * exp_loss + 0.01 * loss

                # Sometimes print info to screen
                if global_step % self.FLAGS.print_every == 0:
                    logging.info(
                        'epoch %d, iter %d, loss %.5f, smoothed loss %.5f, grad norm %.5f, param norm %.5f, batch time %.3f' %
                        (epoch, global_step, loss, exp_loss, grad_norm, param_norm, iter_time))

                # Sometimes save model
                if global_step % self.FLAGS.save_every == 0:
                    logging.info("Saving to %s..." % checkpoint_path)
                    self.saver.save(session, checkpoint_path, global_step=global_step)

                # Sometimes evaluate model on dev loss, train F1/EM and dev F1/EM
                if global_step % self.FLAGS.eval_every == 0:

                    # Get loss for entire dev set and log to tensorboard
                    dev_loss = self.get_dev_loss(session, dev_context_path, dev_qn_path, dev_ans_path)
                    logging.info("Epoch %d, Iter %d, dev loss: %f" % (epoch, global_step, dev_loss))
                    write_summary(dev_loss, "dev/loss", summary_writer, global_step)


                    # Get F1/EM on train set and log to tensorboard
                    train_f1, train_em = self.check_f1_em(session, train_context_path, train_qn_path, train_ans_path, "train", num_samples=1000)
                    logging.info("Epoch %d, Iter %d, Train F1 score: %f, Train EM score: %f" % (epoch, global_step, train_f1, train_em))
                    write_summary(train_f1, "train/F1", summary_writer, global_step)
                    write_summary(train_em, "train/EM", summary_writer, global_step)


                    # Get F1/EM on dev set and log to tensorboard
                    dev_f1, dev_em = self.check_f1_em(session, dev_context_path, dev_qn_path, dev_ans_path, "dev", num_samples=0)
                    logging.info("Epoch %d, Iter %d, Dev F1 score: %f, Dev EM score: %f" % (epoch, global_step, dev_f1, dev_em))
                    write_summary(dev_f1, "dev/F1", summary_writer, global_step)
                    write_summary(dev_em, "dev/EM", summary_writer, global_step)


                    # Early stopping based on dev EM. You could switch this to use F1 instead.
                    if best_dev_em is None or dev_em > best_dev_em:
                        best_dev_em = dev_em
                        logging.info("Saving to %s..." % bestmodel_ckpt_path)
                        self.bestmodel_saver.save(session, bestmodel_ckpt_path, global_step=global_step)


            epoch_toc = time.time()
            logging.info("End of epoch %i. Time for epoch: %f" % (epoch, epoch_toc-epoch_tic))

        sys.stdout.flush()
Example #19
0
    def train(self, model_file_path):
        train_dir = os.path.join(config.log_root, 'train_%d' % (int(time.time())))
        if not os.path.exists(train_dir):
            os.mkdir(train_dir)
        model_dir = os.path.join(train_dir, 'model')
        if not os.path.exists(model_dir):
            os.mkdir(model_dir)
        bestmodel_dir = os.path.join(train_dir, 'bestmodel')
        if not os.path.exists(bestmodel_dir):
            os.makedirs(bestmodel_dir)

        summary_writer = SummaryWriter(train_dir)

        with open(os.path.join(train_dir, "flags.json"), 'w') as fout:
            json.dump(vars(config), fout)

        model = self.get_model(model_file_path)
        params = list(filter(lambda p: p.requires_grad, model.parameters()))
        optimizer = Adam(params, lr=config.lr, weight_decay=config.reg_lambda, amsgrad=True)

        num_params = sum(p.numel() for p in params)
        logging.info("Number of params: %d" % num_params)

        exp_loss, best_dev_f1, best_dev_em = None, None, None

        epoch = 0
        global_step = 0

        logging.info("Beginning training loop...")
        while config.num_epochs == 0 or epoch < config.num_epochs:
            epoch += 1
            epoch_tic = time.time()
            for batch in get_batch_generator(self.word2id, self.train_context_path,
                                             self.train_qn_path, self.train_ans_path,
                                             config.batch_size, context_len=config.context_len,
                                             question_len=config.question_len, discard_long=True):
                global_step += 1
                iter_tic = time.time()

                loss, param_norm, grad_norm = self.train_one_batch(batch, model, optimizer, params)
                summary_writer.add_scalar("train/loss", loss, global_step)

                iter_toc = time.time()
                iter_time = iter_toc - iter_tic

                if not exp_loss:
                    exp_loss = loss
                else:
                    exp_loss = 0.99 * exp_loss + 0.01 * loss

                if global_step % config.print_every == 0:
                    logging.info(
                        'epoch %d, iter %d, loss %.5f, smoothed loss %.5f, grad norm %.5f, param norm %.5f, batch time %.3f' %
                        (epoch, global_step, loss, exp_loss, grad_norm, param_norm, iter_time))


                if global_step % config.save_every == 0:
                    logging.info("Saving to %s..." % model_dir)
                    self.save_model(model, optimizer, loss, global_step, epoch, model_dir)

                if global_step % config.eval_every == 0:
                    dev_loss = self.get_dev_loss(model)
                    logging.info("Epoch %d, Iter %d, dev loss: %f" % (epoch, global_step, dev_loss))
                    summary_writer.add_scalar("dev/loss", dev_loss, global_step)

                    train_f1, train_em = self.check_f1_em(model, "train", num_samples=1000)
                    logging.info("Epoch %d, Iter %d, Train F1 score: %f, Train EM score: %f" % (
                        epoch, global_step, train_f1, train_em))
                    summary_writer.add_scalar("train/F1", train_f1, global_step)
                    summary_writer.add_scalar("train/EM", train_em, global_step)

                    dev_f1, dev_em = self.check_f1_em(model, "dev", num_samples=0)
                    logging.info(
                        "Epoch %d, Iter %d, Dev F1 score: %f, Dev EM score: %f" % (epoch, global_step, dev_f1, dev_em))
                    summary_writer.add_scalar("dev/F1", dev_f1, global_step)
                    summary_writer.add_scalar("dev/EM", dev_em, global_step)

                    if best_dev_f1 is None or dev_f1 > best_dev_f1:
                        best_dev_f1 = dev_f1

                    if best_dev_em is None or dev_em > best_dev_em:
                        best_dev_em = dev_em
                        logging.info("Saving to %s..." % bestmodel_dir)
                        self.save_model(model, optimizer, loss, global_step, epoch, bestmodel_dir)


            epoch_toc = time.time()
            logging.info("End of epoch %i. Time for epoch: %f" % (epoch, epoch_toc - epoch_tic))

        sys.stdout.flush()
Example #20
0
    def train(self, session):
        """
        Main training loop.
        """
        # Print number of model parameters
        tic = time.time()
        params = tf.trainable_variables()
        num_params = sum(map(lambda t: np.prod(tf.shape(t.value()).eval()), params))
        toc = time.time()
        logging.info("Number of params: %d (retrieval took %f secs)" % (num_params, toc - tic))
        print(params)   # For debugging purpose

        # We will keep track of exponentially-smoothed loss
        exp_loss = None

        # Checkpoint management. We keep one latest checkpoint, and one best checkpoint (early stopping)
        checkpoint_path = os.path.join(self.FLAGS.train_dir, "latest.ckpt")
        bestmodel_dir = self.FLAGS.bestmodel_dir
        bestmodel_ckpt_path = os.path.join(bestmodel_dir, "best.ckpt")
        best_val_metric = None

        # For TensorBoard
        summary_writer = tf.summary.FileWriter(self.FLAGS.train_dir, session.graph)

        epoch = 0

        logging.info("Beginning training loop...")
        while self.FLAGS.num_epochs == 0 or epoch < self.FLAGS.num_epochs:
            epoch += 1
            epoch_tic = time.time()

            # Loop over batches
            for batch in get_batch_generator(self.word2id, self.img_features_map, self.train_caption_id_2_caption,
                                             self.caption_id_2_img_id, self.FLAGS.batch_size, self.FLAGS.max_caption_len, 'train', None, self.FLAGS.data_source):

                # Run training iteration
                iter_tic = time.time()
                loss, global_step, param_norm, grad_norm = self.run_train_iter(session, batch, summary_writer)
                iter_toc = time.time()
                iter_time = iter_toc - iter_tic

                # Update exponentially-smoothed loss
                if not exp_loss: # first iter
                    exp_loss = loss
                else:
                    exp_loss = 0.99 * exp_loss + 0.01 * loss

                # Sometimes print info to screen
                if global_step % self.FLAGS.print_every == 0:
                    logging.info(
                        'epoch %d, iter %d, loss %.5f, smoothed loss %.5f, grad norm %.5f, param norm %.5f, batch time %.3f' %
                        (epoch, global_step, loss, exp_loss, grad_norm, param_norm, iter_time))
                    write_summary(loss, "train/loss", summary_writer, global_step)

                # Sometimes save model
                if global_step % self.FLAGS.save_every == 0:
                    logging.info("Saving to %s..." % checkpoint_path)
                    self.saver.save(session, checkpoint_path, global_step=global_step)

                # Sometimes evaluate the model
                if global_step % self.FLAGS.eval_every == 0:
                    # Get loss for entire val set and log to tensorboard
                    val_loss = self.get_val_loss(session)
                    logging.info("Epoch %d, Iter %d, Val loss: %f" % (epoch, global_step, val_loss))
                    write_summary(val_loss, "val/loss", summary_writer, global_step)

                    # Evaluate on val set and log all the metrics to tensorboard
                    val_scores = self.check_metric(session, mode='val', num_samples=0)
                    val_metric = val_scores[self.FLAGS.primary_metric]
                    for metric_name, metric_score in val_scores.items():
                        logging.info("Epoch {}, Iter {}, Val {}: {}".format(epoch, global_step, metric_name, metric_score))
                        write_summary(metric_score, "val/"+metric_name, summary_writer, global_step)

                    # Early stopping based on val evaluation
                    if best_val_metric is None or val_metric > best_val_metric:
                        best_val_metric = val_metric
                        logging.info("Saving to %s..." % bestmodel_ckpt_path)
                        self.bestmodel_saver.save(session, bestmodel_ckpt_path, global_step=global_step)

            epoch_toc = time.time()
            logging.info("End of epoch %i. Time for epoch: %f" % (epoch, epoch_toc-epoch_tic))

        sys.stdout.flush()
Example #21
0
    def visualize_results(self,
                          session,
                          context_path,
                          qn_path,
                          ans_path,
                          dataset,
                          num_samples=100,
                          print_to_screen=False):
        logging.info(
            "Calculating F1/EM for %s examples in %s set..." %
            (str(num_samples) if num_samples != 0 else "all", dataset))

        f1_total = 0.
        em_total = 0.
        example_num = 0

        tic = time.time()

        # Note here we select discard_long=False because we want to sample from the entire dataset
        # That means we're truncating, rather than discarding, examples with too-long context or questions
        for batch in get_batch_generator(self.word2id,
                                         context_path,
                                         qn_path,
                                         ans_path,
                                         self.FLAGS.batch_size,
                                         context_len=self.FLAGS.context_len,
                                         question_len=self.FLAGS.question_len,
                                         discard_long=False):

            pred_start_pos, pred_end_pos, c2q_attn, q2c_attn, strt_logts, end_logts = self.get_results_vis(
                session, batch)

            # Convert the start and end positions to lists length batch_size
            pred_start_pos = pred_start_pos.tolist()  # list length batch_size
            pred_end_pos = pred_end_pos.tolist()  # list length batch_size

            q2c_attn = np.argmax(q2c_attn, axis=1)
            q2c_attn = q2c_attn.tolist()

            fig = plt.figure()

            gs = grd.GridSpec(3, 1, height_ratios=[1, 3, 1])

            ax = plt.subplot(gs[1])
            c2q_attn_plt = c2q_attn[
                0, :len(batch.context_tokens[0]), :len(batch.qn_tokens[0])]
            p = ax.imshow(np.transpose(c2q_attn_plt),
                          interpolation='nearest',
                          aspect='auto')
            plt.ylabel('c2q attn')
            plt.xlim(0, len(batch.context_tokens[0]))

            ax2 = plt.subplot(gs[0])
            ax2.plot(strt_logts[0, :len(batch.context_tokens[0])])
            plt.ylabel('start logits')
            plt.xlim(0, len(batch.context_tokens[0]))

            ax3 = plt.subplot(gs[2])
            ax3.plot(end_logts[0, :len(batch.context_tokens[0])])
            plt.ylabel('end logits')
            plt.xlim(0, len(batch.context_tokens[0]))

            plt.savefig('c2q_attn.png')
            plt.clf()

            for ex_idx, (pred_ans_start, pred_ans_end, true_ans_tokens,
                         q2c_attn_idx) in enumerate(
                             zip(pred_start_pos, pred_end_pos,
                                 batch.ans_tokens, q2c_attn)):
                example_num += 1

                # Get the predicted answer
                # Important: batch.context_tokens contains the original words (no UNKs)
                # You need to use the original no-UNK version when measuring F1/EM
                pred_ans_tokens = batch.context_tokens[ex_idx][
                    pred_ans_start:pred_ans_end + 1]
                pred_answer = " ".join(pred_ans_tokens)
                qn_attn_words = [
                    batch.context_tokens[ex_idx][i]
                    for i in q2c_attn_idx[:len(batch.qn_tokens[ex_idx])]
                ]
                qn_attn = " ".join(qn_attn_words)

                # Get true answer (no UNKs)
                true_answer = " ".join(true_ans_tokens)

                # Calc F1/EM
                f1 = f1_score(pred_answer, true_answer)
                em = exact_match_score(pred_answer, true_answer)
                f1_total += f1
                em_total += em

                # Optionally pretty-print
                if print_to_screen:
                    print_example_attn(self.word2id,
                                       batch.context_tokens[ex_idx],
                                       batch.qn_tokens[ex_idx], qn_attn,
                                       batch.ans_span[ex_idx, 0],
                                       batch.ans_span[ex_idx, 1],
                                       pred_ans_start, pred_ans_end,
                                       true_answer, pred_answer, f1, em)

                if num_samples != 0 and example_num >= num_samples:
                    break

            if num_samples != 0 and example_num >= num_samples:
                break

        f1_total /= example_num
        em_total /= example_num

        toc = time.time()
        logging.info(
            "Calculating F1/EM for %i examples in %s set took %.2f seconds" %
            (example_num, dataset, toc - tic))

        return f1_total, em_total
    def check_f1_em(self,
                    session,
                    context_path,
                    qn_path,
                    ans_path,
                    dataset,
                    num_samples=100,
                    print_to_screen=False):
        """
        Sample from the provided (train/dev) set.
        For each sample, calculate F1 and EM score.
        Return average F1 and EM score for all samples.
        Optionally pretty-print examples.

        Note: This function is not quite the same as the F1/EM numbers you get from "official_eval" mode.
        This function uses the pre-processed version of the e.g. dev set for speed,
        whereas "official_eval" mode uses the original JSON. Therefore:
          1. official_eval takes your max F1/EM score w.r.t. the three reference answers,
            whereas this function compares to just the first answer (which is what's saved in the preprocessed data)
          2. Our preprocessed version of the dev set is missing some examples
            due to tokenization issues (see squad_preprocess.py).
            "official_eval" includes all examples.

        Inputs:
          session: TensorFlow session
          qn_path, context_path, ans_path: paths to {dev/train}.{question/context/answer} data files.
          dataset: string. Either "train" or "dev". Just for logging purposes.
          num_samples: int. How many samples to use. If num_samples=0 then do whole dataset.
          print_to_screen: if True, pretty-prints each example to screen

        Returns:
          F1 and EM: Scalars. The average across the sampled examples.
        """
        logging.info(
            "Calculating F1/EM for %s examples in %s set..." %
            (str(num_samples) if num_samples != 0 else "all", dataset))

        f1_total = 0.
        em_total = 0.
        example_num = 0

        tic = time.time()

        # Note here we select discard_long=False because we want to sample from the entire dataset
        # That means we're truncating, rather than discarding, examples with too-long context or questions
        for batch in get_batch_generator(self.word2id,
                                         context_path,
                                         qn_path,
                                         ans_path,
                                         self.FLAGS.batch_size,
                                         context_len=self.FLAGS.context_len,
                                         question_len=self.FLAGS.question_len,
                                         discard_long=False):

            pred_start_pos, pred_end_pos = self.get_start_end_pos(
                session, batch)

            # Convert the start and end positions to lists length batch_size
            pred_start_pos = pred_start_pos.tolist()  # list length batch_size
            pred_end_pos = pred_end_pos.tolist()  # list length batch_size

            for ex_idx, (pred_ans_start, pred_ans_end,
                         true_ans_tokens) in enumerate(
                             zip(pred_start_pos, pred_end_pos,
                                 batch.ans_tokens)):
                example_num += 1

                # Get the predicted answer
                # Important: batch.context_tokens contains the original words (no UNKs)
                # You need to use the original no-UNK version when measuring F1/EM
                pred_ans_tokens = batch.context_tokens[ex_idx][
                    pred_ans_start:pred_ans_end + 1]
                pred_answer = " ".join(pred_ans_tokens)

                # Get true answer (no UNKs)
                true_answer = " ".join(true_ans_tokens)

                # Calc F1/EM
                f1 = f1_score(pred_answer, true_answer)
                em = exact_match_score(pred_answer, true_answer)
                f1_total += f1
                em_total += em

                # Optionally pretty-print
                if print_to_screen:
                    print_example(self.word2id, batch.context_tokens[ex_idx],
                                  batch.qn_tokens[ex_idx],
                                  batch.ans_span[ex_idx,
                                                 0], batch.ans_span[ex_idx, 1],
                                  pred_ans_start, pred_ans_end, true_answer,
                                  pred_answer, f1, em)

                if num_samples != 0 and example_num >= num_samples:
                    break

            if num_samples != 0 and example_num >= num_samples:
                break

        f1_total /= example_num
        em_total /= example_num

        toc = time.time()
        logging.info(
            "Calculating F1/EM for %i examples in %s set took %.2f seconds" %
            (example_num, dataset, toc - tic))

        return f1_total, em_total
    def check_f1_em(self,
                    session,
                    context_path,
                    qn_path,
                    ans_path,
                    dataset,
                    num_samples=100,
                    print_to_screen=False):

        logging.info(
            "Calculating F1/EM for %s examples in %s set..." %
            (str(num_samples) if num_samples != 0 else "all", dataset))

        f1_total = 0.
        em_total = 0.
        example_num = 0

        tic = time.time()

        # Note here we select discard_long=False because we want to sample from the entire dataset
        # That means we're truncating, rather than discarding, examples with too-long context or questions
        for batch in get_batch_generator(self.word2id,
                                         context_path,
                                         qn_path,
                                         ans_path,
                                         self.FLAGS.batch_size,
                                         context_len=self.FLAGS.context_len,
                                         question_len=self.FLAGS.question_len,
                                         discard_long=False):

            pred_start_pos, pred_end_pos = self.get_start_end_pos(
                session, batch)

            # Convert the start and end positions to lists length batch_size
            pred_start_pos = pred_start_pos.tolist()  # list length batch_size
            pred_end_pos = pred_end_pos.tolist()  # list length batch_size

            for ex_idx, (pred_ans_start, pred_ans_end,
                         true_ans_tokens) in enumerate(
                             zip(pred_start_pos, pred_end_pos,
                                 batch.ans_tokens)):
                example_num += 1

                # Get the predicted answer
                # Important: batch.context_tokens contains the original words (no UNKs)
                # You need to use the original no-UNK version when measuring F1/EM
                pred_ans_tokens = batch.context_tokens[ex_idx][
                    pred_ans_start:pred_ans_end + 1]
                pred_answer = " ".join(pred_ans_tokens)

                # Get true answer (no UNKs)
                true_answer = " ".join(true_ans_tokens)

                # Calc F1/EM
                f1 = f1_score(pred_answer, true_answer)
                em = exact_match_score(pred_answer, true_answer)
                f1_total += f1
                em_total += em

                # Optionally pretty-print
                if print_to_screen:
                    print_example(self.word2id, batch.context_tokens[ex_idx],
                                  batch.qn_tokens[ex_idx],
                                  batch.ans_span[ex_idx,
                                                 0], batch.ans_span[ex_idx, 1],
                                  pred_ans_start, pred_ans_end, true_answer,
                                  pred_answer, f1, em)

                if num_samples != 0 and example_num >= num_samples:
                    break

            if num_samples != 0 and example_num >= num_samples:
                break

        f1_total /= example_num
        em_total /= example_num

        toc = time.time()
        logging.info(
            "Calculating F1/EM for %i examples in %s set took %.2f seconds" %
            (example_num, dataset, toc - tic))

        return f1_total, em_total
Example #24
0
    def get_error_stats(self,
                        session,
                        context_path,
                        qn_path,
                        ans_path,
                        dataset,
                        num_samples=10,
                        print_to_screen=False):
        """
        Sample from the provided (train/dev) set.
        For each sample, calculate F1 and EM score.
        Return average F1 and EM score for all samples.
        Optionally pretty-print examples.

        Note: This function is not quite the same as the F1/EM numbers you get from "official_eval" mode.
        This function uses the pre-processed version of the e.g. dev set for speed,
        whereas "official_eval" mode uses the original JSON. Therefore:
          1. official_eval takes your max F1/EM score w.r.t. the three reference answers,
            whereas this function compares to just the first answer (which is what's saved in the preprocessed data)
          2. Our preprocessed version of the dev set is missing some examples
            due to tokenization issues (see squad_preprocess.py).
            "official_eval" includes all examples.

        Inputs:
          session: TensorFlow session
          qn_path, context_path, ans_path: paths to {dev/train}.{question/context/answer} data files.
          dataset: string. Either "train" or "dev". Just for logging purposes.
          num_samples: int. How many samples to use. If num_samples=0 then do whole dataset.
          print_to_screen: if True, pretty-prints each example to screen

        Returns:
          F1 and EM: Scalars. The average across the sampled examples.
        """
        logging.info(
            "Calculating Error stats for %s examples in %s set..." %
            (str(num_samples) if num_samples != 0 else "all", dataset))

        f1_total = 0.
        em_total = 0.
        example_num = 0

        tic = time.time()

        # Note here we select discard_long=False because we want to sample from the entire dataset
        # That means we're truncating, rather than discarding, examples with too-long context or questions
        first_token_qn_dict_wrong = defaultdict(float)
        first_token_qn_dict_total = defaultdict(float)
        first_token_qn_dict_f1 = defaultdict(float)

        for batch in get_batch_generator(self.word2id,
                                         context_path,
                                         qn_path,
                                         ans_path,
                                         self.FLAGS.batch_size,
                                         context_len=self.FLAGS.context_len,
                                         question_len=self.FLAGS.question_len,
                                         discard_long=False):

            pred_start_pos, pred_end_pos = self.get_start_end_pos(
                session, batch)

            # Convert the start and end positions to lists length batch_size
            pred_start_pos = pred_start_pos.tolist()  # list length batch_size
            pred_end_pos = pred_end_pos.tolist()  # list length batch_size
            for ex_idx, (pred_ans_start, pred_ans_end,
                         true_ans_tokens) in enumerate(
                             zip(pred_start_pos, pred_end_pos,
                                 batch.ans_tokens)):
                example_num += 1

                # Get the predicted answer
                # Important: batch.context_tokens contains the original words (no UNKs)
                # You need to use the original no-UNK version when measuring F1/EM
                pred_ans_tokens = batch.context_tokens[ex_idx][
                    pred_ans_start:pred_ans_end + 1]
                pred_answer = " ".join(pred_ans_tokens)

                # Get true answer (no UNKs)
                true_answer = " ".join(true_ans_tokens)

                # Calc F1/EM
                f1 = f1_score(pred_answer, true_answer)
                em = exact_match_score(pred_answer, true_answer)

                first_token_qn = batch.qn_tokens[ex_idx][0]
                first_token_qn_dict_total[first_token_qn] += 1
                #print 'example_num: ', example_num
                #print 'total words seen in first_token_qn_dict: ', sum(first_token_qn_dict_total.itervalues())
                if not em:
                    #we have found an error:
                    #get first token of error question:
                    first_token_qn_dict_wrong[first_token_qn] += 1

                f1_total += f1
                first_token_qn_dict_f1[first_token_qn] += f1
                em_total += em

                # Optionally pretty-print
                if print_to_screen:
                    print_example(self.word2id, batch.context_tokens[ex_idx],
                                  batch.qn_tokens[ex_idx],
                                  batch.ans_span[ex_idx,
                                                 0], batch.ans_span[ex_idx, 1],
                                  pred_ans_start, pred_ans_end, true_answer,
                                  pred_answer, f1, em)

                if num_samples != 0 and example_num >= num_samples:
                    break

            if num_samples != 0 and example_num >= num_samples:
                break

        f1_total /= example_num
        em_total /= example_num
        print 'total words: ', sum(first_token_qn_dict_total.itervalues())

        toc = time.time()
        logging.info(
            "Calculating F1/EM for %i examples in %s set took %.2f seconds" %
            (example_num, dataset, toc - tic))

        final_freq_dict = {}
        for token, count in sorted(first_token_qn_dict_total.iteritems(),
                                   key=lambda (k, v): (v, k)):
            #key is fist token of question, value is how many times that token occurs
            freq = first_token_qn_dict_wrong[
                token] / first_token_qn_dict_total[token]
            f1 = first_token_qn_dict_f1[token] / first_token_qn_dict_total[
                token]
            print "When first token is: [", token, "] f1:", f1, "We got : ", first_token_qn_dict_wrong[
                token], " wrong exact match out of ", first_token_qn_dict_total[
                    token], " percentage of 1st tokens that are this token: ", first_token_qn_dict_total[
                        token] / sum(first_token_qn_dict_total.itervalues(
                        )), " precentage of this token WRONG: ", freq

        print('em_total:', em_total)
        print('f1_total:', f1_total)
        return f1_total, em_total
Example #25
0
    def check_f1_em(self,
                    session,
                    context_path,
                    qn_path,
                    ans_path,
                    dataset,
                    num_samples=10,
                    print_to_screen=False,
                    write_out=False,
                    file_out=None,
                    shuffle=True):
        """
        Sample from the provided (train/dev) set.
        For each sample, calculate F1 and EM score.
        Return average F1 and EM score for all samples.
        Optionally pretty-print examples.

        Inputs:
          session: TensorFlow session
          qn_path, context_path, ans_path: paths to {dev/train}.{question/context/answer} data files.
          dataset: string. Either "train" or "dev". Just for logging purposes.
          num_samples: int. How many samples to use. If num_samples=0 then do whole dataset.
          print_to_screen: if True, pretty-prints each example to screen

        Returns:
          F1 and EM: Scalars. The average across the sampled examples.
        """
        logging.info(
            "Calculating F1/EM for %s examples in %s set..." %
            (str(num_samples) if num_samples != 0 else "all", dataset))

        f1_total = 0.
        em_total = 0.
        ed_total = 0.
        rough_em_total = 0.
        example_num = 0

        tic = time.time()
        ans_list = []
        graph_route_info = []
        # Note here we select discard_long=False because we want to sample from the entire dataset
        # That means we're truncating, rather than discarding, examples with too-long context or questions
        for batch in get_batch_generator(
                self.word2id,
                self.context2id,
                self.ans2id,
                context_path,
                qn_path,
                ans_path,
                self.FLAGS.batch_size,
                self.graph_vocab_class,
                context_len=self.FLAGS.context_len,
                question_len=self.FLAGS.question_len,
                answer_len=self.FLAGS.answer_len,
                discard_long=False,
                use_raw_graph=self.FLAGS.use_raw_graph,
                shuffle=shuffle,
                show_start_tokens=self.FLAGS.show_start_tokens,
                output_goal=True):
            train_ids, pred_ids, dev_final_states, pred_logits = self.get_prob_dists(
                session, batch)
            start_ids = batch.ans_ids[:, 0].reshape(-1)
            graph_length = np.sum(batch.context_mask, axis=1)

            if self.FLAGS.pred_method != 'beam':
                pred_ids, confidence_score, ans_str = verify_route(
                    start_ids, pred_logits, batch.context_tokens, self.ans2id,
                    self.id2ans, self.FLAGS.answer_len)

            f1_scores, em_scores, ed_scores, gm_scores = [], [], [], []

            pred_ids = pred_ids.tolist()  # the output of using test network
            for ex_idx, (pred_ans_list, true_ans_tokens) in enumerate(
                    zip(pred_ids, list(batch.ans_tokens))):
                example_num += 1
                pred_ans_tokens = []
                for id in pred_ans_list:
                    if id == PAD_ID:
                        break
                    else:
                        pred_ans_tokens.append(self.id2ans[id])
                pred_answer = " ".join(pred_ans_tokens)

                # Get true answer (no UNKs)
                true_answer = " ".join(true_ans_tokens[:])

                # Calculate metrics
                f1, em, edit_dist, goal_match = compute_all_metrics(
                    pred_ans_tokens, true_ans_tokens)
                f1_scores.append(f1)
                em_scores.append(em)
                ed_scores.append(edit_dist)
                gm_scores.append(goal_match)

                f1_total += f1

                em_total += em
                ed_total += edit_dist
                rough_em_total += goal_match
                ans_list.append(pred_answer)
                graph_route_info.append(
                    (str(int(graph_length[ex_idx])),
                     str(len(true_ans_tokens[1:-1])), str(int(em))))

                # Optionally pretty-print
                if print_to_screen:
                    print_example(self.word2id, self.context2id, self.ans2id,
                                  batch.context_tokens[ex_idx],
                                  batch.qn_tokens[ex_idx], true_answer,
                                  pred_answer, f1, em, edit_dist,
                                  confidence_score[ex_idx])

                if num_samples != 0 and example_num >= num_samples:
                    break

            if num_samples != 0 and example_num >= num_samples:
                break
        f1_total /= example_num
        em_total /= example_num
        ed_total /= example_num
        rough_em_total /= example_num

        toc = time.time()
        logging.info(
            "Calculating F1/EM for %i examples in %s set took %.2f seconds" %
            (example_num, dataset, toc - tic))
        if write_out:
            logging.info("Writing the prediction to {}".format(file_out))
            with open(file_out, 'w') as f:
                for line, extra_info in zip(ans_list, graph_route_info):
                    f.write(line + " " + " ".join(extra_info) + '\n')
            print("Wrote predictions to %s" % file_out)

            em_file = "em_" + str(file_out)
            logging.info("Writing EM scores to {}".format(em_file))
            with open(em_file, 'w') as f:
                for em in em_scores:
                    f.write(str(em) + '\n')
            print("Wrote EM Scores to %s" % em_file)

            ed_file = "ed_" + str(file_out)
            logging.info("Writing ED scores to {}".format(ed_file))
            with open(ed_file, 'w') as f:
                for ed in ed_scores:
                    f.write(str(ed) + '\n')
            print("Wrote ED Scores to %s" % ed_file)

            gm_file = "gm_" + str(file_out)
            logging.info("Writing GM scores to {}".format(gm_file))
            with open(gm_file, 'w') as f:
                for gm in gm_scores:
                    f.write(str(gm) + '\n')
            print("Wrote GM Scores to %s" % gm_file)

        return f1_total, em_total, ed_total, rough_em_total
Example #26
0
def main():
    print("Your TensorFlow version: %s" % tf.__version__)

    # Define train_dir
    if not FLAGS.experiment_name and not FLAGS.train_dir and FLAGS.mode != "official_eval":
        raise Exception(
            "You need to specify either --experiment_name or --train_dir")

    FLAGS.train_dir = FLAGS.train_dir or os.path.join(EXPERIMENTS_DIR,
                                                      FLAGS.experiment_name)
    bestmodel_dir = os.path.join(FLAGS.train_dir, "best_checkpoint")

    # Define path for glove vecs
    FLAGS.glove_path = FLAGS.glove_path or os.path.join(
        DEFAULT_DATA_DIR + "/glove.6B/", "glove.6B.{}d.txt".format(
            FLAGS.embedding_size))

    # Load embedding matrix and vocab mappings
    emb_matrix, word2id, id2word = get_glove(FLAGS.glove_path,
                                             FLAGS.embedding_size)

    # Get filepaths to train/dev datafiles for tokenized queries, contexts and answers
    train_context_path = os.path.join(FLAGS.data_dir, "train.context")
    train_qn_path = os.path.join(FLAGS.data_dir, "train.question")
    train_ans_path = os.path.join(FLAGS.data_dir, "train.span")
    dev_context_path = os.path.join(FLAGS.data_dir, "dev.context")
    dev_qn_path = os.path.join(FLAGS.data_dir, "dev.question")
    dev_ans_path = os.path.join(FLAGS.data_dir, "dev.span")

    global_step = 1
    epoch = 0
    print("Beginning training loop...")

    # Initialize model
    model = QAModel(FLAGS, id2word, word2id, emb_matrix)
    optimizer = tf.keras.optimizers.Adam(learning_rate=FLAGS.learning_rate)

    while FLAGS.num_epochs == 0 or epoch < FLAGS.num_epochs:
        epoch += 1
        epoch_tic = time.time()

        for batch in get_batch_generator( \
            word2id, train_context_path, train_qn_path, \
            train_ans_path, FLAGS.batch_size, context_len=FLAGS.context_len, \
            question_len=FLAGS.question_len, discard_long=True):
            # print(batch.ans_span)

            with tf.GradientTape() as tape:
                prob_start, prob_end = model([
                    batch.context_ids, batch.context_mask, batch.qn_ids,
                    batch.qn_mask
                ])
                # prob_start, prob_end = model(batch.context_ids, batch.context_mask, batch.qn_ids, batch.qn_mask)

                loss_start = tf.nn.sparse_softmax_cross_entropy_with_logits(
                    logits=prob_start, labels=batch.ans_span[:, 0])
                loss_start = tf.reduce_mean(loss_start)

                loss_end = tf.nn.sparse_softmax_cross_entropy_with_logits(
                    logits=prob_end, labels=batch.ans_span[:, 1])
                loss_end = tf.reduce_mean(loss_end)

                loss = loss_start + loss_end
                # print("loss %f" % (loss.numpy()))

            grads = tape.gradient(loss, model.variables)
            optimizer.apply_gradients(
                grads_and_vars=zip(grads, model.variables))

            if global_step % FLAGS.eval_every == 0:
                print("==== start evaluating ==== ")
                dev_f1, dev_em = evaluate(model, word2id, FLAGS,
                                          dev_context_path, dev_qn_path,
                                          dev_ans_path)
                print("Epoch %d, Iter %d, Dev F1 score: %f, Dev EM score: %f" %
                      (epoch, global_step, dev_f1, dev_em))
                print("==========================")
            global_step += 1

        epoch_toc = time.time()
        print("End of epoch %i. Time for epoch: %f" %
              (epoch, epoch_toc - epoch_tic))

    sys.stdout.flush()
    def train(self, session, train_context_path, train_qn_path, train_ans_path,
              dev_qn_path, dev_context_path, dev_ans_path):
        """
        Main training loop.

        Inputs:
          session: TensorFlow session
          {train/dev}_{qn/context/ans}_path: paths to {train/dev}.{context/question/answer} data files
        """

        # Print number of model parameters
        tic = time.time()
        params = tf.trainable_variables()
        num_params = sum(
            map(lambda t: np.prod(tf.shape(t.value()).eval()), params))
        toc = time.time()
        logging.info("Number of params: %d (retrieval took %f secs)" %
                     (num_params, toc - tic))

        # We will keep track of exponentially-smoothed loss
        exp_loss = None

        # Checkpoint management.
        # We keep one latest checkpoint, and one best checkpoint (early stopping)
        checkpoint_path = os.path.join(self.FLAGS.train_dir, "qa.ckpt")
        bestmodel_dir = os.path.join(self.FLAGS.train_dir, "best_checkpoint")
        bestmodel_ckpt_path = os.path.join(bestmodel_dir, "qa_best.ckpt")
        best_dev_f1 = None
        best_dev_em = None

        # for TensorBoard
        summary_writer = tf.summary.FileWriter(self.FLAGS.train_dir,
                                               session.graph)

        epoch = 0

        logging.info("Beginning training loop...")
        while self.FLAGS.num_epochs == 0 or epoch < self.FLAGS.num_epochs:
            epoch += 1
            epoch_tic = time.time()

            # Loop over batches
            for batch in get_batch_generator(
                    self.word2id,
                    train_context_path,
                    train_qn_path,
                    train_ans_path,
                    self.FLAGS.batch_size,
                    context_len=self.FLAGS.context_len,
                    question_len=self.FLAGS.question_len,
                    discard_long=True):

                # Run training iteration
                iter_tic = time.time()
                loss, global_step, param_norm, grad_norm = self.run_train_iter(
                    session, batch, summary_writer)
                iter_toc = time.time()
                iter_time = iter_toc - iter_tic

                # Update exponentially-smoothed loss
                if not exp_loss:  # first iter
                    exp_loss = loss
                else:
                    exp_loss = 0.99 * exp_loss + 0.01 * loss

                # Sometimes print info to screen
                if global_step % self.FLAGS.print_every == 0:
                    logging.info(
                        'epoch %d, iter %d, loss %.5f, smoothed loss %.5f, grad norm %.5f, param norm %.5f, batch time %.3f'
                        % (epoch, global_step, loss, exp_loss, grad_norm,
                           param_norm, iter_time))

                # Sometimes save model
                if global_step % self.FLAGS.save_every == 0:
                    logging.info("Saving to %s..." % checkpoint_path)
                    self.saver.save(session,
                                    checkpoint_path,
                                    global_step=global_step)

                # Sometimes evaluate model on dev loss, train F1/EM and dev F1/EM
                if global_step % self.FLAGS.eval_every == 0:

                    # Get loss for entire dev set and log to tensorboard
                    dev_loss = self.get_dev_loss(session, dev_context_path,
                                                 dev_qn_path, dev_ans_path)
                    logging.info("Epoch %d, Iter %d, dev loss: %f" %
                                 (epoch, global_step, dev_loss))
                    write_summary(dev_loss, "dev/loss", summary_writer,
                                  global_step)

                    # Get F1/EM on train set and log to tensorboard
                    train_f1, train_em = self.check_f1_em(session,
                                                          train_context_path,
                                                          train_qn_path,
                                                          train_ans_path,
                                                          "train",
                                                          num_samples=1000)
                    logging.info(
                        "Epoch %d, Iter %d, Train F1 score: %f, Train EM score: %f"
                        % (epoch, global_step, train_f1, train_em))
                    write_summary(train_f1, "train/F1", summary_writer,
                                  global_step)
                    write_summary(train_em, "train/EM", summary_writer,
                                  global_step)

                    # Get F1/EM on dev set and log to tensorboard
                    dev_f1, dev_em = self.check_f1_em(session,
                                                      dev_context_path,
                                                      dev_qn_path,
                                                      dev_ans_path,
                                                      "dev",
                                                      num_samples=0)
                    logging.info(
                        "Epoch %d, Iter %d, Dev F1 score: %f, Dev EM score: %f"
                        % (epoch, global_step, dev_f1, dev_em))
                    write_summary(dev_f1, "dev/F1", summary_writer,
                                  global_step)
                    write_summary(dev_em, "dev/EM", summary_writer,
                                  global_step)

                    # Early stopping based on dev EM. You could switch this to use F1 instead.
                    if best_dev_em is None or dev_em > best_dev_em:
                        best_dev_em = dev_em
                        logging.info("Saving to %s..." % bestmodel_ckpt_path)
                        self.bestmodel_saver.save(session,
                                                  bestmodel_ckpt_path,
                                                  global_step=global_step)

            epoch_toc = time.time()
            logging.info("End of epoch %i. Time for epoch: %f" %
                         (epoch, epoch_toc - epoch_tic))

        sys.stdout.flush()
Example #28
0

if __name__ == "__main__":

    logits_start, probdist_start, logits_end, probdist_end = build_graph()

    # run the program
    with tf.Session() as sess:
        # It is necessary to initialize variables once before running inference.
        sess.run(tf.global_variables_initializer())
        #sess = tf_debug.LocalCLIDebugWrapperSession(sess)
        for batch in get_batch_generator(word2id,
                                         char2id,
                                         train_context_path,
                                         train_qn_path,
                                         train_ans_path,
                                         batch_size,
                                         context_len,
                                         question_len,
                                         max_word_len,
                                         discard_long=True):

            # Create batches of data.
            input_feed = {}
            input_feed[context_elmo] = batcher.batch_sentences(
                batch.context_tokens)
            input_feed[question_elmo] = batcher.batch_sentences(
                batch.qn_tokens)
            input_feed[context_ids] = batch.context_ids
            input_feed[context_mask] = batch.context_mask
            input_feed[qn_ids] = batch.qn_ids
            input_feed[qn_mask] = batch.qn_mask