def per_query_generation(self, qid=None, batch_ranking=None, batch_label=None, pos_and_neg=None, generator=None, samples_per_query=None, shuffle_ties=None, top_k=None, temperature=None): ''' :param qid: :param batch_ranking: :param batch_label: :param pos_and_neg: corresponding to discriminator optimization or generator optimization :param generator: :param samples_per_query: :param shuffle_ties: :param top_k: :param temperature: :return: ''' g_batch_pred = generator.predict( batch_ranking) # [batch, size_ranking] batch_gen_stochastic_prob = gumbel_softmax( g_batch_pred, samples_per_query=samples_per_query, temperature=temperature, cuda=gpu, cuda_device=device) sorted_batch_gen_stochastic_probs, batch_gen_sto_sorted_inds = torch.sort( batch_gen_stochastic_prob, dim=1, descending=True) if pos_and_neg: # for training discriminator used_batch_label = batch_label if shuffle_ties: ''' There is not need to firstly filter out documents of '-1', due to the descending sorting and we only use the top ones BTW, the only required condition is: the number of non-minus-one documents is larger than top_k, which builds upon the customized mask_data() ''' per_query_label = torch.squeeze(used_batch_label) list_std_sto_sorted_inds = [] for i in range(samples_per_query): shuffle_ties_inds = arg_shuffle_ties(per_query_label, descending=True) list_std_sto_sorted_inds.append(shuffle_ties_inds) batch_std_sto_sorted_inds = torch.stack( list_std_sto_sorted_inds, dim=0) else: ''' # still using PL, with a small temperature! if self.eval_dict['mask_label']: # can not use gumbel_softmax by directly using '-1' raise NotImplementedError ''' batch_std_stochastic_prob = gumbel_softmax( used_batch_label, samples_per_query=samples_per_query, temperature=self.temperature_for_std_sampling) _, batch_std_sto_sorted_inds = torch.sort( batch_std_stochastic_prob, dim=1, descending=True ) # sort documents according to the predicted relevance list_pos_ranking, list_neg_ranking = [], [] if top_k is None: # using all documents for i in range(samples_per_query): pos_inds = batch_std_sto_sorted_inds[i, :] pos_ranking = batch_ranking[0, pos_inds, :] list_pos_ranking.append(pos_ranking) neg_inds = batch_gen_sto_sorted_inds[i, :] neg_ranking = batch_ranking[0, neg_inds, :] list_neg_ranking.append(neg_ranking) else: for i in range(samples_per_query): pos_inds = batch_std_sto_sorted_inds[i, 0:top_k] pos_ranking = batch_ranking[ 0, pos_inds, :] # sampled sublist of documents list_pos_ranking.append(pos_ranking) neg_inds = batch_gen_sto_sorted_inds[i, 0:top_k] neg_ranking = batch_ranking[0, neg_inds, :] list_neg_ranking.append(neg_ranking) batch_std_sample_ranking = torch.stack(list_pos_ranking, dim=0) batch_gen_sample_ranking = torch.stack(list_neg_ranking, dim=0) return batch_std_sample_ranking, batch_gen_sample_ranking else: # for training generator if top_k is None: return sorted_batch_gen_stochastic_probs, batch_gen_sto_sorted_inds else: list_g_sort_top_preds, list_g_sort_top_inds = [], [ ] # required to cope with ranking_size mismatch for i in range(samples_per_query): neg_inds = batch_gen_sto_sorted_inds[i, 0:top_k] list_g_sort_top_inds.append(neg_inds) top_gen_stochastic_probs = sorted_batch_gen_stochastic_probs[ i, 0:top_k] list_g_sort_top_preds.append(top_gen_stochastic_probs) top_sorted_batch_gen_stochastic_probs = torch.stack( list_g_sort_top_preds, dim=0) return top_sorted_batch_gen_stochastic_probs, list_g_sort_top_inds
def per_query_generation(self, qid=None, batch_ranking=None, batch_label=None, pos_and_neg=None, generator=None, samples_per_query=None, top_k=None, temperature=None): ''' :param pos_and_neg: corresponding to discriminator optimization or generator optimization ''' g_batch_pred = generator.predict( batch_ranking) # [batch, size_ranking] batch_gen_stochastic_prob = gumbel_softmax( g_batch_pred, samples_per_query=samples_per_query, temperature=temperature, cuda=self.gpu, cuda_device=self.device) sorted_batch_gen_stochastic_probs, batch_gen_sto_sorted_inds = torch.sort( batch_gen_stochastic_prob, dim=1, descending=True) if pos_and_neg: # for training discriminator used_batch_label = batch_label ''' Generate truth-ranking based on shuffling ties There is not need to firstly filter out documents of '-1', due to the descending sorting and we only use the top ones BTW, the only required condition is: the number of non-minus-one documents is larger than top_k, which builds upon the customized mask_data() ''' per_query_label = torch.squeeze(used_batch_label) list_std_sto_sorted_inds = [] for i in range(samples_per_query): shuffle_ties_inds = arg_shuffle_ties(per_query_label, descending=True) list_std_sto_sorted_inds.append(shuffle_ties_inds) batch_std_sto_sorted_inds = torch.stack(list_std_sto_sorted_inds, dim=0) list_pos_ranking, list_neg_ranking = [], [] if top_k is None: # using all documents for i in range(samples_per_query): pos_inds = batch_std_sto_sorted_inds[i, :] pos_ranking = batch_ranking[0, pos_inds, :] list_pos_ranking.append(pos_ranking) neg_inds = batch_gen_sto_sorted_inds[i, :] neg_ranking = batch_ranking[0, neg_inds, :] list_neg_ranking.append(neg_ranking) else: for i in range(samples_per_query): pos_inds = batch_std_sto_sorted_inds[i, 0:top_k] pos_ranking = batch_ranking[ 0, pos_inds, :] # sampled sublist of documents list_pos_ranking.append(pos_ranking) neg_inds = batch_gen_sto_sorted_inds[i, 0:top_k] neg_ranking = batch_ranking[0, neg_inds, :] list_neg_ranking.append(neg_ranking) batch_std_sample_ranking = torch.stack(list_pos_ranking, dim=0) batch_gen_sample_ranking = torch.stack(list_neg_ranking, dim=0) return batch_std_sample_ranking, batch_gen_sample_ranking else: # for training generator if top_k is None: return sorted_batch_gen_stochastic_probs, batch_gen_sto_sorted_inds else: list_g_sort_top_preds, list_g_sort_top_inds = [], [ ] # required to cope with ranking_size mismatch for i in range(samples_per_query): neg_inds = batch_gen_sto_sorted_inds[i, 0:top_k] list_g_sort_top_inds.append(neg_inds) top_gen_stochastic_probs = sorted_batch_gen_stochastic_probs[ i, 0:top_k] list_g_sort_top_preds.append(top_gen_stochastic_probs) top_sorted_batch_gen_stochastic_probs = torch.stack( list_g_sort_top_preds, dim=0) return top_sorted_batch_gen_stochastic_probs, list_g_sort_top_inds