def run(): tokenizer = get_tokenizer() spr = StreamPickleReader("contradiction_prediction") html = HtmlVisualizer("contradiction_prediction.html") cnt = 0 while spr.has_next(): item = spr.get_item() e, p = item input_ids, _, _ = e logit, explain = p tokens = tokenizer.convert_ids_to_tokens(input_ids) p, h = split_p_h_with_input_ids(tokens, input_ids) p_score, h_score = split_p_h_with_input_ids(explain, input_ids) p_score = normalize(p_score) h_score = normalize(h_score) p_cells = [Cell("P:")] + cells_from_tokens(p, p_score) h_cells = [Cell("H:")] + cells_from_tokens(h, h_score) html.write_paragraph(str(logit)) html.multirow_print(p_cells) html.multirow_print(h_cells) if cnt > 100: break cnt += 1
def show_tfrecord(file_path): itr = load_record_v2(file_path) tokenizer = get_tokenizer() name = os.path.basename(file_path) html = HtmlVisualizer(name + ".html") for features in itr: input_ids = take(features["input_ids"]) alt_emb_mask = take(features["alt_emb_mask"]) tokens = tokenizer.convert_ids_to_tokens(input_ids) p_tokens, h_tokens = split_p_h_with_input_ids(tokens, input_ids) p_mask, h_mask = split_p_h_with_input_ids(alt_emb_mask, input_ids) p_cells = [ Cell(p_tokens[i], 100 if p_mask[i] else 0) for i in range(len(p_tokens)) ] h_cells = [ Cell(h_tokens[i], 100 if h_mask[i] else 0) for i in range(len(h_tokens)) ] label = take(features["label_ids"])[0] html.write_paragraph("Label : {}".format(label)) html.write_table([p_cells]) html.write_table([h_cells])
def count_contradiction(data): cont = Counter() n_cont = Counter() df = Counter() pred_count = Counter() high_count = 0 for entry in data: logits = entry.get_vector("logits") input_ids = entry.get_vector("input_ids") tokens = entry.get_tokens("input_ids") probs = softmax(logits) pred = np.argmax(probs) pred_count[pred] += 1 if probs[2] > 0.5: counter = cont if high_count < 100: p, h = split_p_h_with_input_ids(tokens, input_ids) if valid_condition(p, h): high_count += 1 print(probs[2]) print("P:" + pretty_tokens(p, True)) print("H:" + pretty_tokens(h, True)) else: counter = n_cont for t in tokens: if t == "[PAD]": break df[t] += 1 counter[t] += 1 return cont, high_count, n_cont, pred_count
def write_feature_to_html(feature, html, tokenizer): input_ids = take(feature['input_ids']) label_ids = take(feature['label_ids']) seg1, seg2 = split_p_h_with_input_ids(input_ids, input_ids) text1 = tokenizer.convert_ids_to_tokens(seg1) text2 = tokenizer.convert_ids_to_tokens(seg2) text1 = pretty_tokens(text1, True) text2 = pretty_tokens(text2, True) html.write_headline("{}".format(label_ids[0])) html.write_paragraph(text1) html.write_paragraph(text2)
def show_prediction(filename, file_path, correctness_1, correctness_2): data = EstimatorPredictionViewerGosford(filename) itr = load_record_v2(file_path) tokenizer = get_tokenizer() name = os.path.basename(filename) html = HtmlVisualizer(name + ".html") idx = 0 for entry in data: features = itr.__next__() input_ids = entry.get_vector("input_ids") input_ids2 = take(features["input_ids"]) assert np.all(input_ids == input_ids2) alt_emb_mask = take(features["alt_emb_mask"]) tokens = tokenizer.convert_ids_to_tokens(input_ids) p_tokens, h_tokens = split_p_h_with_input_ids(tokens, input_ids) p_mask, h_mask = split_p_h_with_input_ids(alt_emb_mask, input_ids) p_cells = [ Cell(p_tokens[i], 100 if p_mask[i] else 0) for i in range(len(p_tokens)) ] h_cells = [ Cell(h_tokens[i], 100 if h_mask[i] else 0) for i in range(len(h_tokens)) ] label = take(features["label_ids"])[0] logits = entry.get_vector("logits") pred = np.argmax(logits) if not correctness_1[idx] or not correctness_2[idx]: html.write_paragraph("Label : {} Correct: {}/{}".format( label, correctness_1[idx], correctness_2[idx])) html.write_table([p_cells]) html.write_table([h_cells]) idx += 1
def view_entailment(data_name): data, pickle_path = load_data(data_name) for entry in data: logits = entry.get_vector("logits") input_ids = entry.get_vector("input_ids") tokens = entry.get_tokens("input_ids") probs = softmax(logits) pred = np.argmax(probs) if probs[0] > 0.5: p, h = split_p_h_with_input_ids(tokens, input_ids) print("P:" + pretty_tokens(p, True)) print("H:" + pretty_tokens(h, True)) print()
def save_contradiction_pred(data): entries = [] for entry in data: logits = entry.get_vector("logits") input_ids = entry.get_vector("input_ids") tokens = entry.get_tokens("input_ids") probs = softmax(logits) if probs[2] > 0.5: p, h = split_p_h_with_input_ids(tokens, input_ids) if valid_condition(p, h): entries.append(probs[2]) if len(entries) == 100: break save_to_pickle(entries, "cont_model_0")
def transform_datapoint(data_point): input_ids = data_point['input_ids'] max_seq_length = len(input_ids) assert max_seq_length == 200 p, h = split_p_h_with_input_ids(input_ids, input_ids) segment_ids = (2+len(p)) * [0] + (1+len(h)) * [1] input_mask = (3+len(p)+len(h)) * [1] while len(segment_ids) < max_seq_length: input_mask.append(0) segment_ids.append(0) features = collections.OrderedDict() features["input_ids"] = create_int_feature(input_ids) features["input_mask"] = create_int_feature(input_mask) features["segment_ids"] = create_int_feature(segment_ids) features["label_ids"] = create_int_feature([data_point['label']]) return features
def main(): file_path = sys.argv[1] name = os.path.basename(file_path) viewer = EstimatorPredictionViewer(file_path) html = HtmlVisualizer("toke_score_gold.html") stopwords = load_stopwords_for_query() skip = 10 for entry_idx, entry in enumerate(viewer): if entry_idx % skip != 0: continue tokens = entry.get_tokens("input_ids") input_ids = entry.get_vector("input_ids") label_ids = entry.get_vector("label_ids") label_ids = np.reshape(label_ids, [-1, 2]) log_label_ids = np.log(label_ids + 1e-10) seg1, seg2 = split_p_h_with_input_ids(tokens, input_ids) pad_idx = tokens.index("[PAD]") assert pad_idx > 0 logits = entry.get_vector("logits") cells = [] cells2 = [] for idx in range(pad_idx): probs = label_ids[idx] token = tokens[idx] score = probs[0] color = "B" if score > 0 else "R" highlight_score = min(abs(score) * 10000, 100) if token in stopwords: highlight_score = 0 if token in seg1: highlight_score = 50 color = "G" c = Cell(token, highlight_score=highlight_score, target_color=color) cells.append(c) html.multirow_print_from_cells_list([cells, cells2]) if entry_idx > 10000: break
def collect_passage_tokens(pid_entries)\ -> Dict[Tuple[str, int], List[int]]: passage_tokens_dict = {} for pid, entries in pid_entries: for doc_idx, entry in enumerate(entries): input_ids = entry['input_ids2'] key = entry['kdp'].doc_id, entry['kdp'].passage_idx try: _, passage_tokens = split_p_h_with_input_ids(input_ids, input_ids) if key in passage_tokens_dict: a = passage_tokens_dict[key] b = passage_tokens if str(a[:200]) != str(b[:200]): print(key) print(str(a[:200])) print(str(b[:200])) passage_tokens_dict[key] = passage_tokens except UnboundLocalError: print(input_ids) return passage_tokens_dict
def save_contradiction(data): f = open(os.path.join(output_path, "cont_annot.csv"), "w", encoding="utf-8", newline="") writer = csv.writer(f) rows = [] rows.append(("premise", "hypothesis")) input_ids_list = [] for entry in data: logits = entry.get_vector("logits") input_ids = entry.get_vector("input_ids") tokens = entry.get_tokens("input_ids") probs = softmax(logits) pred = np.argmax(probs) if probs[2] > 0.5: p, h = split_p_h_with_input_ids(tokens, input_ids) if valid_condition(p, h): e = (pretty_tokens(p, True), pretty_tokens(h, True)) rows.append(e) input_ids_list.append(input_ids) if len(input_ids_list) == 100: break writer.writerows(rows) save_to_pickle(input_ids_list, "cont_annot_input_ids")
def create_instances(self, input_path, target_topic, target_seq_length): tokenizer = get_tokenizer() doc_top_k = 1000 all_train_data = list(load_record(input_path)) train_data = [] for feature in all_train_data: input_ids = feature["input_ids"].int64_list.value token_id = input_ids[1] topic = token_ids_to_topic[token_id] if target_topic == topic: train_data.append(feature) print("Selected {} from {}".format(len(train_data), len(all_train_data))) doc_dict = load_tokens_for_topic(target_topic) token_doc_list = [] ranked_list = sydney_get_ukp_ranked_list()[target_topic] print("Ranked list contains {} docs, selecting top-{}".format(len(ranked_list), doc_top_k)) doc_ids = [doc_id for doc_id, _, _ in ranked_list[:doc_top_k]] for doc_id in doc_ids: doc = doc_dict[doc_id] token_doc = pool_tokens(doc, target_seq_length) token_doc_list.extend(token_doc) ranker = Ranker() target_tf_list = lmap(ranker.get_terms, token_doc_list) ranker.init_df_from_tf_list(target_tf_list) inv_index = collections.defaultdict(list) for doc_idx, doc_tf in enumerate(target_tf_list): for term in doc_tf: if ranker.df[term] < ranker.N * 0.3: inv_index[term].append(doc_idx) def get_candidate_from_inv_index(inv_index, terms): s = set() for t in terms: s.update(inv_index[t]) return s source_tf_list = [] selected_context = [] for s_idx, feature in enumerate(train_data): input_ids = feature["input_ids"].int64_list.value topic_seg, sent = split_p_h_with_input_ids(input_ids, input_ids) source_tf = ranker.get_terms_from_ids(sent) source_tf_list.append(source_tf) ranked_list = [] candidate_docs = get_candidate_from_inv_index(inv_index, source_tf.keys()) for doc_idx in candidate_docs: target_tf = target_tf_list[doc_idx] score = ranker.bm25(source_tf, target_tf) ranked_list.append((doc_idx, score, target_tf)) ranked_list.sort(key=lambda x: x[1], reverse=True) ranked_list = list(filter_overlap(ranked_list)) ranked_list = ranked_list[:self.max_context] if s_idx < 10: print("--- Source sentence : \n", pretty_tokens(tokenizer.convert_ids_to_tokens(sent), True)) print("-------------------") for rank, (idx, score, target_tf) in enumerate(ranked_list): ranker.bm25(source_tf, target_tf, True) print("Rank#{} {} : ".format(rank, score) + pretty_tokens(token_doc_list[idx], True)) if s_idx % 100 == 0: print(s_idx) contexts = list([token_doc_list[idx] for idx, score, _ in ranked_list]) selected_context.append(contexts) for sent_idx, feature in enumerate(train_data): contexts = selected_context[sent_idx] yield feature, contexts
def recover_subtokens(input_ids) -> List[str]: tokens1, tokens2 = split_p_h_with_input_ids(input_ids, input_ids) return tokenizer.convert_ids_to_tokens(tokens2)