Exemplo n.º 1
0
    def per_query_generation(self,
                             qid=None,
                             batch_ranking=None,
                             batch_label=None,
                             pos_and_neg=None,
                             generator=None,
                             samples_per_query=None,
                             shuffle_ties=None,
                             top_k=None,
                             temperature=None):
        '''
        :param qid:
        :param batch_ranking:
        :param batch_label:
        :param pos_and_neg: corresponding to discriminator optimization or generator optimization
        :param generator:
        :param samples_per_query:
        :param shuffle_ties:
        :param top_k:
        :param temperature:
        :return:
        '''
        g_batch_pred = generator.predict(
            batch_ranking)  # [batch, size_ranking]
        batch_gen_stochastic_prob = gumbel_softmax(
            g_batch_pred,
            samples_per_query=samples_per_query,
            temperature=temperature,
            cuda=gpu,
            cuda_device=device)
        sorted_batch_gen_stochastic_probs, batch_gen_sto_sorted_inds = torch.sort(
            batch_gen_stochastic_prob, dim=1, descending=True)

        if pos_and_neg:  # for training discriminator
            used_batch_label = batch_label

            if shuffle_ties:
                '''
                There is not need to firstly filter out documents of '-1', due to the descending sorting and we only use the top ones
                BTW, the only required condition is: the number of non-minus-one documents is larger than top_k, which builds upon the customized mask_data()
                '''
                per_query_label = torch.squeeze(used_batch_label)
                list_std_sto_sorted_inds = []
                for i in range(samples_per_query):
                    shuffle_ties_inds = arg_shuffle_ties(per_query_label,
                                                         descending=True)
                    list_std_sto_sorted_inds.append(shuffle_ties_inds)

                batch_std_sto_sorted_inds = torch.stack(
                    list_std_sto_sorted_inds, dim=0)
            else:
                '''
                # still using PL, with a small temperature!
                if self.eval_dict['mask_label']:
                    # can not use gumbel_softmax by directly using '-1'
                    raise NotImplementedError
                '''
                batch_std_stochastic_prob = gumbel_softmax(
                    used_batch_label,
                    samples_per_query=samples_per_query,
                    temperature=self.temperature_for_std_sampling)
                _, batch_std_sto_sorted_inds = torch.sort(
                    batch_std_stochastic_prob, dim=1, descending=True
                )  # sort documents according to the predicted relevance

            list_pos_ranking, list_neg_ranking = [], []
            if top_k is None:  # using all documents
                for i in range(samples_per_query):
                    pos_inds = batch_std_sto_sorted_inds[i, :]
                    pos_ranking = batch_ranking[0, pos_inds, :]
                    list_pos_ranking.append(pos_ranking)

                    neg_inds = batch_gen_sto_sorted_inds[i, :]
                    neg_ranking = batch_ranking[0, neg_inds, :]
                    list_neg_ranking.append(neg_ranking)
            else:
                for i in range(samples_per_query):
                    pos_inds = batch_std_sto_sorted_inds[i, 0:top_k]
                    pos_ranking = batch_ranking[
                        0, pos_inds, :]  # sampled sublist of documents
                    list_pos_ranking.append(pos_ranking)

                    neg_inds = batch_gen_sto_sorted_inds[i, 0:top_k]
                    neg_ranking = batch_ranking[0, neg_inds, :]
                    list_neg_ranking.append(neg_ranking)

            batch_std_sample_ranking = torch.stack(list_pos_ranking, dim=0)
            batch_gen_sample_ranking = torch.stack(list_neg_ranking, dim=0)

            return batch_std_sample_ranking, batch_gen_sample_ranking

        else:  # for training generator
            if top_k is None:
                return sorted_batch_gen_stochastic_probs, batch_gen_sto_sorted_inds
            else:
                list_g_sort_top_preds, list_g_sort_top_inds = [], [
                ]  # required to cope with ranking_size mismatch
                for i in range(samples_per_query):
                    neg_inds = batch_gen_sto_sorted_inds[i, 0:top_k]
                    list_g_sort_top_inds.append(neg_inds)

                    top_gen_stochastic_probs = sorted_batch_gen_stochastic_probs[
                        i, 0:top_k]
                    list_g_sort_top_preds.append(top_gen_stochastic_probs)

                top_sorted_batch_gen_stochastic_probs = torch.stack(
                    list_g_sort_top_preds, dim=0)
                return top_sorted_batch_gen_stochastic_probs, list_g_sort_top_inds
Exemplo n.º 2
0
    def per_query_generation(self,
                             qid=None,
                             batch_ranking=None,
                             batch_label=None,
                             pos_and_neg=None,
                             generator=None,
                             samples_per_query=None,
                             top_k=None,
                             temperature=None):
        '''
        :param pos_and_neg: corresponding to discriminator optimization or generator optimization
        '''
        g_batch_pred = generator.predict(
            batch_ranking)  # [batch, size_ranking]
        batch_gen_stochastic_prob = gumbel_softmax(
            g_batch_pred,
            samples_per_query=samples_per_query,
            temperature=temperature,
            cuda=self.gpu,
            cuda_device=self.device)
        sorted_batch_gen_stochastic_probs, batch_gen_sto_sorted_inds = torch.sort(
            batch_gen_stochastic_prob, dim=1, descending=True)

        if pos_and_neg:  # for training discriminator
            used_batch_label = batch_label
            ''' Generate truth-ranking based on shuffling ties
            There is not need to firstly filter out documents of '-1', due to the descending sorting and we only use the top ones
            BTW, the only required condition is: the number of non-minus-one documents is larger than top_k, which builds upon the customized mask_data()
            '''
            per_query_label = torch.squeeze(used_batch_label)
            list_std_sto_sorted_inds = []
            for i in range(samples_per_query):
                shuffle_ties_inds = arg_shuffle_ties(per_query_label,
                                                     descending=True)
                list_std_sto_sorted_inds.append(shuffle_ties_inds)
            batch_std_sto_sorted_inds = torch.stack(list_std_sto_sorted_inds,
                                                    dim=0)

            list_pos_ranking, list_neg_ranking = [], []
            if top_k is None:  # using all documents
                for i in range(samples_per_query):
                    pos_inds = batch_std_sto_sorted_inds[i, :]
                    pos_ranking = batch_ranking[0, pos_inds, :]
                    list_pos_ranking.append(pos_ranking)

                    neg_inds = batch_gen_sto_sorted_inds[i, :]
                    neg_ranking = batch_ranking[0, neg_inds, :]
                    list_neg_ranking.append(neg_ranking)
            else:
                for i in range(samples_per_query):
                    pos_inds = batch_std_sto_sorted_inds[i, 0:top_k]
                    pos_ranking = batch_ranking[
                        0, pos_inds, :]  # sampled sublist of documents
                    list_pos_ranking.append(pos_ranking)

                    neg_inds = batch_gen_sto_sorted_inds[i, 0:top_k]
                    neg_ranking = batch_ranking[0, neg_inds, :]
                    list_neg_ranking.append(neg_ranking)

            batch_std_sample_ranking = torch.stack(list_pos_ranking, dim=0)
            batch_gen_sample_ranking = torch.stack(list_neg_ranking, dim=0)

            return batch_std_sample_ranking, batch_gen_sample_ranking

        else:  # for training generator
            if top_k is None:
                return sorted_batch_gen_stochastic_probs, batch_gen_sto_sorted_inds
            else:
                list_g_sort_top_preds, list_g_sort_top_inds = [], [
                ]  # required to cope with ranking_size mismatch
                for i in range(samples_per_query):
                    neg_inds = batch_gen_sto_sorted_inds[i, 0:top_k]
                    list_g_sort_top_inds.append(neg_inds)

                    top_gen_stochastic_probs = sorted_batch_gen_stochastic_probs[
                        i, 0:top_k]
                    list_g_sort_top_preds.append(top_gen_stochastic_probs)

                top_sorted_batch_gen_stochastic_probs = torch.stack(
                    list_g_sort_top_preds, dim=0)
                return top_sorted_batch_gen_stochastic_probs, list_g_sort_top_inds