def train(self):
        self.logger.info(str(self.cfg))

        config_path = os.path.join(self.cfg["checkpoint_path"], "config.json")
        write_json(config_path, self.cfg)

        batch_size = self.cfg["batch_size"]
        epochs = self.cfg["epochs"]
        train_path = self.cfg["train_set"]
        valid_path = self.cfg["valid_set"]
        self.n_gold_spans = count_gold_spans(valid_path)

        if self.cfg["knn_sampling"] == "knn":
            self.knn_ids = h5py.File(
                os.path.join(self.cfg["raw_path"], "knn_ids.hdf5"), "r")
            valid_batch_size = 1
            shuffle = False
        else:
            valid_batch_size = batch_size
            shuffle = True

        valid_set = list(
            self.batcher.batchnize_dataset(data=valid_path,
                                           data_name="valid",
                                           batch_size=valid_batch_size,
                                           shuffle=shuffle))
        best_f1 = -np.inf
        init_lr = self.cfg["lr"]

        self.log_trainable_variables()
        self.logger.info("Start training...")
        self._add_summary()

        for epoch in range(1, epochs + 1):
            self.logger.info('Epoch {}/{}:'.format(epoch, epochs))

            train_set = self.batcher.batchnize_dataset(data=train_path,
                                                       data_name="train",
                                                       batch_size=batch_size,
                                                       shuffle=True)
            _ = self.train_knn_epoch(train_set, "train")

            if self.cfg["use_lr_decay"]:  # learning rate decay
                self.cfg["lr"] = max(
                    init_lr / (1.0 + self.cfg["lr_decay"] * epoch),
                    self.cfg["minimal_lr"])

            eval_metrics = self.evaluate_knn_epoch(valid_set, "valid")
            cur_valid_f1 = eval_metrics[0]

            if cur_valid_f1 > best_f1:
                best_f1 = cur_valid_f1
                self.save_session(epoch)
                self.logger.info(
                    '-- new BEST F1 on valid set: {:>7.2%}'.format(best_f1))

        self.train_writer.close()
        self.test_writer.close()
    def save_span_representation(self, data_name, preprocessor):
        self.logger.info(str(self.cfg))

        ########################
        # Load validation data #
        ########################
        valid_data = preprocessor.load_dataset(
            self.cfg["data_path"],
            keep_number=True,
            lowercase=self.cfg["char_lowercase"])
        valid_data = valid_data[:self.cfg["data_size"]]
        dataset = preprocessor.build_dataset(valid_data, self.word_dict,
                                             self.char_dict, self.tag_dict)
        self.logger.info("Valid sentences: {:>7}".format(len(dataset)))

        #############
        # Main loop #
        #############
        start_time = time.time()
        gold_labels = {}
        fout_path = os.path.join(self.cfg["checkpoint_path"],
                                 "%s.span_reps.hdf5" % data_name)
        fout = h5py.File(fout_path, 'w')

        print("PREDICTION START")
        for record, data in zip(valid_data, dataset):
            valid_sent_id = record["sent_id"]
            batch = self.make_one_batch_for_target(data, valid_sent_id)

            if (valid_sent_id + 1) % 100 == 0:
                print("%d" % (valid_sent_id + 1), flush=True, end=" ")

            ##############
            # Prediction #
            ##############
            feed_dict = self._get_feed_dict(batch)
            span_reps = self.sess.run([self.span_rep], feed_dict)[0][0]
            span_tags = batch["tags"][0]
            assert len(span_reps) == len(span_tags)

            ##################
            # Add the result #
            ##################
            fout.create_dataset(name='{}'.format(valid_sent_id),
                                dtype='float32',
                                data=span_reps)
            gold_labels[valid_sent_id] = [
                self.rev_tag_dict[int(tag)] for tag in span_tags
            ]
        fout.close()
        path = os.path.join(self.cfg["checkpoint_path"],
                            "%s.gold_labels.json" % data_name)
        write_json(path, gold_labels)
        self.logger.info("-- Time: %f seconds\nFINISHED." %
                         (time.time() - start_time))
Exemple #3
0
def process_data(config):
    train_file = os.path.join(config["raw_path"], "train.txt")
    dev_file = os.path.join(config["raw_path"], "dev.txt")
    ref_file = os.path.join(config["raw_path"], "ref.txt")
    asr_file = os.path.join(config["raw_path"], "asr.txt")
    if not os.path.exists(config["save_path"]):
        os.makedirs(config["save_path"])
    # build vocabulary
    word_vocab, char_vocab = build_vocab_list([train_file],
                                              config["min_word_count"],
                                              config["min_char_count"],
                                              config["max_vocab_size"])
    if not config["use_pretrained"]:
        word_dict, char_dict = build_vocabulary(word_vocab, char_vocab)
    else:
        glove_path = config["glove_path"].format(config["glove_name"],
                                                 config["emb_dim"])
        glove_vocab = load_glove_vocab(glove_path, config["glove_name"])
        glove_vocab = glove_vocab & {word.lower() for word in glove_vocab}
        word_vocab = [word for word in word_vocab if word in glove_vocab]
        word_dict, char_dict = build_vocabulary(word_vocab, char_vocab)
        tmp_word_dict = word_dict.copy()
        del tmp_word_dict[UNK], tmp_word_dict[NUM], tmp_word_dict[END]
        vectors = filter_glove_emb(tmp_word_dict, glove_path,
                                   config["glove_name"], config["emb_dim"])
        np.savez_compressed(config["pretrained_emb"], embeddings=vectors)
    # create indices dataset
    punct_dict = dict([(punct, idx)
                       for idx, punct in enumerate(PUNCTUATION_VOCABULARY)])
    train_set = build_dataset([train_file], word_dict, char_dict, punct_dict,
                              config["max_sequence_len"])
    dev_set = build_dataset([dev_file], word_dict, char_dict, punct_dict,
                            config["max_sequence_len"])
    ref_set = build_dataset([ref_file], word_dict, char_dict, punct_dict,
                            config["max_sequence_len"])
    asr_set = build_dataset([asr_file], word_dict, char_dict, punct_dict,
                            config["max_sequence_len"])
    vocab = {
        "word_dict": word_dict,
        "char_dict": char_dict,
        "tag_dict": punct_dict
    }
    # write to file
    write_json(config["vocab"], vocab)
    write_json(config["train_set"], train_set)
    write_json(config["dev_set"], dev_set)
    write_json(config["ref_set"], ref_set)
    write_json(config["asr_set"], asr_set)
Exemple #4
0
    def train(self):
        self.logger.info(str(self.cfg))
        write_json(os.path.join(self.cfg["checkpoint_path"], "config.json"),
                   self.cfg)

        batch_size = self.cfg["batch_size"]
        epochs = self.cfg["epochs"]
        train_path = self.cfg["train_set"]
        valid_path = self.cfg["valid_set"]
        self.n_gold_spans = count_gold_spans(valid_path)
        valid_set = list(
            self.batcher.batchnize_dataset(valid_path,
                                           batch_size,
                                           shuffle=True))

        best_f1 = -np.inf
        init_lr = self.cfg["lr"]

        self.log_trainable_variables()
        self.logger.info("Start training...")
        self._add_summary()
        for epoch in range(1, epochs + 1):
            self.logger.info('Epoch {}/{}:'.format(epoch, epochs))

            train_set = self.batcher.batchnize_dataset(train_path,
                                                       batch_size,
                                                       shuffle=True)
            _ = self.train_epoch(train_set)

            if self.cfg["use_lr_decay"]:  # learning rate decay
                self.cfg["lr"] = max(
                    init_lr / (1.0 + self.cfg["lr_decay"] * epoch),
                    self.cfg["minimal_lr"])

            eval_metrics = self.evaluate_epoch(valid_set, "valid")
            cur_valid_f1 = eval_metrics[0]

            if cur_valid_f1 > best_f1:
                best_f1 = cur_valid_f1
                self.save_session(epoch)
                self.logger.info(
                    "-- new BEST F1 on valid set: {:>7.2%}".format(best_f1))

        self.train_writer.close()
        self.test_writer.close()
Exemple #5
0
def process_data(dataset, config):
    dataset["train"] = data_formatted(dataset["train"])
    dataset["dev"] = data_formatted(dataset["dev"])
    dataset["test"] = data_formatted(dataset["test"])
    train_data = load_dataset(dataset["train"], config["task_name"])
    dev_data = load_dataset(dataset["dev"], config["task_name"])
    test_data = load_dataset(dataset["test"], config["task_name"])
    if not os.path.exists(config["save_path"]):
        os.makedirs(config["save_path"])
    # build vocabulary
    if not config["use_pretrained"]:
        word_dict = build_word_vocab([train_data, dev_data, test_data])
    else:
        glove_path = config["glove_path"].format(config["glove_name"],
                                                 config["emb_dim"])
        glove_vocab = load_glove_vocab(glove_path, config["glove_name"])
        word_dict = build_word_vocab_pretrained(
            [train_data, dev_data, test_data], glove_vocab)
        vectors = filter_glove_emb(word_dict, glove_path, config["glove_name"],
                                   config["emb_dim"])
        np.savez_compressed(config["pretrained_emb"], embeddings=vectors)
    tag_dict = build_tag_vocab([train_data, dev_data, test_data],
                               config["task_name"])
    # build char dict
    train_data = load_dataset(dataset["train"],
                              config["task_name"],
                              keep_number=True,
                              lowercase=config["char_lowercase"])
    dev_data = load_dataset(dataset["dev"],
                            config["task_name"],
                            keep_number=True,
                            lowercase=config["char_lowercase"])
    test_data = load_dataset(dataset["test"],
                             config["task_name"],
                             keep_number=True,
                             lowercase=config["char_lowercase"])
    char_dict = build_char_vocab([train_data, dev_data, test_data])
    # create indices dataset
    train_set = build_dataset(train_data, word_dict, char_dict, tag_dict)
    dev_set = build_dataset(dev_data, word_dict, char_dict, tag_dict)
    test_set = build_dataset(test_data, word_dict, char_dict, tag_dict)
    vocab = {
        "word_dict": word_dict,
        "char_dict": char_dict,
        "tag_dict": tag_dict
    }

    write_json(os.path.join(config["save_path"], "vocab.json"), vocab)
    write_json(os.path.join(config["save_path"], "train.json"), train_set)
    write_json(os.path.join(config["save_path"], "dev.json"), dev_set)
    write_json(os.path.join(config["save_path"], "test.json"), test_set)

    return (train_set, dev_set, test_set, vocab)
Exemple #6
0
    def eval(self, preprocessor):
        self.logger.info(str(self.cfg))
        data = preprocessor.load_dataset(self.cfg["data_path"],
                                         keep_number=True,
                                         lowercase=self.cfg["char_lowercase"])
        data = data[:self.cfg["data_size"]]
        dataset = preprocessor.build_dataset(data, self.word_dict,
                                             self.char_dict, self.tag_dict)
        write_json(os.path.join(self.cfg["save_path"], "tmp.json"), dataset)
        self.n_gold_spans = count_gold_spans(
            os.path.join(self.cfg["save_path"], "tmp.json"))
        self.logger.info("Target data: %s sentences" % len(dataset))
        del dataset

        batches = list(
            self.batcher.batchnize_dataset(os.path.join(
                self.cfg["save_path"], "tmp.json"),
                                           batch_size=self.cfg["batch_size"],
                                           shuffle=True))
        self.logger.info("Target data: %s batches" % len(batches))
        _ = self.evaluate_epoch(batches, "valid")
Exemple #7
0
def process_data(config):
    # load raw data
    train_data = load_dataset(os.path.join(config["raw_path"], "train.txt"))
    dev_data = load_dataset(os.path.join(config["raw_path"], "valid.txt"))
    test_data = load_dataset(os.path.join(config["raw_path"], "test.txt"))
    # build vocabulary
    word_dict, char_dict, _ = build_vocab([train_data, dev_data])
    *_, tag_dict = build_vocab([train_data, dev_data, test_data])
    # create indices dataset
    train_set = build_dataset(train_data, word_dict, char_dict, tag_dict)
    dev_set = build_dataset(dev_data, word_dict, char_dict, tag_dict)
    test_set = build_dataset(test_data, word_dict, char_dict, tag_dict)
    vocab = {
        "word_dict": word_dict,
        "char_dict": char_dict,
        "tag_dict": tag_dict
    }
    # write to file
    if not os.path.exists(config["save_path"]):
        os.makedirs(config["save_path"])
    write_json(os.path.join(config["save_path"], "vocab.json"), vocab)
    write_json(os.path.join(config["save_path"], "train.json"), train_set)
    write_json(os.path.join(config["save_path"], "dev.json"), dev_set)
    write_json(os.path.join(config["save_path"], "test.json"), test_set)
    def preprocess(self):
        config = self.config
        os.makedirs(config["save_path"], exist_ok=True)

        # List[{'words': List[str], 'tags': List[str]}]
        train_data = self.load_dataset(os.path.join(config["raw_path"],
                                                    "train.json"),
                                       keep_number=False,
                                       lowercase=True)
        valid_data = self.load_dataset(os.path.join(config["raw_path"],
                                                    "valid.json"),
                                       keep_number=False,
                                       lowercase=True)
        train_data = train_data[:config["data_size"]]
        valid_data = valid_data[:config["data_size"]]

        # build vocabulary
        if config["use_pretrained"]:
            glove_path = self.config["glove_path"].format(
                config["glove_name"], config["emb_dim"])
            glove_vocab = self.load_glove_vocab(glove_path,
                                                config["glove_name"])
            word_dict = self.build_word_vocab_pretrained(
                [train_data, valid_data], glove_vocab)
            vectors = self.filter_glove_emb(word_dict, glove_path,
                                            config["glove_name"],
                                            config["emb_dim"])
            np.savez_compressed(config["pretrained_emb"], embeddings=vectors)
        else:
            word_dict = self.build_word_vocab([train_data, valid_data])

        # build tag dict
        tag_dict = self.build_tag_vocab([train_data, valid_data])

        # build char dict
        train_data = self.load_dataset(os.path.join(config["raw_path"],
                                                    "train.json"),
                                       keep_number=True,
                                       lowercase=config["char_lowercase"])
        valid_data = self.load_dataset(os.path.join(config["raw_path"],
                                                    "valid.json"),
                                       keep_number=True,
                                       lowercase=config["char_lowercase"])

        train_data = train_data[:config["data_size"]]
        valid_data = valid_data[:config["data_size"]]

        char_dict = self.build_char_vocab([train_data])

        # create indices dataset
        # List[{'words': List[str], 'chars': List[List[str]], 'tags': List[str]}]
        train_set = self.build_dataset(train_data, word_dict, char_dict,
                                       tag_dict)
        valid_set = self.build_dataset(valid_data, word_dict, char_dict,
                                       tag_dict)
        vocab = {
            "word_dict": word_dict,
            "char_dict": char_dict,
            "tag_dict": tag_dict
        }

        print("Train Sents: %d" % len(train_set))
        print("Valid Sents: %d" % len(valid_set))

        # write to file
        write_json(os.path.join(config["save_path"], "vocab.json"), vocab)
        write_json(os.path.join(config["save_path"], "train.json"), train_set)
        write_json(os.path.join(config["save_path"], "valid.json"), valid_set)
    def save_nearest_spans(self, data_name, preprocessor, print_knn):
        self.logger.info(str(self.cfg))

        ########################
        # Load validation data #
        ########################
        valid_data = preprocessor.load_dataset(
            self.cfg["data_path"],
            keep_number=True,
            lowercase=self.cfg["char_lowercase"])
        valid_data = valid_data[:self.cfg["data_size"]]
        dataset = preprocessor.build_dataset(valid_data, self.word_dict,
                                             self.char_dict, self.tag_dict)
        dataset_path = os.path.join(self.cfg["save_path"], "tmp.json")
        write_json(dataset_path, dataset)
        self.logger.info("Valid sentences: {:>7}".format(len(dataset)))
        self.n_gold_spans = count_gold_spans(dataset_path)

        ######################
        # Load training data #
        ######################
        train_sents = load_json(self.cfg["train_set"])
        if self.cfg["knn_sampling"] == "random":
            train_sent_ids = [sent_id for sent_id in range(len(train_sents))]
        else:
            train_sent_ids = None
        train_data = preprocessor.load_dataset(os.path.join(
            self.cfg["raw_path"], "train.json"),
                                               keep_number=True,
                                               lowercase=False)
        self.logger.info("Train sentences: {:>7}".format(len(train_sents)))

        #############
        # Main loop #
        #############
        correct = 0
        p_total = 0
        start_time = time.time()
        file_path = os.path.join(self.cfg["checkpoint_path"],
                                 "%s.nearest_spans.txt" % data_name)
        fout_txt = open(file_path, "w")
        print("PREDICTION START")
        for record, data in zip(valid_data, dataset):
            valid_sent_id = record["sent_id"]
            batch = self.make_one_batch_for_target(data, valid_sent_id)

            if (valid_sent_id + 1) % 100 == 0:
                print("%d" % (valid_sent_id + 1), flush=True, end=" ")

            #####################
            # Sentence sampling #
            #####################
            batch, sampled_sent_ids = self._make_batch_and_sample_sent_ids(
                batch, record, train_sents, train_sent_ids)

            ##############
            # Prediction #
            ##############
            feed_dict = self._get_feed_dict(batch)
            batch_sims, batch_preds = self.sess.run(
                [self.similarity, self.predicts], feed_dict)

            crr_i, p_total_i = count_gold_and_system_outputs(
                batch["tags"], batch_preds, NULL_LABEL_ID)
            correct += crr_i
            p_total += p_total_i

            ####################
            # Write the result #
            ####################
            self._write_predictions(fout_txt, record)
            self._write_nearest_spans(fout_txt, record, train_data,
                                      sampled_sent_ids, batch_sims,
                                      batch_preds, print_knn)

        fout_txt.close()

        p, r, f = f_score(correct, p_total, self.n_gold_spans)
        self.logger.info("-- Time: %f seconds" % (time.time() - start_time))
        self.logger.info(
            "-- {} set\tF:{:>7.2%} P:{:>7.2%} ({:>5}/{:>5}) R:{:>7.2%} ({:>5}/{:>5})"
            .format(data_name, f, p, correct, p_total, r, correct,
                    self.n_gold_spans))
    def save_predicted_bio_tags(self, data_name, preprocessor):
        self.logger.info(str(self.cfg))

        ########################
        # Load validation data #
        ########################
        valid_data = preprocessor.load_dataset(
            self.cfg["data_path"],
            keep_number=True,
            lowercase=self.cfg["char_lowercase"])
        valid_data = valid_data[:self.cfg["data_size"]]
        dataset = preprocessor.build_dataset(valid_data, self.word_dict,
                                             self.char_dict, self.tag_dict)
        dataset_path = os.path.join(self.cfg["save_path"], "tmp.json")
        write_json(dataset_path, dataset)
        self.logger.info("Valid sentences: {:>7}".format(len(dataset)))

        ######################
        # Load training data #
        ######################
        train_sents = load_json(self.cfg["train_set"])
        if self.cfg["knn_sampling"] == "random":
            train_sent_ids = [sent_id for sent_id in range(len(train_sents))]
        else:
            train_sent_ids = None
        self.logger.info("Train sentences: {:>7}".format(len(train_sents)))

        #############
        # Main loop #
        #############
        start_time = time.time()
        path = os.path.join(self.cfg["checkpoint_path"],
                            "%s.bio.txt" % data_name)
        fout_txt = open(path, "w")
        print("PREDICTION START")
        for record, data in zip(valid_data, dataset):
            valid_sent_id = record["sent_id"]
            batch = self.make_one_batch_for_target(data,
                                                   valid_sent_id,
                                                   add_tags=False)
            if (valid_sent_id + 1) % 100 == 0:
                print("%d" % (valid_sent_id + 1), flush=True, end=" ")

            #####################
            # Sentence sampling #
            #####################
            batch, sampled_sent_ids = self._make_batch_and_sample_sent_ids(
                batch, record, train_sents, train_sent_ids)

            ###############
            # KNN predict #
            ###############
            feed_dict = self._get_feed_dict(batch)
            proba = self.sess.run([self.marginal_proba], feed_dict)[0][0]

            ########################
            # Make predicted spans #
            ########################
            words = record["words"]
            triples = greedy_search(proba,
                                    n_words=len(words),
                                    max_span_len=self.max_span_len,
                                    null_label_id=NULL_LABEL_ID)
            pred_bio_tags = span2bio(spans=triples,
                                     n_words=len(words),
                                     tag_dict=self.rev_tag_dict)
            gold_bio_tags = span2bio(spans=record["tags"], n_words=len(words))
            assert len(words) == len(pred_bio_tags) == len(gold_bio_tags)

            ####################
            # Write the result #
            ####################
            for word, gold_tag, pred_tag in zip(words, gold_bio_tags,
                                                pred_bio_tags):
                fout_txt.write("%s _ %s %s\n" % (word, gold_tag, pred_tag))
            fout_txt.write("\n")

        self.logger.info("-- Time: %f seconds\nFINISHED." %
                         (time.time() - start_time))
    def save_predicted_spans(self, data_name, preprocessor):
        self.logger.info(str(self.cfg))

        ########################
        # Load validation data #
        ########################
        valid_data = preprocessor.load_dataset(
            self.cfg["data_path"],
            keep_number=True,
            lowercase=self.cfg["char_lowercase"])
        valid_data = valid_data[:self.cfg["data_size"]]
        dataset = preprocessor.build_dataset(valid_data, self.word_dict,
                                             self.char_dict, self.tag_dict)
        dataset_path = os.path.join(self.cfg["save_path"], "tmp.json")
        write_json(dataset_path, dataset)
        self.logger.info("Valid sentences: {:>7}".format(len(dataset)))

        ######################
        # Load training data #
        ######################
        train_sents = load_json(self.cfg["train_set"])
        if self.cfg["knn_sampling"] == "random":
            train_sent_ids = [sent_id for sent_id in range(len(train_sents))]
        else:
            train_sent_ids = None
        self.logger.info("Train sentences: {:>7}".format(len(train_sents)))

        #############
        # Main loop #
        #############
        start_time = time.time()
        results = []
        print("PREDICTION START")
        for record, data in zip(valid_data, dataset):
            valid_sent_id = record["sent_id"]
            batch = self.make_one_batch_for_target(data,
                                                   valid_sent_id,
                                                   add_tags=False)
            if (valid_sent_id + 1) % 100 == 0:
                print("%d" % (valid_sent_id + 1), flush=True, end=" ")

            #####################
            # Sentence sampling #
            #####################
            batch, sampled_sent_ids = self._make_batch_and_sample_sent_ids(
                batch, record, train_sents, train_sent_ids)

            ###############
            # KNN predict #
            ###############
            feed_dict = self._get_feed_dict(batch)
            batch_preds = self.sess.run([self.predicts], feed_dict)[0]
            preds = batch_preds[0]

            ########################
            # Make predicted spans #
            ########################
            indx_i, indx_j = get_span_indices(n_words=len(record["words"]),
                                              max_span_len=self.max_span_len)
            assert len(preds) == len(indx_i) == len(indx_j)
            pred_spans = [[
                self.rev_tag_dict[pred_label_id],
                int(i), int(j)
            ] for pred_label_id, i, j in zip(preds, indx_i, indx_j)
                          if pred_label_id != NULL_LABEL_ID]

            ##################
            # Add the result #
            ##################
            results.append({
                "sent_id": valid_sent_id,
                "words": record["words"],
                "spans": pred_spans,
                "train_sent_ids": sampled_sent_ids
            })

        path = os.path.join(self.cfg["checkpoint_path"],
                            "%s.predicted_spans.json" % data_name)
        write_json(path, results)
        self.logger.info("-- Time: %f seconds\nFINISHED." %
                         (time.time() - start_time))
Exemple #12
0
    def save_span_representation(self, data_name, preprocessor):
        self.logger.info(str(self.cfg))

        ########################
        # Load validation data #
        ########################
        valid_data = preprocessor.load_dataset(
            self.cfg["data_path"],
            keep_number=True,
            lowercase=self.cfg["char_lowercase"])
        valid_data = valid_data[:self.cfg["data_size"]]
        dataset = preprocessor.build_dataset(valid_data, self.word_dict,
                                             self.char_dict, self.tag_dict)
        dataset_path = os.path.join(self.cfg["save_path"], "tmp.json")
        write_json(dataset_path, dataset)
        self.logger.info("Valid sentences: {:>7}".format(len(dataset)))

        #############
        # Main loop #
        #############
        start_time = time.time()
        results = []
        fout_hdf5 = h5py.File(
            os.path.join(self.cfg["checkpoint_path"],
                         "%s.span_reps.hdf5" % data_name), 'w')
        print("PREDICTION START")
        for record, data in zip(valid_data, dataset):
            valid_sent_id = record["sent_id"]
            batch = self.batcher.make_each_batch(
                batch_words=[data["words"]],
                batch_chars=[data["chars"]],
                max_span_len=self.max_span_len,
                batch_tags=[data["tags"]])

            if (valid_sent_id + 1) % 100 == 0:
                print("%d" % (valid_sent_id + 1), flush=True, end=" ")

            #################
            # Predict spans #
            #################
            feed_dict = self._get_feed_dict(batch)
            preds, span_reps = self.sess.run([self.predicts, self.span_rep],
                                             feed_dict=feed_dict)
            golds = batch["tags"][0]
            preds = preds[0]
            span_reps = span_reps[0]
            assert len(span_reps) == len(golds) == len(preds)

            ########################
            # Make predicted spans #
            ########################
            indx_i, indx_j = get_span_indices(n_words=len(record["words"]),
                                              max_span_len=self.max_span_len)
            assert len(preds) == len(indx_i) == len(indx_j)
            pred_spans = [[self.rev_tag_dict[label_id],
                           int(i), int(j)]
                          for label_id, i, j in zip(preds, indx_i, indx_j)]
            gold_spans = [[self.rev_tag_dict[label_id],
                           int(i), int(j)]
                          for label_id, i, j in zip(golds, indx_i, indx_j)]

            ####################
            # Write the result #
            ####################
            fout_hdf5.create_dataset(name='{}'.format(valid_sent_id),
                                     dtype='float32',
                                     data=span_reps)
            results.append({
                "sent_id": valid_sent_id,
                "words": record["words"],
                "gold_spans": gold_spans,
                "pred_spans": pred_spans
            })
        fout_hdf5.close()
        write_json(
            os.path.join(self.cfg["checkpoint_path"],
                         "%s.spans.json" % data_name), results)
        self.logger.info("-- Time: %f seconds\nFINISHED." %
                         (time.time() - start_time))