Esempio n. 1
0
def show(html_visualizer: HtmlVisualizer,
         features: List[ParagraphClaimPersFeature]):
    print("Cid: ", features[0].claim_pers.cid)
    for f in features:
        html_visualizer.write_paragraph("Claim: " + f.claim_pers.claim_text)
        html_visualizer.write_paragraph("Perspective: " + f.claim_pers.p_text)

        pc_tokens: List[str] = nltk.word_tokenize(
            f.claim_pers.claim_text) + nltk.word_tokenize(f.claim_pers.p_text)
        pc_tokens_set = set([t.lower() for t in pc_tokens])
        print(pc_tokens_set)

        def get_cell(token) -> Cell:
            if token.lower() in pc_tokens_set:
                score = 100
            else:
                score = 0
            return Cell(token, score)

        html_visualizer.write_paragraph("Label : {}".format(
            f.claim_pers.label))
        for score_paragraph in f.feature:
            paragraph = score_paragraph.paragraph
            cells = [get_cell(t) for t in paragraph.tokens]
            html_visualizer.write_paragraph("---")
            html_visualizer.multirow_print(cells, width=20)
Esempio n. 2
0
def draw2(in_file, out_file):
    filename = os.path.join(output_path, in_file)
    data = EstimatorPredictionViewerGosford(filename)
    html_writer = HtmlVisualizer(out_file, dark_mode=False)

    tokenizer = get_tokenizer()
    for inst_i, entry in enumerate(data):
        if inst_i > 100:
            break

        tokens = entry.get_tokens("input_ids")
        # tokens = entry.get_tokens("input_ids")
        prob1 = entry.get_vector("prob1")
        prob2 = entry.get_vector("prob2")
        real_loss1 = entry.get_vector("per_example_loss1")
        real_loss2 = entry.get_vector("per_example_loss2")

        masked_lm_positions = entry.get_vector("masked_lm_positions")

        for i, loc in enumerate(masked_lm_positions):

            tokens[loc] = "[{}:{}]".format(i, tokens[loc])

        html_writer.multirow_print(data.cells_from_tokens(tokens))

        row2 = [Cell("prob1:")] + data.cells_from_anything(prob1)
        row3 = [Cell("prob2:")] + data.cells_from_anything(prob2)
        row4 = [Cell("real_loss1:")] + data.cells_from_anything(real_loss1)
        row5 = [Cell("real_loss2:")] + data.cells_from_anything(real_loss2)
        html_writer.multirow_print_from_cells_list([row2, row3, row4, row5])
Esempio n. 3
0
def show(out_file_name, summarized_table: List[Entry]):
    html = HtmlVisualizer(out_file_name)
    tokenizer = get_tokenizer()
    num_print = 0
    for input_ids, prob, contributions in summarized_table:
        tokens = tokenizer.convert_ids_to_tokens(input_ids)
        html.write_paragraph("Score : {}".format(prob))
        cells = []
        max_change = 0
        for idx in range(len(input_ids)):
            token = tokens[idx]
            if token == "[PAD]":
                break
            if idx in contributions:
                raw_score = contributions[idx]
                max_change = max(abs(raw_score), max_change)

                score = abs(raw_score) * 100
                color = "R" if raw_score > 0 else "B"
                c = Cell(token, highlight_score=score, target_color=color)
            else:
                c = Cell(token, highlight_score=150, target_color="Gray")
            cells.append(c)

        if max_change < 0.05:
            pass
        else:
            html.multirow_print(cells, 30)
            num_print += 1

    print("printed {} of {}".format(num_print, len(summarized_table)))
Esempio n. 4
0
def run():
    tokenizer = get_tokenizer()
    spr = StreamPickleReader("contradiction_prediction")

    html = HtmlVisualizer("contradiction_prediction.html")
    cnt = 0
    while spr.has_next():
        item = spr.get_item()
        e, p = item
        input_ids, _, _ = e
        logit, explain = p
        tokens = tokenizer.convert_ids_to_tokens(input_ids)
        p, h = split_p_h_with_input_ids(tokens, input_ids)
        p_score, h_score = split_p_h_with_input_ids(explain, input_ids)

        p_score = normalize(p_score)
        h_score = normalize(h_score)
        p_cells = [Cell("P:")] + cells_from_tokens(p, p_score)
        h_cells = [Cell("H:")] + cells_from_tokens(h, h_score)

        html.write_paragraph(str(logit))
        html.multirow_print(p_cells)
        html.multirow_print(h_cells)

        if cnt > 100:
            break
        cnt += 1
Esempio n. 5
0
def view_grad_overlap_per_mask():
    filename = "ukp_lm_probs.pickle"

    out_name = filename.split(".")[0] + ".html"
    html_writer = HtmlVisualizer(out_name, dark_mode=False)
    data = EstimatorPredictionViewerGosford(filename)
    tokenizer = data.tokenizer
    for inst_i, entry in enumerate(data):
        tokens = entry.get_mask_resolved_input_mask_with_input()
        highlight = lmap(is_mask, tokens)
        scores = entry.get_vector("overlap_score")
        pos_list = entry.get_vector("masked_lm_positions")
        probs = entry.get_vector("masked_lm_log_probs")
        probs = np.reshape(probs, [20, -1])
        rows = []
        for score, position, prob in zip(scores, pos_list, probs):
            tokens[position] = "{}-".format(position) + tokens[position]

            row = [Cell(position), Cell(score)]

            for idx in np.argsort(prob)[::-1][:5]:
                term = tokenizer.inv_vocab[idx]
                p = math.exp(prob[idx])
                row.append(Cell(term))
                row.append(Cell(p))
            rows.append(row)

        cells = data.cells_from_tokens(tokens, highlight)
        for score, position in zip(scores, pos_list):
            cells[position].highlight_score = score / 10000 * 255

        html_writer.multirow_print(cells, 20)
        html_writer.write_table(rows)
Esempio n. 6
0
def draw():
    #name = "pc_para_D_grad"
    name = "pc_para_I_grad"
    #name = "pc_para_H_grad"
    data = EstimatorPredictionViewerGosford(name)
    html_writer = HtmlVisualizer(name + ".html", dark_mode=False)

    for inst_i, entry in enumerate(data):
        tokens = entry.get_tokens("input_ids")
        grad = entry.get_vector("gradient")
        m = min(grad)

        cells = data.cells_from_tokens(tokens)

        for i, cell in enumerate(cells):
            cells[i].highlight_score = min(abs(grad[i]) * 1e4, 255)
            cells[i].target_color = "B" if grad[i] > 0 else "R"
        print(grad)
        prob = softmax(entry.get_vector("logits"))

        pred = np.argmax(prob)

        label = entry.get_vector("labels")
        html_writer.write_paragraph("Label={} / Pred={}".format(str(label), pred))
        html_writer.multirow_print(cells)
Esempio n. 7
0
def view_grad_overlap():
    filename = "gradient_overlap_4K.pickle"

    out_name = filename.split(".")[0] + ".html"
    html_writer = HtmlVisualizer(out_name, dark_mode=False)

    data = EstimatorPredictionViewerGosford(filename)
    iba = IntBinAverage()
    scores = []
    for inst_i, entry in enumerate(data):
        masked_lm_example_loss = entry.get_vector("masked_lm_example_loss")
        score = entry.get_vector("overlap_score")

        if masked_lm_example_loss > 1:
            norm_score = score / masked_lm_example_loss
            iba.add(masked_lm_example_loss, norm_score)
        scores.append(score)

    score_avg = average(scores)
    score_std = np.std(scores)

    avg = iba.all_average()
    std_dict = {}
    for key, values in iba.list_dict.items():
        std_dict[key] = np.std(values)
        if len(values) == 1:
            std_dict[key] = 999

    def unlikeliness(value, mean, std):
        return abs(value - mean) / std

    data = EstimatorPredictionViewerGosford(filename)
    print("num record : ", data.data_len)
    cnt = 0
    for inst_i, entry in enumerate(data):
        tokens = entry.get_mask_resolved_input_mask_with_input()
        masked_lm_example_loss = entry.get_vector("masked_lm_example_loss")
        highlight = lmap(is_mask, tokens)
        score = entry.get_vector("overlap_score")
        print(score)
        cells = data.cells_from_tokens(tokens, highlight)
        if masked_lm_example_loss > 1:
            bin_key = int(masked_lm_example_loss)
            norm_score = score / masked_lm_example_loss
            if norm_score > 5000:
                cnt += 1
            expectation = avg[bin_key]
            if unlikeliness(score, score_avg, score_std) > 2 or True:
                html_writer.multirow_print(cells, 20)
                if norm_score > expectation:
                    html_writer.write_paragraph("High")
                else:
                    html_writer.write_paragraph("Low")
                html_writer.write_paragraph("Norm score: " + str(norm_score))
                html_writer.write_paragraph("score: " + str(score))
                html_writer.write_paragraph("masked_lm_example_loss: " +
                                            str(masked_lm_example_loss))
                html_writer.write_paragraph("expectation: " + str(expectation))
    print("number over 5000: ", cnt)
Esempio n. 8
0
def loss_view(dir_path):
    tokenizer = get_tokenizer()
    html_writer = HtmlVisualizer("ukp_lm_grad_high.html", dark_mode=False)

    for file_path in get_dir_files(dir_path):
        items = pickle.load(open(file_path, "rb"))

        for e in items:
            input_ids, masked_input_ids, masked_lm_example_loss = e
            tokens = mask_resolve_1(
                tokenizer.convert_ids_to_tokens(input_ids),
                tokenizer.convert_ids_to_tokens(masked_input_ids))
            highlight = lmap(is_mask, tokens)

            cells = cells_from_tokens(tokens, highlight)
            html_writer.multirow_print(cells)
Esempio n. 9
0
def print_as_html(fn):
    examples = load_record(fn)
    tokenizer = tokenizer_wo_tf.FullTokenizer(
        os.path.join(data_path, "bert_voca.txt"))

    html_output = HtmlVisualizer("out_name.html")

    for feature in examples:
        masked_inputs = feature["input_ids"].int64_list.value
        idx = 0
        step = 512
        while idx < len(masked_inputs):
            slice = masked_inputs[idx:idx + step]
            tokens = tokenizer.convert_ids_to_tokens(slice)
            idx += step
            cells = cells_from_tokens(tokens)
            html_output.multirow_print(cells)
        html_output.write_paragraph("----------")
Esempio n. 10
0
def loss_view():
    filename = "sero_pred.pickle"
    p = os.path.join(output_path, filename)
    data = pickle.load(open(p, "rb"))
    print(data[0]["masked_lm_example_loss"].shape)
    print(data[0]["masked_input_ids"].shape)

    html_writer = HtmlVisualizer("sero_pred.html", dark_mode=False)

    data = EstimatorPredictionViewerGosford(filename)
    for inst_i, entry in enumerate(data):
        losses = entry.get_vector("masked_lm_example_loss")
        print(losses)
        tokens = entry.get_tokens("masked_input_ids")
        cells = data.cells_from_tokens(tokens)
        row = []

        for idx, cell in enumerate(cells):
            row.append(cell)
            if len(row) == 20:
                html_writer.write_table([row])
                row = []

        html_writer.multirow_print(data.cells_from_anything(losses), 20)
Esempio n. 11
0
def main():
    n_factor = 16
    step_size = 16
    max_seq_length = 128
    max_seq_length2 = 128 - 16
    batch_size = 8
    info_file_path = at_output_dir("robust", "seg_info")
    queries = load_robust_04_query("desc")
    qid_list = get_robust_qid_list()

    f_handler = get_format_handler("qc")
    info: Dict = load_combine_info_jsons(info_file_path,
                                         f_handler.get_mapping(),
                                         f_handler.drop_kdp())
    print(len(info))
    tokenizer = get_tokenizer()

    for job_idx in [1]:
        qid = qid_list[job_idx]
        query = queries[str(qid)]
        q_term_length = len(tokenizer.tokenize(query))
        data_path1 = os.path.join(output_path, "robust",
                                  "windowed_{}.score".format(job_idx))
        data_path2 = os.path.join(output_path, "robust",
                                  "windowed_small_{}.score".format(job_idx))
        data1 = OutputViewer(data_path1, n_factor, batch_size)
        data2 = OutputViewer(data_path2, n_factor, batch_size)
        segment_len = max_seq_length - 3 - q_term_length
        segment_len2 = max_seq_length2 - 3 - q_term_length

        outputs = []
        for d1, d2 in zip(data1, data2):
            # for each query, doc pairs
            cur_info1 = info[d1['data_id']]
            cur_info2 = info[d2['data_id']]
            query_doc_id1 = f_handler.get_pair_id(cur_info1)
            query_doc_id2 = f_handler.get_pair_id(cur_info2)

            assert query_doc_id1 == query_doc_id2

            doc = d1['doc']
            probs = get_probs(d1['logits'])
            probs2 = get_probs(d2['logits'])
            n_pred_true = np.count_nonzero(np.less(0.5, probs))
            print(n_pred_true, len(probs))

            seg_scores: List[Tuple[int, int, float]] = get_piece_scores(
                n_factor, probs, segment_len, step_size)
            seg_scores2: List[Tuple[int, int, float]] = get_piece_scores(
                n_factor, probs2, segment_len2, step_size)
            ss_list = []
            for st, ed, score in seg_scores:
                try:
                    st2, ed2, score2 = find_where(lambda x: x[1] == ed,
                                                  seg_scores2)
                    assert ed == ed2
                    assert st < st2
                    tokens = tokenizer.convert_ids_to_tokens(doc[st:st2])
                    diff = score - score2
                    ss = ScoredPiece(st, st2, diff, tokens)
                    ss_list.append(ss)
                except StopIteration:
                    pass
            outputs.append((probs, probs2, query_doc_id1, ss_list))

        html = HtmlVisualizer("windowed.html")

        for probs, probs2, query_doc_id, ss_list in outputs:
            html.write_paragraph(str(query_doc_id))
            html.write_paragraph("Query: " + query)

            ss_list.sort(key=lambda ss: ss.st)
            prev_end = None
            cells = []
            prob_str1 = lmap(two_digit_float, probs)
            prob_str1 = ["8.88"] + prob_str1
            prob_str2 = lmap(two_digit_float, probs2)
            html.write_paragraph(" ".join(prob_str1))
            html.write_paragraph(" ".join(prob_str2))

            for ss in ss_list:
                if prev_end is not None:
                    assert prev_end == ss.st
                else:
                    print(ss.st)

                score = abs(int(100 * ss.score))
                color = "B" if score > 0 else "R"
                cells.extend(
                    [Cell(t, score, target_color=color) for t in ss.tokens])
                prev_end = ss.ed

            html.multirow_print(cells)
Esempio n. 12
0
def binary_feature_demo(datapoint_list):
    ci = PassageRankedListInterface(make_passage_query, Q_CONFIG_ID_BM25)
    not_found_set = set()
    _, clue12_13_df = load_clueweb12_B13_termstat()
    cdf = 50 * 1000 * 1000
    html = HtmlVisualizer("pc_binary_feature.html")

    def idf_scorer(doc, claim_text, perspective_text):
        cp_tokens = nltk.word_tokenize(claim_text) + nltk.word_tokenize(
            perspective_text)
        cp_tokens = lmap(lambda x: x.lower(), cp_tokens)
        cp_tokens = set(cp_tokens)
        mentioned_terms = lfilter(lambda x: x in doc, cp_tokens)
        mentioned_terms = re_tokenize(mentioned_terms)

        def idf(term):
            if term not in clue12_13_df:
                if term in string.printable:
                    return 0
                not_found_set.add(term)

            return math.log((cdf + 0.5) / (clue12_13_df[term] + 0.5))

        score = sum(lmap(idf, mentioned_terms))
        max_score = sum(lmap(idf, cp_tokens))
        # print(claim_text, perspective_text)
        # print(mentioned_terms)
        # print(score, max_score)
        return score, max_score, mentioned_terms

    def bm25_estimator(doc: Counter, claim_text: str, perspective_text: str):
        cp_tokens = nltk.word_tokenize(claim_text) + nltk.word_tokenize(
            perspective_text)
        cp_tokens = lmap(lambda x: x.lower(), cp_tokens)
        k1 = 0

        def BM25_3(f, qf, df, N, dl, avdl) -> float:
            K = compute_K(dl, avdl)
            first = math.log((N - df + 0.5) / (df + 0.5))
            second = ((k1 + 1) * f) / (K + f)
            return first * second

        dl = sum(doc.values())
        info = []
        for q_term in set(cp_tokens):
            if q_term in doc:
                score = BM25_3(doc[q_term], 0, clue12_13_df[q_term], cdf, dl,
                               1200)
                info.append((q_term, doc[q_term], clue12_13_df[q_term], score))
        return info

    print_cnt = 0
    for dp_idx, x in enumerate(datapoint_list):
        ranked_list: List[GalagoRankEntry] = ci.query_passage(
            x.cid, x.pid, x.claim_text, x.p_text)
        html.write_paragraph(x.claim_text)
        html.write_paragraph(x.p_text)
        html.write_paragraph("{}".format(x.label))

        local_print_cnt = 0
        lines = []
        for ranked_entry in ranked_list:
            try:
                doc_id = ranked_entry.doc_id
                galago_score = ranked_entry.score

                tokens = load_doc(doc_id)
                doc_tf = Counter(tokens)
                if doc_tf is not None:
                    score, max_score, mentioned_terms = idf_scorer(
                        doc_tf, x.claim_text, x.p_text)
                    matched = score > max_score * 0.75
                else:
                    matched = "Unk"
                    score = "Unk"
                    max_score = "Unk"

                def get_cell(token):
                    if token in mentioned_terms:
                        return Cell(token, highlight_score=50)
                    else:
                        return Cell(token)

                line = [doc_id, galago_score, matched, score, max_score]
                lines.append(line)
                html.write_paragraph("{0} / {1:.2f}".format(
                    doc_id, galago_score))
                html.write_paragraph("{}/{}".format(score, max_score))
                bm25_info = bm25_estimator(doc_tf, x.claim_text, x.p_text)
                bm25_score = sum(lmap(lambda x: x[3], bm25_info))
                html.write_paragraph(
                    "bm25 re-estimate : {}".format(bm25_score))
                html.write_paragraph("{}".format(bm25_info))
                html.multirow_print(lmap(get_cell, tokens))
                local_print_cnt += 1
                if local_print_cnt > 10:
                    break
            except KeyError:
                pass

        matched_idx = idx_where(lambda x: x[2], lines)
        if not matched_idx:
            html.write_paragraph("No match")
        else:
            last_matched = matched_idx[-1]
            lines = lines[:last_matched + 1]
            rows = lmap(lambda line: lmap(Cell, line), lines)
            html.write_table(rows)

        if dp_idx > 10:
            break
Esempio n. 13
0
def main():
    prediction_file_path = at_output_dir("robust", "rob_dense_pred.score")
    info_file_path = at_job_man_dir1("robust_predict_desc_128_step16_info")
    queries: Dict[str, str] = load_robust_04_query("desc")
    tokenizer = get_tokenizer()
    query_token_len_d = {}
    for qid, q_text in queries.items():
        query_token_len_d[qid] = len(tokenizer.tokenize(q_text))
    step_size = 16
    window_size = 128
    out_entries: List[DocTokenScore] = collect_token_scores(
        info_file_path, prediction_file_path, query_token_len_d, step_size,
        window_size)

    qrel_path = "/home/youngwookim/Downloads/rob04-desc/qrels.rob04.txt"
    judgement_d = load_qrels_structured(qrel_path)

    html = HtmlVisualizer("robust_desc_128_step16.html", use_tooltip=True)

    tprint("loading tokens pickles")
    tokens_d: Dict[str, List[str]] = load_pickle_from(
        os.path.join(sydney_working_dir, "RobustPredictTokens3", "1"))
    tprint("Now printing")
    n_printed = 0

    def transform(x):
        return 3 * (math.pow(x - 0.5, 3) + math.pow(0.5, 3))

    for e in out_entries:
        max_score = e.max_segment_score()
        if max_score < 0.6:
            continue
        n_printed += 1
        if n_printed > 10:
            break
        doc_tokens: List[str] = tokens_d[e.doc_id]
        score_len = len(e.scores)
        judgement: Dict[str, int] = judgement_d[e.query_id]
        label = judgement[e.doc_id]

        if not len(doc_tokens) <= score_len < len(doc_tokens) + window_size:
            print("doc length : ", len(doc_tokens))
            print("score len:", score_len)
            print("doc length +step_size: ", len(doc_tokens) + step_size)
            raise IndexError

        row = []
        q_text = queries[e.query_id]
        html.write_paragraph("qid: " + e.query_id)
        html.write_paragraph("q_text: " + q_text)
        html.write_paragraph("Pred: {0:.2f}".format(max_score))
        html.write_paragraph("Label: {0:.2f}".format(label))

        for idx in range(score_len):
            token = doc_tokens[idx] if idx < len(doc_tokens) else '[-]'

            full_scores = e.full_scores[idx]
            full_score_str = " ".join(lmap(two_digit_float, full_scores))
            score = e.scores[idx]
            normalized_score = transform(score) * 200
            c = get_tooltip_cell(token, full_score_str)
            c.highlight_score = normalized_score
            row.append(c)

        html.multirow_print(row, 16)
Esempio n. 14
0
def join_docs_and_lm():
    gold = get_claim_perspective_id_dict()

    d_ids = list(load_train_claim_ids())
    claims: List[Dict] = get_claims_from_ids(d_ids)
    claims = claims[:10]
    top_n = 10
    q_res_path = FilePath(
        "/mnt/nfs/work3/youngwookim/data/perspective/train_claim/q_res_100")
    ranked_list: Dict[
        str, List[SimpleRankedListEntry]] = load_galago_ranked_list(q_res_path)
    preload_docs(ranked_list, claims, top_n)
    claim_lms = build_gold_lms(claims)
    claim_lms_d = {lm.cid: lm for lm in claim_lms}
    bg_lm = average_counters(lmap(lambda x: x.LM, claim_lms))
    log_bg_lm = get_lm_log(bg_lm)

    stopwords.update([".", ",", "!", "?"])

    alpha = 0.1

    html_visualizer = HtmlVisualizer("doc_lm_joined.html")

    def get_cell_from_token2(token, probs):
        if token.lower() in stopwords:
            probs = 0
        probs = probs * 1e5
        s = min(100, probs)
        c = Cell(token, s)
        return c

    tokenizer = PCTokenizer()
    for c in claims:
        q_res: List[SimpleRankedListEntry] = ranked_list[str(c['cId'])]
        html_visualizer.write_headline("{} : {}".format(c['cId'], c['text']))

        clusters: List[List[int]] = gold[c['cId']]

        for cluster in clusters:
            html_visualizer.write_paragraph("---")
            p_text_list: List[str] = lmap(perspective_getter, cluster)
            for text in p_text_list:
                html_visualizer.write_paragraph(text)
            html_visualizer.write_paragraph("---")
        claim_lm = claim_lms_d[c['cId']]
        topic_lm_prob = smooth(claim_lm.LM, bg_lm, alpha)
        log_topic_lm = get_lm_log(smooth(claim_lm.LM, bg_lm, alpha))
        log_odd: Counter = subtract(log_topic_lm, log_bg_lm)

        s = "\t".join(left(log_odd.most_common(30)))
        html_visualizer.write_paragraph("Log odd top: " + s)
        not_found = set()

        def get_log_odd(x):
            x = tokenizer.stemmer.stem(x)
            if x not in log_odd:
                not_found.add(x)
            return log_odd[x]

        def get_probs(x):
            x = tokenizer.stemmer.stem(x)
            if x not in topic_lm_prob:
                not_found.add(x)
            return topic_lm_prob[x]

        for i in range(top_n):
            try:
                doc = load_doc(q_res[i].doc_id)
                cells = lmap(lambda x: get_cell_from_token(x, get_log_odd(x)),
                             doc)
                html_visualizer.write_headline("Doc rank {}".format(i))
                html_visualizer.multirow_print(cells, width=20)
            except KeyError:
                pass
        html_visualizer.write_paragraph("Not found: {}".format(not_found))
Esempio n. 15
0
def main():
    cid = sys.argv[1]
    print("Loading scores...")
    baseline_cid_grouped = load_baseline("train_baseline")
    gold = get_claim_perspective_id_dict()
    tokenizer = get_tokenizer()
    claim_d = load_train_claim_d()

    print("Start analyzing")
    html = HtmlVisualizer("cppnc_value_per_token_score_{}.html".format(cid))
    pid_entries_d = load_entries(cid)
    pid_entries_d: Dict[str, List[Dict]] = pid_entries_d
    pid_entries: List[Tuple[str, List[Dict]]] = list(pid_entries_d.items())
    baseline_pid_entries = baseline_cid_grouped[int(cid)]
    baseline_score_d = fetch_score_per_pid(baseline_pid_entries)
    gold_pids = gold[int(cid)]

    ret = collect_score_per_doc(baseline_score_d, get_score_from_entry, gold_pids,
                                                              pid_entries)
    passage_tokens_d = collect_passage_tokens(pid_entries)
    doc_info_d: Dict[int, Tuple[str, int]] = ret[0]
    doc_value_arr: List[List[float]] = ret[1]

    kdp_result_grouped = defaultdict(list)
    for doc_idx, doc_values in enumerate(doc_value_arr):
        doc_id, passage_idx = doc_info_d[doc_idx]
        avg_score = average(doc_values)
        kdp_result = doc_id, passage_idx, avg_score
        kdp_result_grouped[doc_id].append(kdp_result)

    s = "{} : {}".format(cid, claim_d[int(cid)])
    html.write_headline(s)

    scores: List[float] = list([r[2] for r in doc_value_arr])

    foreach(html.write_paragraph, lmap(str, scores))

    for doc_id, kdp_result_list in kdp_result_grouped.items():
        html.write_headline(doc_id)
        tokens, per_token_score = combine_collect_score(tokenizer, doc_id, passage_tokens_d, kdp_result_list)
        str_tokens = tokenizer.convert_ids_to_tokens(tokens)
        row = cells_from_tokens(str_tokens)
        for idx in range(len(str_tokens)):
            score = average(per_token_score[idx])
            norm_score = min(abs(score) * 10000, 100)
            color = "B" if score > 0 else "R"
            row[idx].highlight_score = norm_score
            row[idx].target_color = color

        rows = [row]
        nth = 0
        any_score_found = True
        while any_score_found:
            any_score_found = False
            score_list = []
            for idx in range(len(str_tokens)):
                if nth < len(per_token_score[idx]):
                    score = per_token_score[idx][nth]
                    any_score_found = True
                else:
                    score = "-"
                score_list.append(score)

            def get_cell(score):
                if score == "-":
                    return Cell("-")
                else:
                    # 0.01 -> 100
                    norm_score = min(abs(score) * 20000, 255)
                    color = "B" if score > 0 else "R"
                    return Cell("", highlight_score=norm_score, target_color=color)

            nth += 1
            if any_score_found:
                row = lmap(get_cell, score_list)
                rows.append(row)
        html.multirow_print(rows[0])
Esempio n. 16
0
def main():
    prediction_file_path = at_output_dir("robust", "rob_dense2_pred.score")
    info_file_path = at_job_man_dir1("robust_predict_desc_128_step16_2_info")
    queries: Dict[str, str] = load_robust_04_query("desc")
    tokenizer = get_tokenizer()
    query_token_len_d = {}
    for qid, q_text in queries.items():
        query_token_len_d[qid] = len(tokenizer.tokenize(q_text))
    step_size = 16
    window_size = 128
    out_entries: List[AnalyzedDoc] = token_score_by_ablation(
        info_file_path, prediction_file_path, query_token_len_d, step_size,
        window_size)

    qrel_path = "/home/youngwookim/Downloads/rob04-desc/qrels.rob04.txt"
    judgement_d = load_qrels_structured(qrel_path)

    html = HtmlVisualizer("robust_desc_128_step16_2.html", use_tooltip=True)

    tprint("loading tokens pickles")
    tokens_d: Dict[str, List[str]] = load_pickle_from(
        os.path.join(sydney_working_dir, "RobustPredictTokens3", "1"))
    tprint("Now printing")
    n_printed = 0

    def transform(x):
        return 3 * (math.pow(x - 0.5, 3) + math.pow(0.5, 3))

    n_pos = 0
    n_neg = 0
    for e in out_entries:
        max_score: float = max(
            lmap(SegmentScorePair.get_max_score,
                 flatten(e.token_info.values())))
        if max_score < 0.6:
            if n_neg > n_pos:
                continue
            else:
                n_neg += 1
                pass
        else:
            n_pos += 1

        n_printed += 1
        if n_printed > 500:
            break

        doc_tokens: List[str] = tokens_d[e.doc_id]
        score_len = max(e.token_info.keys()) + 1
        judgement: Dict[str, int] = judgement_d[e.query_id]
        label = judgement[e.doc_id]

        if not len(doc_tokens) <= score_len < len(doc_tokens) + window_size:
            print("doc length : ", len(doc_tokens))
            print("score len:", score_len)
            print("doc length +step_size: ", len(doc_tokens) + step_size)
            continue

        row = []
        q_text = queries[e.query_id]
        html.write_paragraph("qid: " + e.query_id)
        html.write_paragraph("q_text: " + q_text)
        html.write_paragraph("Pred: {0:.2f}".format(max_score))
        html.write_paragraph("Label: {0:.2f}".format(label))

        for idx in range(score_len):
            token = doc_tokens[idx] if idx < len(doc_tokens) else '[-]'
            token_info: List[SegmentScorePair] = e.token_info[idx]
            full_scores: List[float] = lmap(SegmentScorePair.get_score_diff,
                                            token_info)

            full_score_str = " ".join(lmap(two_digit_float, full_scores))
            # 1 ~ -1
            score = average(full_scores)
            if score > 0:
                color = "B"
            else:
                color = "R"
            normalized_score = transform(abs(score)) * 200
            c = get_tooltip_cell(token, full_score_str)
            c.highlight_score = normalized_score
            c.target_color = color
            row.append(c)

        html.multirow_print(row, 16)