def run(): tokenizer = get_tokenizer() spr = StreamPickleReader("contradiction_prediction") html = HtmlVisualizer("contradiction_prediction.html") cnt = 0 while spr.has_next(): item = spr.get_item() e, p = item input_ids, _, _ = e logit, explain = p tokens = tokenizer.convert_ids_to_tokens(input_ids) p, h = split_p_h_with_input_ids(tokens, input_ids) p_score, h_score = split_p_h_with_input_ids(explain, input_ids) p_score = normalize(p_score) h_score = normalize(h_score) p_cells = [Cell("P:")] + cells_from_tokens(p, p_score) h_cells = [Cell("H:")] + cells_from_tokens(h, h_score) html.write_paragraph(str(logit)) html.multirow_print(p_cells) html.multirow_print(h_cells) if cnt > 100: break cnt += 1
def show_simple(run_name, data_id, tex_visulizer): num_tags = 3 num_select = 20 pickle_name = "save_view_{}_{}".format(run_name, data_id) tokenizer = get_tokenizer() data_loader = get_modified_data_loader2(HPSENLI3(), NLIExTrainConfig()) explain_entries = load_from_pickle(pickle_name) explain_entries = explain_entries selected_instances = [[], [], []] for idx, entry in enumerate(explain_entries): x0, logits, scores = entry pred = np.argmax(logits) input_ids = x0 p, h = data_loader.split_p_h_with_input_ids(input_ids, input_ids) p_tokens = tokenizer.convert_ids_to_tokens(p) h_tokens = tokenizer.convert_ids_to_tokens(h) p_tokens = restore_capital_letter(p_tokens) h_tokens = restore_capital_letter(h_tokens) target_tag = ["match", "mismatch", "conflict"][pred] tag_idx = data_generator.NLI.nli_info.tags.index(target_tag) tag_name = data_generator.NLI.nli_info.tags[tag_idx] p_score, h_score = data_loader.split_p_h_with_input_ids(scores[tag_idx], input_ids) normalize_fn = normalize p_score = normalize_fn(p_score) h_score = normalize_fn(h_score) p_row = [Cell("\\textbf{P:}")] + cells_from_tokens(p_tokens, p_score) h_row = [Cell("\\textbf{H:}")] + cells_from_tokens(h_tokens, h_score) pred_str = ["entailment", "neutral" , "contradiction"][pred] apply_color(p_row, tag_name) apply_color(h_row, tag_name) #tex_visulizer.write_paragraph(pred_str) if len(selected_instances[pred]) < num_select : e = pred_str, [p_row, h_row] #tex_visulizer.write_instance(pred_str, gold_label, [p_row, h_row]) selected_instances[pred].append(e) if all([len(s) == num_select for s in selected_instances]): break for insts in selected_instances: for e in insts: pred_str, rows = e tex_visulizer.write_instance(pred_str, rows) return selected_instances
def loss_view(dir_path): tokenizer = get_tokenizer() html_writer = HtmlVisualizer("ukp_lm_grad_high.html", dark_mode=False) for file_path in get_dir_files(dir_path): items = pickle.load(open(file_path, "rb")) for e in items: input_ids, masked_input_ids, masked_lm_example_loss = e tokens = mask_resolve_1( tokenizer.convert_ids_to_tokens(input_ids), tokenizer.convert_ids_to_tokens(masked_input_ids)) highlight = lmap(is_mask, tokens) cells = cells_from_tokens(tokens, highlight) html_writer.multirow_print(cells)
def print_as_html(fn): examples = load_record(fn) tokenizer = tokenizer_wo_tf.FullTokenizer( os.path.join(data_path, "bert_voca.txt")) html_output = HtmlVisualizer("out_name.html") for feature in examples: masked_inputs = feature["input_ids"].int64_list.value idx = 0 step = 512 while idx < len(masked_inputs): slice = masked_inputs[idx:idx + step] tokens = tokenizer.convert_ids_to_tokens(slice) idx += step cells = cells_from_tokens(tokens) html_output.multirow_print(cells) html_output.write_paragraph("----------")
def main(): print("Loading scores...") cid_grouped: Dict[str, Dict[str, List[Dict]]] = load_cppnc_score_wrap() baseline_cid_grouped = load_baseline("train_baseline") gold = get_claim_perspective_id_dict() tokenizer = get_tokenizer() claim_d = load_train_claim_d() print("Start analyzing") html = HtmlVisualizer("cppnc_value_per_token_score.html") claim_cnt = 0 for cid, pid_entries_d in cid_grouped.items(): pid_entries_d: Dict[str, List[Dict]] = pid_entries_d pid_entries: List[Tuple[str, List[Dict]]] = list(pid_entries_d.items()) baseline_pid_entries = baseline_cid_grouped[int(cid)] baseline_score_d = fetch_score_per_pid(baseline_pid_entries) gold_pids = gold[int(cid)] ret = collect_score_per_doc(baseline_score_d, get_score_from_entry, gold_pids, pid_entries) passage_tokens_d = collect_passage_tokens(pid_entries) doc_info_d: Dict[int, Tuple[str, int]] = ret[0] doc_value_arr: List[List[float]] = ret[1] kdp_result_grouped = defaultdict(list) for doc_idx, doc_values in enumerate(doc_value_arr): doc_id, passage_idx = doc_info_d[doc_idx] avg_score = average(doc_values) kdp_result = doc_id, passage_idx, avg_score kdp_result_grouped[doc_id].append(kdp_result) s = "{} : {}".format(cid, claim_d[int(cid)]) html.write_headline(s) claim_cnt += 1 if claim_cnt > 10: break scores: List[float] = list([r[2] for r in doc_value_arr]) foreach(html.write_paragraph, lmap(str, scores)) for doc_id, kdp_result_list in kdp_result_grouped.items(): html.write_headline(doc_id) tokens, per_token_score = combine_collect_score(tokenizer, doc_id, passage_tokens_d, kdp_result_list) str_tokens = tokenizer.convert_ids_to_tokens(tokens) row = cells_from_tokens(str_tokens) for idx in range(len(str_tokens)): score = per_token_score[idx][0] norm_score = min(abs(score) * 10000, 100) color = "B" if score > 0 else "R" row[idx].highlight_score = norm_score row[idx].target_color = color rows = [row] nth = 0 any_score_found = True while any_score_found: any_score_found = False score_list = [] for idx in range(len(str_tokens)): if nth < len(per_token_score[idx]): score = per_token_score[idx][nth] any_score_found = True else: score = "-" score_list.append(score) def get_cell(score): if score == "-": return Cell("-") else: # 0.01 -> 100 norm_score = min(abs(score) * 10000, 100) color = "B" if score > 0 else "R" return Cell("", highlight_score=norm_score, target_color=color) nth += 1 if any_score_found: row = lmap(get_cell, score_list) rows.append(row) html.multirow_print_from_cells_list(rows)
def cells_from_tokens(self, tokens, scores=None, stop_at_pad=True): return cells_from_tokens(tokens, scores, stop_at_pad)
def show_all(run_name, data_id): num_tags = 3 num_select = 1 pickle_name = "save_view_{}_{}".format(run_name, data_id) tokenizer = get_tokenizer() data_loader = get_modified_data_loader2(HPSENLI3(), NLIExTrainConfig()) explain_entries = load_from_pickle(pickle_name) explain_entries = explain_entries visualizer = HtmlVisualizer(pickle_name + ".html") tex_visulizer = TexTableNLIVisualizer(pickle_name + ".tex") tex_visulizer.begin_table() selected_instances = [[], [], []] for idx, entry in enumerate(explain_entries): x0, logits, scores = entry pred = np.argmax(logits) input_ids = x0 p, h = data_loader.split_p_h_with_input_ids(input_ids, input_ids) p_tokens = tokenizer.convert_ids_to_tokens(p) h_tokens = tokenizer.convert_ids_to_tokens(h) p_rows = [] h_rows = [] p_rows.append(cells_from_tokens(p_tokens)) h_rows.append(cells_from_tokens(h_tokens)) p_score_list = [] h_score_list = [] for j in range(num_tags): tag_name = data_generator.NLI.nli_info.tags[j] p_score, h_score = data_loader.split_p_h_with_input_ids(scores[j], input_ids) normalize_fn = normalize add = True if pred == "0": add = tag_name == "match" if pred == "1": add = tag_name == "mismatch" if pred == "2": add = tag_name == "conflict" def format_scores(raw_scores): def format_float(s): return "{0:.2f}".format(s) norm_scores = normalize_fn(raw_scores) cells = [Cell(format_float(s1), s2, False, False) for s1, s2 in zip(raw_scores, norm_scores)] if tag_name == "mismatch": set_cells_color(cells, "G") elif tag_name == "conflict": set_cells_color(cells, "R") return cells if add: p_rows.append(format_scores(p_score)) h_rows.append(format_scores(h_score)) p_score_list.append(p_score) h_score_list.append(h_score) pred_str = ["Entailment", "Neutral" , "Contradiction"][pred] out_entry = pred_str, p_tokens, h_tokens, p_score_list, h_score_list if len(selected_instances[pred]) < num_select : selected_instances[pred].append(out_entry) visualizer.write_headline(pred_str) visualizer.multirow_print_from_cells_list(p_rows) visualizer.multirow_print_from_cells_list(h_rows) visualizer.write_instance(pred_str, p_rows, h_rows) tex_visulizer.write_paragraph(str(pred)) tex_visulizer.multirow_print_from_cells_list(p_rows, width=13) tex_visulizer.multirow_print_from_cells_list(h_rows, width=13) if all([len(s) == num_select for s in selected_instances]): break tex_visulizer.close_table() return selected_instances
def write_deletion_score_to_html(out_file_name, summarized_table: List[Entry], info: Dict[int, Dict]): text_to_info = claim_text_to_info() html = HtmlVisualizer(out_file_name) tokenizer = get_biobert_tokenizer() num_print = 0 for entry in summarized_table: tokens = tokenizer.convert_ids_to_tokens(entry.input_ids) idx_sep1, idx_sep2 = get_sep_loc(entry.input_ids) max_change = 0 max_drop = 0 cells = cells_from_tokens(tokens) drops = [] for idx in range(len(tokens)): if tokens[idx] == "[PAD]": break if tokens[idx] == '[SEP]': continue if idx in entry.contribution: raw_score = entry.contribution[idx] e = idx, raw_score drops.append(e) drops.sort(key=get_second) _, largest_drop = drops[0] max_drop_idx = -1 max_drop_case_logit = None for idx in range(len(tokens)): if tokens[idx] == "[PAD]": break if tokens[idx] == '[SEP]': continue if idx in entry.contribution: raw_score = entry.contribution[idx] max_change = max(abs(raw_score), max_change) if max_drop > raw_score: max_drop = raw_score max_drop_idx = idx max_drop_case_logit = entry.case_logits_d[idx] if raw_score < 0: score = abs(raw_score / largest_drop) * 200 color = "B" else: score = 0 color = "B" else: score = 150 color = "Gray" cells[idx].highlight_score = score cells[idx].target_color = color if max_change < 0.05 and False: pass else: # if random.random() < 0.90: # continue base_probs = scipy.special.softmax(entry.base_logits) info_entry = info[str(entry.data_id[0])] claim1_info: Dict = text_to_info[info_entry['text1']] claim2_info: Dict = text_to_info[info_entry['text2']] question = claim1_info['question'] assertion1 = claim1_info['assertion'] assertion2 = claim2_info['assertion'] original_prediction_summary = make_prediction_summary_str( base_probs) html.write_bar() html.write_paragraph("Question: {}".format(question)) html.write_paragraph("Original prediction: " + original_prediction_summary) html.write_paragraph("Max drop") rows = [] for idx, score in drops[:5]: row = [Cell(str(idx)), Cell(tokens[idx]), Cell(score)] rows.append(row) html.write_table(rows) min_token = tokens[max_drop_idx] html.write_paragraph("> \"{}\": {} ".format(min_token, max_drop)) max_drop_case_prob = scipy.special.softmax(max_drop_case_logit) max_drop_prediction_summary = make_prediction_summary_str( max_drop_case_prob) html.write_paragraph("> " + max_drop_prediction_summary) p = [Cell("Claim1 ({}):".format(assertion1))] + cells[1:idx_sep1] h = [Cell("Claim2 ({}):".format(assertion2)) ] + cells[idx_sep1 + 1:idx_sep2] html.write_table([p]) html.write_table([h]) num_print += 1 print("printed {} of {}".format(num_print, len(summarized_table)))