Example #1
0
def ndcg_at_ks(ranker=None, test_data=None, ks=[1, 5, 10], label_type=LABEL_TYPE.MultiLabel):
    '''
    There is no check based on the assumption (say light_filtering() is called)
    that each test instance Q includes at least k(k=max(ks)) documents, and at least one relevant document.
    Or there will be errors.
    '''
    sum_ndcg_at_ks = torch.zeros(len(ks))
    cnt = torch.zeros(1)
    already_sorted = True if test_data.presort else False
    for qid, batch_ranking, batch_labels in test_data: # _, [batch, ranking_size, num_features], [batch, ranking_size]
        if gpu: batch_ranking = batch_ranking.to(device)
        batch_rele_preds = ranker.predict(batch_ranking)
        if gpu: batch_rele_preds = batch_rele_preds.cpu()

        _, batch_sorted_inds = torch.sort(batch_rele_preds, dim=1, descending=True)

        batch_sys_sorted_labels = torch.gather(batch_labels, dim=1, index=batch_sorted_inds)
        if already_sorted:
            batch_ideal_sorted_labels = batch_labels
        else:
            batch_ideal_sorted_labels, _ = torch.sort(batch_labels, dim=1, descending=True)

        batch_ndcg_at_ks = torch_nDCG_at_ks(batch_sys_sorted_labels=batch_sys_sorted_labels,
                                            batch_ideal_sorted_labels=batch_ideal_sorted_labels,
                                            ks=ks, label_type=label_type)

        # default batch_size=1 due to testing data
        sum_ndcg_at_ks = torch.add(sum_ndcg_at_ks, torch.squeeze(batch_ndcg_at_ks, dim=0))
        cnt += 1

    avg_ndcg_at_ks = sum_ndcg_at_ks/cnt
    return avg_ndcg_at_ks
def test_ndcg():
    sys_sorted_labels = torch.Tensor([1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0])
    std_sorted_labels = torch.Tensor([1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0])
    ndcg_at_ks = torch_nDCG_at_ks(sys_sorted_labels.view(1, -1),
                                  std_sorted_labels.view(1, -1),
                                  ks=[1, 2, 3, 4, 5, 6, 7])
    print(ndcg_at_ks
          )  # tensor([1.0000, 1.0000, 0.7654, 0.8048, 0.8048, 0.8048, 0.9349])
Example #3
0
def ndcg_at_ks(ranker=None,
               test_data=None,
               ks=[1, 5, 10],
               multi_level_rele=True,
               batch_mode=True):
    '''
    There is no check based on the assumption (say light_filtering() is called)
    that each test instance Q includes at least k(k=max(ks)) documents, and at least one relevant document.
    Or there will be errors.
    '''
    sum_ndcg_at_ks = torch.zeros(len(ks))
    cnt = torch.zeros(1)
    for qid, batch_ranking, batch_label in test_data:  # _, [batch, ranking_size, num_features], [batch, ranking_size]
        if gpu: batch_ranking = batch_ranking.to(device)

        if batch_mode:
            batch_rele_preds = ranker.predict(batch_ranking)
            rele_preds = torch.squeeze(batch_rele_preds)
        else:
            rele_preds = ranker.predict(torch.squeeze(batch_ranking))
            rele_preds = torch.squeeze(rele_preds)

        if gpu: rele_preds = rele_preds.cpu()

        std_labels = torch.squeeze(batch_label)

        _, sorted_inds = torch.sort(rele_preds, descending=True)

        sys_sorted_labels = std_labels[sorted_inds]
        ideal_sorted_labels, _ = torch.sort(std_labels, descending=True)

        ndcg_at_ks = torch_nDCG_at_ks(sys_sorted_labels=sys_sorted_labels,
                                      ideal_sorted_labels=ideal_sorted_labels,
                                      ks=ks,
                                      multi_level_rele=multi_level_rele)
        sum_ndcg_at_ks = torch.add(sum_ndcg_at_ks, ndcg_at_ks)
        cnt += 1

    avg_ndcg_at_ks = sum_ndcg_at_ks / cnt
    return avg_ndcg_at_ks
Example #4
0
    def cal_metric_at_ks(self,
                         model_id,
                         all_std_labels=None,
                         all_preds=None,
                         group=None,
                         ks=[1, 3, 5, 10]):
        """
        Compute metric values with different cutoff values
        :param model:
        :param all_std_labels:
        :param all_preds:
        :param group:
        :param ks:
        :return:
        """
        cnt = torch.zeros(1)

        sum_ndcg_at_ks = torch.zeros(len(ks))
        sum_nerr_at_ks = torch.zeros(len(ks))
        sum_ap_at_ks = torch.zeros(len(ks))
        sum_p_at_ks = torch.zeros(len(ks))

        list_ndcg_at_ks_per_q = []
        list_err_at_ks_per_q = []
        list_ap_at_ks_per_q = []
        list_p_at_ks_per_q = []

        tor_all_std_labels, tor_all_preds = \
            torch.from_numpy(all_std_labels.astype(np.float32)), torch.from_numpy(all_preds.astype(np.float32))

        head = 0
        if model_id.startswith('LightGBM'):
            group = group.astype(np.int).tolist()
        for gr in group:
            tor_per_query_std_labels = tor_all_std_labels[head:head + gr]
            tor_per_query_preds = tor_all_preds[head:head + gr]
            head += gr

            _, tor_sorted_inds = torch.sort(tor_per_query_preds,
                                            descending=True)

            sys_sorted_labels = tor_per_query_std_labels[tor_sorted_inds]
            ideal_sorted_labels, _ = torch.sort(tor_per_query_std_labels,
                                                descending=True)

            ndcg_at_ks = torch_nDCG_at_ks(
                sys_sorted_labels=sys_sorted_labels,
                ideal_sorted_labels=ideal_sorted_labels,
                ks=ks,
                multi_level_rele=True)
            list_ndcg_at_ks_per_q.append(ndcg_at_ks.numpy())

            nerr_at_ks = torch_nerr_at_ks(
                sys_sorted_labels=sys_sorted_labels,
                ideal_sorted_labels=ideal_sorted_labels,
                ks=ks,
                multi_level_rele=True)
            list_err_at_ks_per_q.append(nerr_at_ks.numpy())

            ap_at_ks = torch_ap_at_ks(sys_sorted_labels=sys_sorted_labels,
                                      ideal_sorted_labels=ideal_sorted_labels,
                                      ks=ks)
            list_ap_at_ks_per_q.append(ap_at_ks.numpy())

            p_at_ks = torch_p_at_ks(sys_sorted_labels=sys_sorted_labels, ks=ks)
            list_p_at_ks_per_q.append(p_at_ks.numpy())

            sum_ndcg_at_ks = torch.add(sum_ndcg_at_ks, ndcg_at_ks)
            sum_nerr_at_ks = torch.add(sum_nerr_at_ks, nerr_at_ks)
            sum_ap_at_ks = torch.add(sum_ap_at_ks, ap_at_ks)
            sum_p_at_ks = torch.add(sum_p_at_ks, p_at_ks)
            cnt += 1

        tor_avg_ndcg_at_ks = sum_ndcg_at_ks / cnt
        avg_ndcg_at_ks = tor_avg_ndcg_at_ks.data.numpy()

        tor_avg_nerr_at_ks = sum_nerr_at_ks / cnt
        avg_nerr_at_ks = tor_avg_nerr_at_ks.data.numpy()

        tor_avg_ap_at_ks = sum_ap_at_ks / cnt
        avg_ap_at_ks = tor_avg_ap_at_ks.data.numpy()

        tor_avg_p_at_ks = sum_p_at_ks / cnt
        avg_p_at_ks = tor_avg_p_at_ks.data.numpy()

        return avg_ndcg_at_ks, avg_nerr_at_ks, avg_ap_at_ks, avg_p_at_ks,\
               list_ndcg_at_ks_per_q, list_err_at_ks_per_q, list_ap_at_ks_per_q, list_p_at_ks_per_q