Пример #1
0
def clip_query_data(qid,
                    list_docids=None,
                    feature_mat=None,
                    std_label_vec=None,
                    binary_rele=False,
                    unknown_as_zero=False,
                    clip_query=None,
                    min_docs=None,
                    min_rele=1,
                    presort=None):
    """ Clip the data associated with the same query if required """
    if binary_rele:
        std_label_vec = np.clip(std_label_vec, a_min=-10,
                                a_max=1)  # to binary labels
    if unknown_as_zero:
        std_label_vec = np.clip(std_label_vec, a_min=0,
                                a_max=10)  # convert unknown as zero

    if clip_query:
        if feature_mat.shape[
                0] < min_docs:  # skip queries with documents that are fewer the pre-specified min_docs
            return None
        if (std_label_vec > 0).sum() < min_rele:
            # skip queries with no standard relevant documents, since there is no meaning for both training and testing.
            return None

    assert presort is not None
    if presort:
        '''
        Possible advantages: 1> saving time for evaluation; 
        2> saving time for some models, say the ones need optimal ranking
        '''
        des_inds = np_arg_shuffle_ties(
            std_label_vec, descending=True)  # sampling by shuffling ties
        feature_mat, std_label_vec = feature_mat[des_inds], std_label_vec[
            des_inds]
        '''
        if list_docids is None:
            list_docids = None
        else:
            list_docids = []
            for ind in des_inds:
                list_docids.append(list_docids[ind])
        '''
    return (qid, feature_mat, std_label_vec)
Пример #2
0
def clip_query_data(qid,
                    list_docids=None,
                    feature_mat=None,
                    std_label_vec=None,
                    binary_rele=False,
                    unknown_as_zero=False,
                    clip_query=None,
                    min_docs=None,
                    min_rele=1,
                    presort=True):
    ''' clip the data associated with the same query '''

    if binary_rele: std_label_vec = np.clip(std_label_vec, a_min=-10, a_max=1)
    if unknown_as_zero:
        std_label_vec = np.clip(std_label_vec, a_min=0, a_max=10)

    if clip_query:
        if feature_mat.shape[
                0] < min_docs:  # skip queries with documents that are fewer the pre-specified min_docs
            return None
        if (std_label_vec > 0).sum(
        ) < min_rele:  # skip queries with no standard relevant documents, since there is no meaning for both training and testing.
            return None
        return (qid, feature_mat, std_label_vec)
    else:
        if presort:
            des_inds = np_arg_shuffle_ties(
                std_label_vec, descending=True)  # sampling by shuffling ties
            feature_mat, std_label_vec = feature_mat[des_inds], std_label_vec[
                des_inds]
            '''
            if list_docids is None:
                list_docids = None
            else:
                list_docids = []
                for ind in des_inds:
                    list_docids.append(list_docids[ind])
            '''
        return (qid, feature_mat, std_label_vec)
Пример #3
0
    def __init__(self,
                 train,
                 file,
                 data_id=None,
                 data_dict=None,
                 sample_rankings_per_q=1,
                 shuffle=True,
                 hot=False,
                 eval_dict=None,
                 buffer=True,
                 given_scaler=None):

        assert data_id is not None or data_dict is not None
        if data_dict is None:
            data_dict = self.get_default_data_dict(data_id=data_id)

        self.train = train

        if data_dict['data_id'] in MSLETOR or data_dict['data_id'] in MSLRWEB \
                or data_dict['data_id'] in YAHOO_LTR or data_dict['data_id'] in YAHOO_LTR_5Fold \
                or data_dict['data_id'] in ISTELLA_LTR \
                or data_dict['data_id'] == 'IRGAN_MQ2008_Semi': # supported datasets

            self.check_load_setting(data_dict, eval_dict)

            perquery_file = get_buffer_file_name(data_id=data_id,
                                                 file=file,
                                                 data_dict=data_dict)

            if sample_rankings_per_q > 1:
                if hot:
                    torch_perquery_file = perquery_file.replace(
                        '.np', '_'.join([
                            'SP',
                            str(sample_rankings_per_q), 'Hot', '.torch'
                        ]))
                else:
                    torch_perquery_file = perquery_file.replace(
                        '.np',
                        '_'.join(['SP',
                                  str(sample_rankings_per_q), '.torch']))
            else:
                if hot:
                    torch_perquery_file = perquery_file.replace(
                        '.np', '_Hot.torch')
                else:
                    torch_perquery_file = perquery_file.replace(
                        '.np', '.torch')

            if eval_dict is not None:
                mask_label, mask_ratio, mask_type = eval_dict[
                    'mask_label'], eval_dict['mask_ratio'], eval_dict[
                        'mask_type']
                print(eval_dict)
                if mask_label:
                    mask_label_str = '_'.join(
                        [mask_type, 'Ratio', '{:,g}'.format(mask_ratio)])
                    torch_perquery_file = torch_perquery_file.replace(
                        '.torch', '_' + mask_label_str + '.torch')
            else:
                mask_label = False

            if os.path.exists(torch_perquery_file):
                print('loading buffered file ...')
                self.list_torch_Qs = pickle_load(torch_perquery_file)
            else:
                self.list_torch_Qs = []

                scale_data = data_dict['scale_data']
                scaler_id = data_dict[
                    'scaler_id'] if 'scaler_id' in data_dict else None
                list_Qs = iter_queries(in_file=file,
                                       data_dict=data_dict,
                                       scale_data=scale_data,
                                       scaler_id=scaler_id,
                                       perquery_file=perquery_file,
                                       buffer=buffer)

                list_inds = list(range(len(list_Qs)))
                for ind in list_inds:
                    qid, doc_reprs, doc_labels = list_Qs[ind]

                    if sample_rankings_per_q > 1:
                        assert mask_label is not True  # not supported since it is rarely used.

                        list_ranking = []
                        list_labels = []
                        for _ in range(self.sample_rankings_per_q):
                            des_inds = np_arg_shuffle_ties(
                                doc_labels,
                                descending=True)  # sampling by shuffling ties
                            list_ranking.append(doc_reprs[des_inds])
                            list_labels.append(doc_labels[des_inds])

                        batch_rankings = np.stack(list_ranking, axis=0)
                        batch_std_labels = np.stack(list_labels, axis=0)

                        torch_batch_rankings = torch.from_numpy(
                            batch_rankings).type(torch.FloatTensor)
                        torch_batch_std_labels = torch.from_numpy(
                            batch_std_labels).type(torch.FloatTensor)
                    else:
                        torch_batch_rankings = torch.from_numpy(
                            doc_reprs).type(torch.FloatTensor)
                        torch_batch_rankings = torch.unsqueeze(
                            torch_batch_rankings,
                            dim=0)  # a consistent batch dimension of size 1

                        torch_batch_std_labels = torch.from_numpy(
                            doc_labels).type(torch.FloatTensor)
                        torch_batch_std_labels = torch.unsqueeze(
                            torch_batch_std_labels, dim=0)

                        if mask_label:  # masking
                            if mask_type == 'rand_mask_rele':
                                torch_batch_rankings, torch_batch_std_labels = random_mask_rele_labels(
                                    batch_ranking=torch_batch_rankings,
                                    batch_label=torch_batch_std_labels,
                                    mask_ratio=mask_ratio,
                                    mask_value=0,
                                    presort=data_dict['presort'])

                            elif mask_type == 'rand_mask_all':
                                masked_res = random_mask_all_labels(
                                    batch_ranking=torch_batch_rankings,
                                    batch_label=torch_batch_std_labels,
                                    mask_ratio=mask_ratio,
                                    mask_value=0,
                                    presort=data_dict['presort'])
                                if masked_res is not None:
                                    torch_batch_rankings, torch_batch_std_labels = masked_res
                                else:
                                    continue
                            else:
                                raise NotImplementedError
                    if hot:
                        assert mask_label is not True  # not supported since it is rarely used.
                        max_rele_level = data_dict['max_rele_level']
                        assert max_rele_level is not None

                        torch_batch_std_hot_labels = get_one_hot_reprs(
                            torch_batch_std_labels)
                        batch_cnts = batch_count(
                            batch_std_labels=torch_batch_std_labels,
                            max_rele_grade=max_rele_level,
                            descending=True)

                        self.list_torch_Qs.append(
                            (qid, torch_batch_rankings, torch_batch_std_labels,
                             torch_batch_std_hot_labels, batch_cnts))
                    else:
                        self.list_torch_Qs.append((qid, torch_batch_rankings,
                                                   torch_batch_std_labels))
                #buffer
                #print('Num of q:', len(self.list_torch_Qs))
                if buffer:
                    parent_dir = Path(torch_perquery_file).parent
                    if not os.path.exists(parent_dir):
                        os.makedirs(parent_dir)
                    pickle_save(self.list_torch_Qs, torch_perquery_file)
        else:
            raise NotImplementedError

        self.hot = hot
        self.shuffle = shuffle