Exemplo n.º 1
0
def test_ap():
    ''' todo-as-note: the denominator should be carefully checked when using AP@k '''
    # here we assume that there five relevant documents, but the system just retrieves three of them
    sys_sorted_labels = torch.Tensor([1.0, 0.0, 1.0, 0.0, 1.0])
    std_sorted_labels = torch.Tensor([1.0, 1.0, 1.0, 1.0, 1.0])
    ap_at_ks = torch_ap_at_ks(sys_sorted_labels.view(1, -1),
                              std_sorted_labels.view(1, -1),
                              ks=[1, 3, 5])
    print(ap_at_ks.size(), ap_at_ks)  # tensor([1.0000, 0.5556, 0.4533])
    ap_at_k = torch_ap_at_k(sys_sorted_labels.view(1, -1),
                            std_sorted_labels.view(1, -1),
                            k=3)
    print(ap_at_k.size(), ap_at_k)  # tensor([1.0000, 0.5556, 0.4533])

    sys_sorted_labels = torch.Tensor([1.0, 0.0, 1.0, 0.0, 1.0])
    std_sorted_labels = torch.Tensor([1.0, 1.0, 1.0, 0.0, 0.0])
    ap_at_ks = torch_ap_at_ks(sys_sorted_labels.view(1, -1),
                              std_sorted_labels.view(1, -1),
                              ks=[1, 3, 5])
    print(ap_at_ks)  # tensor([1.0000, 0.5556, 0.7556])

    # here we assume that there four relevant documents, the system just retrieves four of them
    sys_sorted_labels = torch.Tensor([1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0])
    std_sorted_labels = torch.Tensor([1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0])
    ap_at_ks = torch_ap_at_ks(sys_sorted_labels.view(1, -1),
                              std_sorted_labels.view(1, -1),
                              ks=[1, 2, 3, 5, 7])
    print(ap_at_ks)  # tensor([1.0000, 1.0000, 0.6667, 0.6875, 0.8304])
    ap_at_k = torch_ap_at_k(sys_sorted_labels.view(1, -1),
                            std_sorted_labels.view(1, -1),
                            k=5)
    print(ap_at_k)  # tensor([1.0000, 1.0000, 0.6667, 0.6875, 0.8304])
    print()
Exemplo n.º 2
0
    def cal_metric_at_ks(self,
                         model_id,
                         all_std_labels=None,
                         all_preds=None,
                         group=None,
                         ks=[1, 3, 5, 10]):
        """
        Compute metric values with different cutoff values
        :param model:
        :param all_std_labels:
        :param all_preds:
        :param group:
        :param ks:
        :return:
        """
        cnt = torch.zeros(1)

        sum_ndcg_at_ks = torch.zeros(len(ks))
        sum_nerr_at_ks = torch.zeros(len(ks))
        sum_ap_at_ks = torch.zeros(len(ks))
        sum_p_at_ks = torch.zeros(len(ks))

        list_ndcg_at_ks_per_q = []
        list_err_at_ks_per_q = []
        list_ap_at_ks_per_q = []
        list_p_at_ks_per_q = []

        tor_all_std_labels, tor_all_preds = \
            torch.from_numpy(all_std_labels.astype(np.float32)), torch.from_numpy(all_preds.astype(np.float32))

        head = 0
        if model_id.startswith('LightGBM'):
            group = group.astype(np.int).tolist()
        for gr in group:
            tor_per_query_std_labels = tor_all_std_labels[head:head + gr]
            tor_per_query_preds = tor_all_preds[head:head + gr]
            head += gr

            _, tor_sorted_inds = torch.sort(tor_per_query_preds,
                                            descending=True)

            sys_sorted_labels = tor_per_query_std_labels[tor_sorted_inds]
            ideal_sorted_labels, _ = torch.sort(tor_per_query_std_labels,
                                                descending=True)

            ndcg_at_ks = torch_nDCG_at_ks(
                sys_sorted_labels=sys_sorted_labels,
                ideal_sorted_labels=ideal_sorted_labels,
                ks=ks,
                multi_level_rele=True)
            list_ndcg_at_ks_per_q.append(ndcg_at_ks.numpy())

            nerr_at_ks = torch_nerr_at_ks(
                sys_sorted_labels=sys_sorted_labels,
                ideal_sorted_labels=ideal_sorted_labels,
                ks=ks,
                multi_level_rele=True)
            list_err_at_ks_per_q.append(nerr_at_ks.numpy())

            ap_at_ks = torch_ap_at_ks(sys_sorted_labels=sys_sorted_labels,
                                      ideal_sorted_labels=ideal_sorted_labels,
                                      ks=ks)
            list_ap_at_ks_per_q.append(ap_at_ks.numpy())

            p_at_ks = torch_p_at_ks(sys_sorted_labels=sys_sorted_labels, ks=ks)
            list_p_at_ks_per_q.append(p_at_ks.numpy())

            sum_ndcg_at_ks = torch.add(sum_ndcg_at_ks, ndcg_at_ks)
            sum_nerr_at_ks = torch.add(sum_nerr_at_ks, nerr_at_ks)
            sum_ap_at_ks = torch.add(sum_ap_at_ks, ap_at_ks)
            sum_p_at_ks = torch.add(sum_p_at_ks, p_at_ks)
            cnt += 1

        tor_avg_ndcg_at_ks = sum_ndcg_at_ks / cnt
        avg_ndcg_at_ks = tor_avg_ndcg_at_ks.data.numpy()

        tor_avg_nerr_at_ks = sum_nerr_at_ks / cnt
        avg_nerr_at_ks = tor_avg_nerr_at_ks.data.numpy()

        tor_avg_ap_at_ks = sum_ap_at_ks / cnt
        avg_ap_at_ks = tor_avg_ap_at_ks.data.numpy()

        tor_avg_p_at_ks = sum_p_at_ks / cnt
        avg_p_at_ks = tor_avg_p_at_ks.data.numpy()

        return avg_ndcg_at_ks, avg_nerr_at_ks, avg_ap_at_ks, avg_p_at_ks,\
               list_ndcg_at_ks_per_q, list_err_at_ks_per_q, list_ap_at_ks_per_q, list_p_at_ks_per_q