示例#1
0
def run():
    tokenizer = get_tokenizer()
    spr = StreamPickleReader("contradiction_prediction")

    html = HtmlVisualizer("contradiction_prediction.html")
    cnt = 0
    while spr.has_next():
        item = spr.get_item()
        e, p = item
        input_ids, _, _ = e
        logit, explain = p
        tokens = tokenizer.convert_ids_to_tokens(input_ids)
        p, h = split_p_h_with_input_ids(tokens, input_ids)
        p_score, h_score = split_p_h_with_input_ids(explain, input_ids)

        p_score = normalize(p_score)
        h_score = normalize(h_score)
        p_cells = [Cell("P:")] + cells_from_tokens(p, p_score)
        h_cells = [Cell("H:")] + cells_from_tokens(h, h_score)

        html.write_paragraph(str(logit))
        html.multirow_print(p_cells)
        html.multirow_print(h_cells)

        if cnt > 100:
            break
        cnt += 1
示例#2
0
def show_simple(run_name, data_id, tex_visulizer):
    num_tags = 3
    num_select = 20
    pickle_name = "save_view_{}_{}".format(run_name, data_id)
    tokenizer = get_tokenizer()

    data_loader = get_modified_data_loader2(HPSENLI3(), NLIExTrainConfig())

    explain_entries = load_from_pickle(pickle_name)
    explain_entries = explain_entries

    selected_instances = [[], [], []]
    for idx, entry in enumerate(explain_entries):
        x0, logits, scores = entry

        pred = np.argmax(logits)
        input_ids = x0
        p, h = data_loader.split_p_h_with_input_ids(input_ids, input_ids)
        p_tokens = tokenizer.convert_ids_to_tokens(p)
        h_tokens = tokenizer.convert_ids_to_tokens(h)

        p_tokens = restore_capital_letter(p_tokens)
        h_tokens = restore_capital_letter(h_tokens)

        target_tag = ["match", "mismatch", "conflict"][pred]

        tag_idx = data_generator.NLI.nli_info.tags.index(target_tag)
        tag_name = data_generator.NLI.nli_info.tags[tag_idx]
        p_score, h_score = data_loader.split_p_h_with_input_ids(scores[tag_idx], input_ids)
        normalize_fn = normalize
        p_score = normalize_fn(p_score)
        h_score = normalize_fn(h_score)
        p_row = [Cell("\\textbf{P:}")] + cells_from_tokens(p_tokens, p_score)
        h_row = [Cell("\\textbf{H:}")] + cells_from_tokens(h_tokens, h_score)

        pred_str = ["entailment", "neutral" , "contradiction"][pred]
        apply_color(p_row, tag_name)
        apply_color(h_row, tag_name)
        #tex_visulizer.write_paragraph(pred_str)
        if len(selected_instances[pred]) < num_select :
            e = pred_str, [p_row, h_row]
            #tex_visulizer.write_instance(pred_str, gold_label, [p_row, h_row])
            selected_instances[pred].append(e)

        if all([len(s) == num_select for s in selected_instances]):
            break

    for insts in selected_instances:
        for e in insts:
            pred_str, rows = e
            tex_visulizer.write_instance(pred_str, rows)

    return selected_instances
示例#3
0
def loss_view(dir_path):
    tokenizer = get_tokenizer()
    html_writer = HtmlVisualizer("ukp_lm_grad_high.html", dark_mode=False)

    for file_path in get_dir_files(dir_path):
        items = pickle.load(open(file_path, "rb"))

        for e in items:
            input_ids, masked_input_ids, masked_lm_example_loss = e
            tokens = mask_resolve_1(
                tokenizer.convert_ids_to_tokens(input_ids),
                tokenizer.convert_ids_to_tokens(masked_input_ids))
            highlight = lmap(is_mask, tokens)

            cells = cells_from_tokens(tokens, highlight)
            html_writer.multirow_print(cells)
示例#4
0
def print_as_html(fn):
    examples = load_record(fn)
    tokenizer = tokenizer_wo_tf.FullTokenizer(
        os.path.join(data_path, "bert_voca.txt"))

    html_output = HtmlVisualizer("out_name.html")

    for feature in examples:
        masked_inputs = feature["input_ids"].int64_list.value
        idx = 0
        step = 512
        while idx < len(masked_inputs):
            slice = masked_inputs[idx:idx + step]
            tokens = tokenizer.convert_ids_to_tokens(slice)
            idx += step
            cells = cells_from_tokens(tokens)
            html_output.multirow_print(cells)
        html_output.write_paragraph("----------")
示例#5
0
def main():
    print("Loading scores...")
    cid_grouped: Dict[str, Dict[str, List[Dict]]] = load_cppnc_score_wrap()
    baseline_cid_grouped = load_baseline("train_baseline")
    gold = get_claim_perspective_id_dict()
    tokenizer = get_tokenizer()
    claim_d = load_train_claim_d()

    print("Start analyzing")
    html = HtmlVisualizer("cppnc_value_per_token_score.html")
    claim_cnt = 0
    for cid, pid_entries_d in cid_grouped.items():
        pid_entries_d: Dict[str, List[Dict]] = pid_entries_d
        pid_entries: List[Tuple[str, List[Dict]]] = list(pid_entries_d.items())
        baseline_pid_entries = baseline_cid_grouped[int(cid)]
        baseline_score_d = fetch_score_per_pid(baseline_pid_entries)
        gold_pids = gold[int(cid)]

        ret = collect_score_per_doc(baseline_score_d, get_score_from_entry, gold_pids,
                                                                  pid_entries)
        passage_tokens_d = collect_passage_tokens(pid_entries)
        doc_info_d: Dict[int, Tuple[str, int]] = ret[0]
        doc_value_arr: List[List[float]] = ret[1]

        kdp_result_grouped = defaultdict(list)
        for doc_idx, doc_values in enumerate(doc_value_arr):
            doc_id, passage_idx = doc_info_d[doc_idx]
            avg_score = average(doc_values)
            kdp_result = doc_id, passage_idx, avg_score
            kdp_result_grouped[doc_id].append(kdp_result)

        s = "{} : {}".format(cid, claim_d[int(cid)])
        html.write_headline(s)
        claim_cnt += 1
        if claim_cnt > 10:
            break

        scores: List[float] = list([r[2] for r in doc_value_arr])

        foreach(html.write_paragraph, lmap(str, scores))

        for doc_id, kdp_result_list in kdp_result_grouped.items():
            html.write_headline(doc_id)
            tokens, per_token_score = combine_collect_score(tokenizer, doc_id, passage_tokens_d, kdp_result_list)
            str_tokens = tokenizer.convert_ids_to_tokens(tokens)
            row = cells_from_tokens(str_tokens)
            for idx in range(len(str_tokens)):
                score = per_token_score[idx][0]
                norm_score = min(abs(score) * 10000, 100)
                color = "B" if score > 0 else "R"
                row[idx].highlight_score = norm_score
                row[idx].target_color = color

            rows = [row]
            nth = 0
            any_score_found = True
            while any_score_found:
                any_score_found = False
                score_list = []
                for idx in range(len(str_tokens)):
                    if nth < len(per_token_score[idx]):
                        score = per_token_score[idx][nth]
                        any_score_found = True
                    else:
                        score = "-"
                    score_list.append(score)

                def get_cell(score):
                    if score == "-":
                        return Cell("-")
                    else:
                        # 0.01 -> 100
                        norm_score = min(abs(score) * 10000, 100)
                        color = "B" if score > 0 else "R"
                        return Cell("", highlight_score=norm_score, target_color=color)

                nth += 1
                if any_score_found:
                    row = lmap(get_cell, score_list)
                    rows.append(row)
            html.multirow_print_from_cells_list(rows)
 def cells_from_tokens(self, tokens, scores=None, stop_at_pad=True):
     return cells_from_tokens(tokens, scores, stop_at_pad)
示例#7
0
def show_all(run_name, data_id):
    num_tags = 3
    num_select = 1
    pickle_name = "save_view_{}_{}".format(run_name, data_id)
    tokenizer = get_tokenizer()

    data_loader = get_modified_data_loader2(HPSENLI3(), NLIExTrainConfig())

    explain_entries = load_from_pickle(pickle_name)
    explain_entries = explain_entries

    visualizer = HtmlVisualizer(pickle_name + ".html")
    tex_visulizer = TexTableNLIVisualizer(pickle_name + ".tex")
    tex_visulizer.begin_table()
    selected_instances = [[], [], []]
    for idx, entry in enumerate(explain_entries):
        x0, logits, scores = entry

        pred = np.argmax(logits)
        input_ids = x0
        p, h = data_loader.split_p_h_with_input_ids(input_ids, input_ids)
        p_tokens = tokenizer.convert_ids_to_tokens(p)
        h_tokens = tokenizer.convert_ids_to_tokens(h)

        p_rows = []
        h_rows = []
        p_rows.append(cells_from_tokens(p_tokens))
        h_rows.append(cells_from_tokens(h_tokens))

        p_score_list = []
        h_score_list = []
        for j in range(num_tags):
            tag_name = data_generator.NLI.nli_info.tags[j]
            p_score, h_score = data_loader.split_p_h_with_input_ids(scores[j], input_ids)
            normalize_fn = normalize

            add = True
            if pred == "0":
                add = tag_name == "match"
            if pred == "1":
                add = tag_name == "mismatch"
            if pred == "2":
                add = tag_name == "conflict"

            def format_scores(raw_scores):
                def format_float(s):
                    return "{0:.2f}".format(s)

                norm_scores = normalize_fn(raw_scores)

                cells = [Cell(format_float(s1), s2, False, False) for s1, s2 in zip(raw_scores, norm_scores)]
                if tag_name == "mismatch":
                    set_cells_color(cells, "G")
                elif tag_name == "conflict":
                    set_cells_color(cells, "R")
                return cells

            if add:
                p_rows.append(format_scores(p_score))
                h_rows.append(format_scores(h_score))

            p_score_list.append(p_score)
            h_score_list.append(h_score)

        pred_str = ["Entailment", "Neutral" , "Contradiction"][pred]

        out_entry = pred_str, p_tokens, h_tokens, p_score_list, h_score_list

        if len(selected_instances[pred]) < num_select :
            selected_instances[pred].append(out_entry)
            visualizer.write_headline(pred_str)
            visualizer.multirow_print_from_cells_list(p_rows)
            visualizer.multirow_print_from_cells_list(h_rows)
            visualizer.write_instance(pred_str, p_rows, h_rows)

            tex_visulizer.write_paragraph(str(pred))
            tex_visulizer.multirow_print_from_cells_list(p_rows, width=13)
            tex_visulizer.multirow_print_from_cells_list(h_rows, width=13)

        if all([len(s) == num_select for s in selected_instances]):
            break

    tex_visulizer.close_table()
    return selected_instances
def write_deletion_score_to_html(out_file_name, summarized_table: List[Entry],
                                 info: Dict[int, Dict]):
    text_to_info = claim_text_to_info()
    html = HtmlVisualizer(out_file_name)
    tokenizer = get_biobert_tokenizer()
    num_print = 0
    for entry in summarized_table:
        tokens = tokenizer.convert_ids_to_tokens(entry.input_ids)
        idx_sep1, idx_sep2 = get_sep_loc(entry.input_ids)
        max_change = 0
        max_drop = 0
        cells = cells_from_tokens(tokens)

        drops = []
        for idx in range(len(tokens)):
            if tokens[idx] == "[PAD]":
                break
            if tokens[idx] == '[SEP]':
                continue

            if idx in entry.contribution:
                raw_score = entry.contribution[idx]
                e = idx, raw_score
                drops.append(e)

        drops.sort(key=get_second)
        _, largest_drop = drops[0]

        max_drop_idx = -1
        max_drop_case_logit = None
        for idx in range(len(tokens)):
            if tokens[idx] == "[PAD]":
                break
            if tokens[idx] == '[SEP]':
                continue
            if idx in entry.contribution:
                raw_score = entry.contribution[idx]

                max_change = max(abs(raw_score), max_change)
                if max_drop > raw_score:
                    max_drop = raw_score
                    max_drop_idx = idx
                    max_drop_case_logit = entry.case_logits_d[idx]

                if raw_score < 0:
                    score = abs(raw_score / largest_drop) * 200
                    color = "B"
                else:
                    score = 0
                    color = "B"
            else:
                score = 150
                color = "Gray"
            cells[idx].highlight_score = score
            cells[idx].target_color = color

        if max_change < 0.05 and False:
            pass
        else:
            # if random.random() < 0.90:
            #     continue
            base_probs = scipy.special.softmax(entry.base_logits)
            info_entry = info[str(entry.data_id[0])]
            claim1_info: Dict = text_to_info[info_entry['text1']]
            claim2_info: Dict = text_to_info[info_entry['text2']]
            question = claim1_info['question']
            assertion1 = claim1_info['assertion']
            assertion2 = claim2_info['assertion']
            original_prediction_summary = make_prediction_summary_str(
                base_probs)
            html.write_bar()
            html.write_paragraph("Question: {}".format(question))
            html.write_paragraph("Original prediction: " +
                                 original_prediction_summary)
            html.write_paragraph("Max drop")

            rows = []
            for idx, score in drops[:5]:
                row = [Cell(str(idx)), Cell(tokens[idx]), Cell(score)]
                rows.append(row)
            html.write_table(rows)

            min_token = tokens[max_drop_idx]
            html.write_paragraph("> \"{}\": {} ".format(min_token, max_drop))
            max_drop_case_prob = scipy.special.softmax(max_drop_case_logit)
            max_drop_prediction_summary = make_prediction_summary_str(
                max_drop_case_prob)
            html.write_paragraph("> " + max_drop_prediction_summary)
            p = [Cell("Claim1 ({}):".format(assertion1))] + cells[1:idx_sep1]
            h = [Cell("Claim2 ({}):".format(assertion2))
                 ] + cells[idx_sep1 + 1:idx_sep2]
            html.write_table([p])
            html.write_table([h])
            num_print += 1

    print("printed {} of {}".format(num_print, len(summarized_table)))