Exemple #1
0
def test_nerr():
    sys_sorted_labels = torch.Tensor([3.0, 2.0, 4.0])
    std_sorted_labels = torch.Tensor([4.0, 3.0, 2.0])
    nerr_at_ks = torch_nerr_at_ks(sys_sorted_labels,
                                  std_sorted_labels,
                                  ks=[1, 2, 3])
    print(nerr_at_ks)  # tensor([0.4667, 0.5154, 0.6640])
def test_nerr():
    sys_sorted_labels = torch.Tensor([3.0, 2.0, 4.0])
    std_sorted_labels = torch.Tensor([4.0, 3.0, 2.0])
    # convert to batch mode
    batch_nerr_at_ks = torch_nerr_at_ks(sys_sorted_labels.view(1, -1),
                                        std_sorted_labels.view(1, -1),
                                        ks=[1, 2, 3])
    print(batch_nerr_at_ks)  # tensor([0.4667, 0.5154, 0.6640])
Exemple #3
0
    def cal_metric_at_ks(self,
                         model_id,
                         all_std_labels=None,
                         all_preds=None,
                         group=None,
                         ks=[1, 3, 5, 10]):
        """
        Compute metric values with different cutoff values
        :param model:
        :param all_std_labels:
        :param all_preds:
        :param group:
        :param ks:
        :return:
        """
        cnt = torch.zeros(1)

        sum_ndcg_at_ks = torch.zeros(len(ks))
        sum_nerr_at_ks = torch.zeros(len(ks))
        sum_ap_at_ks = torch.zeros(len(ks))
        sum_p_at_ks = torch.zeros(len(ks))

        list_ndcg_at_ks_per_q = []
        list_err_at_ks_per_q = []
        list_ap_at_ks_per_q = []
        list_p_at_ks_per_q = []

        tor_all_std_labels, tor_all_preds = \
            torch.from_numpy(all_std_labels.astype(np.float32)), torch.from_numpy(all_preds.astype(np.float32))

        head = 0
        if model_id.startswith('LightGBM'):
            group = group.astype(np.int).tolist()
        for gr in group:
            tor_per_query_std_labels = tor_all_std_labels[head:head + gr]
            tor_per_query_preds = tor_all_preds[head:head + gr]
            head += gr

            _, tor_sorted_inds = torch.sort(tor_per_query_preds,
                                            descending=True)

            sys_sorted_labels = tor_per_query_std_labels[tor_sorted_inds]
            ideal_sorted_labels, _ = torch.sort(tor_per_query_std_labels,
                                                descending=True)

            ndcg_at_ks = torch_nDCG_at_ks(
                sys_sorted_labels=sys_sorted_labels,
                ideal_sorted_labels=ideal_sorted_labels,
                ks=ks,
                multi_level_rele=True)
            list_ndcg_at_ks_per_q.append(ndcg_at_ks.numpy())

            nerr_at_ks = torch_nerr_at_ks(
                sys_sorted_labels=sys_sorted_labels,
                ideal_sorted_labels=ideal_sorted_labels,
                ks=ks,
                multi_level_rele=True)
            list_err_at_ks_per_q.append(nerr_at_ks.numpy())

            ap_at_ks = torch_ap_at_ks(sys_sorted_labels=sys_sorted_labels,
                                      ideal_sorted_labels=ideal_sorted_labels,
                                      ks=ks)
            list_ap_at_ks_per_q.append(ap_at_ks.numpy())

            p_at_ks = torch_p_at_ks(sys_sorted_labels=sys_sorted_labels, ks=ks)
            list_p_at_ks_per_q.append(p_at_ks.numpy())

            sum_ndcg_at_ks = torch.add(sum_ndcg_at_ks, ndcg_at_ks)
            sum_nerr_at_ks = torch.add(sum_nerr_at_ks, nerr_at_ks)
            sum_ap_at_ks = torch.add(sum_ap_at_ks, ap_at_ks)
            sum_p_at_ks = torch.add(sum_p_at_ks, p_at_ks)
            cnt += 1

        tor_avg_ndcg_at_ks = sum_ndcg_at_ks / cnt
        avg_ndcg_at_ks = tor_avg_ndcg_at_ks.data.numpy()

        tor_avg_nerr_at_ks = sum_nerr_at_ks / cnt
        avg_nerr_at_ks = tor_avg_nerr_at_ks.data.numpy()

        tor_avg_ap_at_ks = sum_ap_at_ks / cnt
        avg_ap_at_ks = tor_avg_ap_at_ks.data.numpy()

        tor_avg_p_at_ks = sum_p_at_ks / cnt
        avg_p_at_ks = tor_avg_p_at_ks.data.numpy()

        return avg_ndcg_at_ks, avg_nerr_at_ks, avg_ap_at_ks, avg_p_at_ks,\
               list_ndcg_at_ks_per_q, list_err_at_ks_per_q, list_ap_at_ks_per_q, list_p_at_ks_per_q