示例#1
0
def compute_mrr(D, I, qids, ref_dict):
    knn_pkl = {"D": D, "I": I}
    all_knn_list = all_gather(knn_pkl)
    mrr = 0.0
    if is_first_worker():
        D_merged = concat_key(all_knn_list, "D", axis=1)
        I_merged = concat_key(all_knn_list, "I", axis=1)
        print(D_merged.shape, I_merged.shape)
        # we pad with negative pids and distance -128 - if they make it to the top we have a problem
        idx = np.argsort(D_merged, axis=1)[:, ::-1][:, :10]
        sorted_I = np.take_along_axis(I_merged, idx, axis=1)
        candidate_dict = {}
        for i, qid in enumerate(qids):
            seen_pids = set()
            if qid not in candidate_dict:
                candidate_dict[qid] = [0] * 1000
            j = 0
            for pid in sorted_I[i]:
                if pid >= 0 and pid not in seen_pids:
                    candidate_dict[qid][j] = pid
                    j += 1
                    seen_pids.add(pid)

        allowed, message = quality_checks_qids(ref_dict, candidate_dict)
        if message != '':
            print(message)

        mrr_metrics = compute_metrics(ref_dict, candidate_dict)
        mrr = mrr_metrics["MRR @10"]
        print(mrr)
    return mrr
示例#2
0
def compute_mrr_last(D, I, qids, ref_dict, dev_query_positive_id):
    knn_pkl = {"D": D, "I": I}
    all_knn_list = all_gather(knn_pkl)
    mrr = 0.0
    final_recall = 0.0
    if is_first_worker():
        prediction = {}
        D_merged = concat_key(all_knn_list, "D", axis=1)
        I_merged = concat_key(all_knn_list, "I", axis=1)
        print(D_merged.shape, I_merged.shape)
        # we pad with negative pids and distance -128 - if they make it to the top we have a problem
        idx = np.argsort(D_merged, axis=1)[:, ::-1][:, :1000]
        sorted_I = np.take_along_axis(I_merged, idx, axis=1)
        candidate_dict = {}
        for i, qid in enumerate(qids):
            seen_pids = set()
            if qid not in candidate_dict:
                prediction[qid] = {}
                candidate_dict[qid] = [0] * 1000
            j = 0
            for pid in sorted_I[i]:
                if pid >= 0 and pid not in seen_pids:
                    candidate_dict[qid][j] = pid
                    prediction[qid][pid] = -(j + 1)  #-rank
                    j += 1
                    seen_pids.add(pid)

        # allowed, message = quality_checks_qids(ref_dict, candidate_dict)
        # if message != '':
        #     print(message)

        # mrr_metrics = compute_metrics(ref_dict, candidate_dict)
        # mrr = mrr_metrics["MRR @10"]
        # print(mrr)
        allowed, message = quality_checks_qids(ref_dict, candidate_dict)
        if message != '':
            print(message)

        mrr_metrics = compute_metrics(ref_dict, candidate_dict)
        mrr = mrr_metrics["MRR @10"]
        print(mrr)

        evaluator = pytrec_eval.RelevanceEvaluator(
            convert_to_string_id(dev_query_positive_id), {'recall'})

        eval_query_cnt = 0
        recall = 0
        topN = 1000
        result = evaluator.evaluate(convert_to_string_id(prediction))
        for k in result.keys():
            eval_query_cnt += 1
            recall += result[k]["recall_" + str(topN)]

        final_recall = recall / eval_query_cnt
        print('final_recall: ', final_recall)

    return mrr, final_recall