Exemple #1
0
def main():
    save_name = "alamri_pair"
    info_entries, output_d = load_from_pickle(save_name)
    html = HtmlVisualizer("alamri_pairing_deletion.html", use_tooltip=True)
    initial_text = load_p_h_pair_text(
        at_output_dir("alamri_pilot", "true_pair_small.csv"))
    per_group_summary: List[PerGroupSummary] = summarize_pair_deletion_results(
        info_entries, output_d)

    def float_arr_to_str_arr(float_arr):
        return list(map(two_digit_float, float_arr))

    def float_arr_to_cell(head, float_arr):
        return [Cell(head)] + lmap(Cell, map(two_digit_float, float_arr))

    def float_arr_to_cell2(head, float_arr):
        return [Cell(head)] + lmap(Cell, map("{0:.4f}".format, float_arr))

    num_data = len(output_d['input_ids'])
    for data_idx, (p, h) in enumerate(initial_text):
        group_summary = per_group_summary[data_idx]

        p_tokens = p.split()
        h_tokens = h.split()

        base_score = group_summary.score_d[(-1, -1)]
        pred_str = make_prediction_summary_str(base_score)
        html.write_paragraph("Prediction: {}".format(pred_str))

        keys = list(group_summary.score_d.keys())
        p_idx_max = max(left(keys))
        h_idx_max = max(right(keys))

        def get_pair_score_by_h(key):
            p_score, h_score = group_summary.effect_d[key]
            return h_score

        def get_pair_score_by_p(key):
            p_score, h_score = group_summary.effect_d[key]
            return p_score

        def get_table(get_pair_score_at):
            head = [Cell("")] + [Cell(t) for t in p_tokens]
            rows = [head]
            for h_idx in range(h_idx_max + 1):
                row = [Cell(h_tokens[h_idx])]
                for p_idx in range(p_idx_max + 1):
                    s = get_pair_score_at((p_idx, h_idx))
                    one_del_score = group_summary.score_d[(p_idx, -1)]
                    two_del_score = group_summary.score_d[(p_idx, h_idx)]
                    tooltip_str = "{} -> {}".format(
                        float_arr_to_str_arr(one_del_score),
                        float_arr_to_str_arr(two_del_score))
                    row.append(
                        get_tooltip_cell(two_digit_float(s), tooltip_str))
                rows.append(row)
            return rows

        html.write_table(get_table(get_pair_score_by_p))
        html.write_table(get_table(get_pair_score_by_h))
        html.write_bar()
def write_deletion_score_to_html(out_file_name, summarized_table: List[Entry],
                                 info: Dict[int, Dict]):
    text_to_info = claim_text_to_info()
    html = HtmlVisualizer(out_file_name)
    tokenizer = get_biobert_tokenizer()
    num_print = 0
    for entry in summarized_table:
        tokens = tokenizer.convert_ids_to_tokens(entry.input_ids)
        idx_sep1, idx_sep2 = get_sep_loc(entry.input_ids)
        max_change = 0
        max_drop = 0
        cells = cells_from_tokens(tokens)

        drops = []
        for idx in range(len(tokens)):
            if tokens[idx] == "[PAD]":
                break
            if tokens[idx] == '[SEP]':
                continue

            if idx in entry.contribution:
                raw_score = entry.contribution[idx]
                e = idx, raw_score
                drops.append(e)

        drops.sort(key=get_second)
        _, largest_drop = drops[0]

        max_drop_idx = -1
        max_drop_case_logit = None
        for idx in range(len(tokens)):
            if tokens[idx] == "[PAD]":
                break
            if tokens[idx] == '[SEP]':
                continue
            if idx in entry.contribution:
                raw_score = entry.contribution[idx]

                max_change = max(abs(raw_score), max_change)
                if max_drop > raw_score:
                    max_drop = raw_score
                    max_drop_idx = idx
                    max_drop_case_logit = entry.case_logits_d[idx]

                if raw_score < 0:
                    score = abs(raw_score / largest_drop) * 200
                    color = "B"
                else:
                    score = 0
                    color = "B"
            else:
                score = 150
                color = "Gray"
            cells[idx].highlight_score = score
            cells[idx].target_color = color

        if max_change < 0.05 and False:
            pass
        else:
            # if random.random() < 0.90:
            #     continue
            base_probs = scipy.special.softmax(entry.base_logits)
            info_entry = info[str(entry.data_id[0])]
            claim1_info: Dict = text_to_info[info_entry['text1']]
            claim2_info: Dict = text_to_info[info_entry['text2']]
            question = claim1_info['question']
            assertion1 = claim1_info['assertion']
            assertion2 = claim2_info['assertion']
            original_prediction_summary = make_prediction_summary_str(
                base_probs)
            html.write_bar()
            html.write_paragraph("Question: {}".format(question))
            html.write_paragraph("Original prediction: " +
                                 original_prediction_summary)
            html.write_paragraph("Max drop")

            rows = []
            for idx, score in drops[:5]:
                row = [Cell(str(idx)), Cell(tokens[idx]), Cell(score)]
                rows.append(row)
            html.write_table(rows)

            min_token = tokens[max_drop_idx]
            html.write_paragraph("> \"{}\": {} ".format(min_token, max_drop))
            max_drop_case_prob = scipy.special.softmax(max_drop_case_logit)
            max_drop_prediction_summary = make_prediction_summary_str(
                max_drop_case_prob)
            html.write_paragraph("> " + max_drop_prediction_summary)
            p = [Cell("Claim1 ({}):".format(assertion1))] + cells[1:idx_sep1]
            h = [Cell("Claim2 ({}):".format(assertion2))
                 ] + cells[idx_sep1 + 1:idx_sep2]
            html.write_table([p])
            html.write_table([h])
            num_print += 1

    print("printed {} of {}".format(num_print, len(summarized_table)))