def _get_nearest_spans(_sampled_train_sents): _nearest_spans = [] _prev_indx = 0 _temp_indx = 0 for _record in _sampled_train_sents: _indx_i, _indx_j = get_span_indices( n_words=len(_record["words"]), max_span_len=self.max_span_len) _temp_indx += len(_indx_i) _temp_scores = scores[_prev_indx:_temp_indx] assert len(_temp_scores) == len(_indx_i) == len(_indx_j) _nearest_spans.extend( get_scores_and_spans(spans=_record["tags"], scores=_temp_scores, sent_id=_record["sent_id"], indx_i=_indx_i, indx_j=_indx_j)) _prev_indx = _temp_indx return _nearest_spans
def predict_on_command_line(self, preprocessor): def _load_glove(glove_path): vocab = {} vectors = [] total = int(4e5) with codecs.open(glove_path, mode='r', encoding='utf-8') as f: for line in tqdm(f, total=total, desc="Load glove"): line = line.lstrip().rstrip().split(" ") vocab[line[0]] = len(vocab) vectors.append([float(x) for x in line[1:]]) assert len(vocab) == len(vectors) return vocab, np.asarray(vectors) def _mean_vectors(sents, emb, vocab): unk_vec = np.zeros(emb.shape[1]) mean_vecs = [] for words in sents: vecs = [] for word in words: word = word.lower() if word in vocab: vec = emb[vocab[word]] else: vec = unk_vec vecs.append(vec) mean_vecs.append(np.mean(vecs, axis=0)) return mean_vecs def _cosine_sim(p0, p1): d = (norm(p0) * norm(p1)) if d > 0: return np.dot(p0, p1) / d return 0.0 def _setup_repository(_train_sents, _train_data=None): if self.cfg["knn_sampling"] == "random": _train_sent_ids = [ _sent_id for _sent_id in range(len(_train_sents)) ] _vocab = _glove = _train_embs = None else: _train_sent_ids = None _vocab, _glove = _load_glove("data/emb/glove.6B.100d.txt") _train_words = [[w.lower() for w in _train_record["words"]] for _train_record in _train_data] _train_embs = _mean_vectors(_train_words, _glove, _vocab) return _train_sent_ids, _train_embs, _vocab, _glove def _make_ids(_words): _char_ids = [] _word_ids = [] for word in _words: _char_ids.append([ self.char_dict[char] if char in self.char_dict else self.char_dict[UNK] for char in word ]) word = word_convert(word, keep_number=False, lowercase=True) _word_ids.append(self.word_dict[word] if word in self.word_dict else self.word_dict[UNK]) return _char_ids, _word_ids def _retrieve_knn_train_sents(_record, _train_embs, _vocab, _glove): test_words = [w.lower() for w in _record["words"]] test_emb = _mean_vectors([test_words], _glove, _vocab)[0] sim = [ _cosine_sim(train_emb, test_emb) for train_emb in _train_embs ] arg_sort = np.argsort(sim)[::-1][:self.cfg["k"]] _record["train_sent_ids"] = [int(arg) for arg in arg_sort] return _record def _get_nearest_spans(_sampled_train_sents, _scores): _nearest_spans = [] _prev_indx = 0 _temp_indx = 0 for _record in _sampled_train_sents: _indx_i, _indx_j = get_span_indices( n_words=len(_record["words"]), max_span_len=self.max_span_len) _temp_indx += len(_indx_i) _temp_scores = _scores[_prev_indx:_temp_indx] assert len(_temp_scores) == len(_indx_i) == len(_indx_j) _nearest_spans.extend( get_scores_and_spans(spans=_record["tags"], scores=_temp_scores, sent_id=_record["sent_id"], indx_i=_indx_i, indx_j=_indx_j)) _prev_indx = _temp_indx _nearest_spans.sort(key=lambda span: span[-1], reverse=True) return _nearest_spans ###################### # Load training data # ###################### train_sents = load_json(self.cfg["train_set"]) train_data = preprocessor.load_dataset(os.path.join( self.cfg["raw_path"], "train.json"), keep_number=True, lowercase=False) train_sent_ids, train_embs, vocab, glove = _setup_repository( train_sents, train_data) ######################################## # Load each sentence from command line # ######################################## print("\nPREDICTION START\n") while True: sentence = input('\nEnter a tokenized sentence: ') words = sentence.split() char_ids, word_ids = _make_ids(words) data = {"words": word_ids, "chars": char_ids} record = {"sent_id": 0, "words": words, "train_sent_ids": None} batch = self.make_one_batch_for_target(data, sent_id=0, add_tags=False) ##################### # Sentence sampling # ##################### if self.cfg["knn_sampling"] == "knn": record = _retrieve_knn_train_sents(record, train_embs, vocab, glove) batch, sampled_sent_ids = self._make_batch_and_sample_sent_ids( batch, record, train_sents, train_sent_ids) ############## # Prediction # ############## feed_dict = self._get_feed_dict(batch) batch_sims, batch_preds = self.sess.run( [self.similarity, self.predicts], feed_dict) #################### # Write the result # #################### sims = batch_sims[0] # 1D: n_spans, 2D: n_instances preds = batch_preds[0] # 1D: n_spans indx_i, indx_j = get_span_indices(n_words=len(record["words"]), max_span_len=self.max_span_len) assert len(sims) == len(preds) == len(indx_i) == len(indx_j) sampled_train_sents = [ train_data[sent_id] for sent_id in sampled_sent_ids ] for scores, pred_label_id, i, j in zip(sims, preds, indx_i, indx_j): if pred_label_id == NULL_LABEL_ID: continue pred_label = self.rev_tag_dict[pred_label_id] print("#(%d,%d) || %s || %s" % (i, j, " ".join(record["words"][i:j + 1]), pred_label)) nearest_spans = _get_nearest_spans(sampled_train_sents, scores) for k, (r, _sent_id, a, b, _score) in enumerate(nearest_spans[:5]): train_words = train_data[_sent_id]["words"] if a - 5 < 0: left_context = "" else: left_context = " ".join(train_words[a - 5:a]) left_context = "... " + left_context right_context = " ".join(train_words[b + 1:b + 6]) if b + 6 < len(train_words): right_context = right_context + " ..." mention = " ".join(train_words[a:b + 1]) text = "{}: {} [{}] {}".format(r, left_context, mention, right_context) print("## %d %s" % (k, text))
def save_predicted_spans(self, data_name, preprocessor): self.logger.info(str(self.cfg)) ######################## # Load validation data # ######################## valid_data = preprocessor.load_dataset( self.cfg["data_path"], keep_number=True, lowercase=self.cfg["char_lowercase"]) valid_data = valid_data[:self.cfg["data_size"]] dataset = preprocessor.build_dataset(valid_data, self.word_dict, self.char_dict, self.tag_dict) dataset_path = os.path.join(self.cfg["save_path"], "tmp.json") write_json(dataset_path, dataset) self.logger.info("Valid sentences: {:>7}".format(len(dataset))) ###################### # Load training data # ###################### train_sents = load_json(self.cfg["train_set"]) if self.cfg["knn_sampling"] == "random": train_sent_ids = [sent_id for sent_id in range(len(train_sents))] else: train_sent_ids = None self.logger.info("Train sentences: {:>7}".format(len(train_sents))) ############# # Main loop # ############# start_time = time.time() results = [] print("PREDICTION START") for record, data in zip(valid_data, dataset): valid_sent_id = record["sent_id"] batch = self.make_one_batch_for_target(data, valid_sent_id, add_tags=False) if (valid_sent_id + 1) % 100 == 0: print("%d" % (valid_sent_id + 1), flush=True, end=" ") ##################### # Sentence sampling # ##################### batch, sampled_sent_ids = self._make_batch_and_sample_sent_ids( batch, record, train_sents, train_sent_ids) ############### # KNN predict # ############### feed_dict = self._get_feed_dict(batch) batch_preds = self.sess.run([self.predicts], feed_dict)[0] preds = batch_preds[0] ######################## # Make predicted spans # ######################## indx_i, indx_j = get_span_indices(n_words=len(record["words"]), max_span_len=self.max_span_len) assert len(preds) == len(indx_i) == len(indx_j) pred_spans = [[ self.rev_tag_dict[pred_label_id], int(i), int(j) ] for pred_label_id, i, j in zip(preds, indx_i, indx_j) if pred_label_id != NULL_LABEL_ID] ################## # Add the result # ################## results.append({ "sent_id": valid_sent_id, "words": record["words"], "spans": pred_spans, "train_sent_ids": sampled_sent_ids }) path = os.path.join(self.cfg["checkpoint_path"], "%s.predicted_spans.json" % data_name) write_json(path, results) self.logger.info("-- Time: %f seconds\nFINISHED." % (time.time() - start_time))
def _write_nearest_spans(self, fout_txt, record, train_data, sampled_sent_ids, batch_sims, batch_preds, print_knn): def _write_train_sents(_sampled_train_sents): for _train_record in _sampled_train_sents: fout_txt.write("--kNN:%d || %s || %s\n" % (_train_record["sent_id"], " ".join( _train_record["words"]), " ".join([ "(%s,%d,%d)" % (r, i, j) for (r, i, j) in _train_record["tags"] ]))) def _write_gold_and_pred_spans(_record, _pred_label_id, _span_boundaries): if (i, j) in _span_boundaries: _index = _span_boundaries.index((i, j)) gold_label = _record["tags"][_index][0] else: gold_label = "O" pred_label = self.rev_tag_dict[_pred_label_id] fout_txt.write("##(%d,%d) || %s || %s || %s\n" % (i, j, " ".join( record["words"][i:j + 1]), pred_label, gold_label)) def _get_nearest_spans(_sampled_train_sents): _nearest_spans = [] _prev_indx = 0 _temp_indx = 0 for _record in _sampled_train_sents: _indx_i, _indx_j = get_span_indices( n_words=len(_record["words"]), max_span_len=self.max_span_len) _temp_indx += len(_indx_i) _temp_scores = scores[_prev_indx:_temp_indx] assert len(_temp_scores) == len(_indx_i) == len(_indx_j) _nearest_spans.extend( get_scores_and_spans(spans=_record["tags"], scores=_temp_scores, sent_id=_record["sent_id"], indx_i=_indx_i, indx_j=_indx_j)) _prev_indx = _temp_indx return _nearest_spans def _write_nearest_spans_for_each_span(_sampled_train_sents): nearest_spans = _get_nearest_spans(_sampled_train_sents) nearest_spans.sort(key=lambda span: span[-1], reverse=True) for rank, (r, sent_id, i, j, score) in enumerate(nearest_spans[:10]): mention = " ".join(train_data[sent_id]["words"][i:j + 1]) text = "{} || {} || sent:{} || ({},{}) || {:.3g}".format( r, mention, sent_id, i, j, score) fout_txt.write("####RANK:%d %s\n" % (rank, text)) sampled_train_sents = [ train_data[sent_id] for sent_id in sampled_sent_ids ] if print_knn: _write_train_sents(sampled_train_sents) sims = batch_sims[0] # 1D: n_spans, 2D: n_instances preds = batch_preds[0] # 1D: n_spans indx_i, indx_j = get_span_indices(n_words=len(record["words"]), max_span_len=self.max_span_len) span_boundaries = [(i, j) for _, i, j in record["tags"]] assert len(sims) == len(preds) == len(indx_i) == len(indx_j) for scores, pred_label_id, i, j in zip(sims, preds, indx_i, indx_j): if pred_label_id == NULL_LABEL_ID and (i, j) not in span_boundaries: continue _write_gold_and_pred_spans(record, pred_label_id, span_boundaries) _write_nearest_spans_for_each_span(sampled_train_sents) fout_txt.write("\n")
def save_span_representation(self, data_name, preprocessor): self.logger.info(str(self.cfg)) ######################## # Load validation data # ######################## valid_data = preprocessor.load_dataset( self.cfg["data_path"], keep_number=True, lowercase=self.cfg["char_lowercase"]) valid_data = valid_data[:self.cfg["data_size"]] dataset = preprocessor.build_dataset(valid_data, self.word_dict, self.char_dict, self.tag_dict) dataset_path = os.path.join(self.cfg["save_path"], "tmp.json") write_json(dataset_path, dataset) self.logger.info("Valid sentences: {:>7}".format(len(dataset))) ############# # Main loop # ############# start_time = time.time() results = [] fout_hdf5 = h5py.File( os.path.join(self.cfg["checkpoint_path"], "%s.span_reps.hdf5" % data_name), 'w') print("PREDICTION START") for record, data in zip(valid_data, dataset): valid_sent_id = record["sent_id"] batch = self.batcher.make_each_batch( batch_words=[data["words"]], batch_chars=[data["chars"]], max_span_len=self.max_span_len, batch_tags=[data["tags"]]) if (valid_sent_id + 1) % 100 == 0: print("%d" % (valid_sent_id + 1), flush=True, end=" ") ################# # Predict spans # ################# feed_dict = self._get_feed_dict(batch) preds, span_reps = self.sess.run([self.predicts, self.span_rep], feed_dict=feed_dict) golds = batch["tags"][0] preds = preds[0] span_reps = span_reps[0] assert len(span_reps) == len(golds) == len(preds) ######################## # Make predicted spans # ######################## indx_i, indx_j = get_span_indices(n_words=len(record["words"]), max_span_len=self.max_span_len) assert len(preds) == len(indx_i) == len(indx_j) pred_spans = [[self.rev_tag_dict[label_id], int(i), int(j)] for label_id, i, j in zip(preds, indx_i, indx_j)] gold_spans = [[self.rev_tag_dict[label_id], int(i), int(j)] for label_id, i, j in zip(golds, indx_i, indx_j)] #################### # Write the result # #################### fout_hdf5.create_dataset(name='{}'.format(valid_sent_id), dtype='float32', data=span_reps) results.append({ "sent_id": valid_sent_id, "words": record["words"], "gold_spans": gold_spans, "pred_spans": pred_spans }) fout_hdf5.close() write_json( os.path.join(self.cfg["checkpoint_path"], "%s.spans.json" % data_name), results) self.logger.info("-- Time: %f seconds\nFINISHED." % (time.time() - start_time))