Пример #1
0
    def __init__(self, eval_dict, data_dict, sf_para_dict=None, ad_para_dict=None, g_key='BT', sigma=1.0, gpu=False, device=None):
        super(IRFGAN_Pair, self).__init__(eval_dict=eval_dict, data_dict=data_dict, gpu=gpu, device=device)

        self.f_div_id = ad_para_dict['f_div_id']
        self.samples_per_query = ad_para_dict['samples_per_query']
        self.activation_f, self.conjugate_f = get_f_divergence_functions(self.f_div_id)
        self.dict_diff = dict()
        self.tensor = torch.cuda.FloatTensor if self.gpu else torch.FloatTensor

        assert g_key=='BT'
        '''
        Underlying formulation of generation
        (1) BT: the probability of observing a pair of ordered documents is formulated via Bradley-Terry model,
         i.e., p(d_i > d_j)=1/(1+exp(-sigma(s_i - s_j))), the default value of sigma is given as 1.0            
        '''
        self.g_key = g_key
        self.sigma = sigma # only used w.r.t. SR

        g_sf_para_dict = sf_para_dict

        d_sf_para_dict = copy.deepcopy(g_sf_para_dict)
        d_sf_para_dict[sf_para_dict['sf_id']]['apply_tl_af'] = False

        self.generator = Point_Generator(sf_para_dict=g_sf_para_dict, gpu=gpu, device=device)
        self.discriminator = Point_Discriminator(sf_para_dict=d_sf_para_dict, gpu=gpu, device=device)
Пример #2
0
    def __init__(self, eval_dict, data_dict, sf_para_dict=None, ad_para_dict=None, gpu=False, device=None):
        super(IRFGAN_Point, self).__init__(eval_dict=eval_dict, data_dict=data_dict, gpu=gpu, device=device)

        self.f_div_id = ad_para_dict['f_div_id']
        ''' muted due to default train_discriminator_generator_single_step() '''
        #self.d_epoches = ad_para_dict['d_epoches']
        #self.g_epoches = ad_para_dict['g_epoches']
        #self.ad_training_order = ad_para_dict['ad_training_order']
        self.samples_per_query = ad_para_dict['samples_per_query']

        self.activation_f, self.conjugate_f = get_f_divergence_functions(self.f_div_id)

        #sf_para_dict['ffnns']['apply_tl_af'] = False
        g_sf_para_dict = sf_para_dict

        d_sf_para_dict = copy.deepcopy(g_sf_para_dict)

        self.generator = Point_Generator(sf_para_dict=g_sf_para_dict)
        self.discriminator = Point_Discriminator(sf_para_dict=d_sf_para_dict)
Пример #3
0
class IRFGAN_Point(AdversarialMachine):
    '''  '''
    def __init__(self, eval_dict, data_dict, sf_para_dict=None, ad_para_dict=None, gpu=False, device=None):
        super(IRFGAN_Point, self).__init__(eval_dict=eval_dict, data_dict=data_dict, gpu=gpu, device=device)

        self.f_div_id = ad_para_dict['f_div_id']
        ''' muted due to default train_discriminator_generator_single_step() '''
        #self.d_epoches = ad_para_dict['d_epoches']
        #self.g_epoches = ad_para_dict['g_epoches']
        #self.ad_training_order = ad_para_dict['ad_training_order']
        self.samples_per_query = ad_para_dict['samples_per_query']

        self.activation_f, self.conjugate_f = get_f_divergence_functions(self.f_div_id)

        #sf_para_dict['ffnns']['apply_tl_af'] = False
        g_sf_para_dict = sf_para_dict

        d_sf_para_dict = copy.deepcopy(g_sf_para_dict)

        self.generator = Point_Generator(sf_para_dict=g_sf_para_dict)
        self.discriminator = Point_Discriminator(sf_para_dict=d_sf_para_dict)

    def fill_global_buffer(self, train_data, dict_buffer=None):
        ''' Buffer the number of positive documents per query '''
        assert train_data.presort is True  # this is required for efficient truth exampling

        for entry in train_data:
            qid, _, batch_label = entry[0], entry[1], entry[2]
            if not qid in dict_buffer:
                boolean_mat = torch.gt(batch_label, 0)
                num_pos = torch.sum(boolean_mat) # number of positive documents
                dict_buffer[qid] = num_pos


    def mini_max_train(self, train_data=None, generator=None, discriminator=None, global_buffer=None, single_step=True):
        if single_step:
            stop_training = self.train_discriminator_generator_single_step(train_data=train_data, generator=generator,
                                                            discriminator=discriminator, global_buffer=global_buffer)
            return stop_training
        else:
            if self.ad_training_order == 'DG': # being consistent with the provided code
                for d_epoch in range(self.d_epoches):
                    if d_epoch % 10 == 0:
                        generated_data = self.generate_data(train_data=train_data, generator=generator,
                                                            global_buffer=global_buffer)

                    self.train_discriminator(train_data=train_data, generated_data=generated_data, discriminator=discriminator)  # train discriminator

                for g_epoch in range(self.g_epoches):
                    stop_training = self.train_generator(train_data=train_data, generator=generator,
                                                         discriminator=discriminator, global_buffer=global_buffer)  # train generator
                    if stop_training: return stop_training

            else: # being consistent with Algorithms-1 in the paper
                for g_epoch in range(self.g_epoches):
                    stop_training = self.train_generator(train_data=train_data, generator=generator,
                                                         discriminator=discriminator, global_buffer=global_buffer)  # train generator
                    if stop_training: return stop_training

                for d_epoch in range(self.d_epoches):
                    if d_epoch % 10 == 0:
                        generated_data = self.generate_data(train_data=train_data, generator=generator,
                                                            global_buffer=global_buffer)

                    self.train_discriminator(train_data=train_data, generated_data=generated_data, discriminator=discriminator)  # train discriminator

            stop_training = False
            return stop_training


    def train_discriminator(self, train_data=None, generated_data=None, discriminator=None, **kwargs):
        for entry in train_data:
            qid, batch_ranking = entry[0], entry[1]

            if qid in generated_data:
                if self.gpu: batch_ranking = batch_ranking.to(self.device)

                pos_inds, neg_inds = generated_data[qid]

                true_docs = batch_ranking[0, pos_inds, :]
                fake_docs = batch_ranking[0, neg_inds, :]

                true_preds = discriminator.predict(true_docs, train=True)
                fake_preds = discriminator.predict(fake_docs, train=True)

                dis_loss = torch.mean(self.conjugate_f(self.activation_f(fake_preds))) - torch.mean(self.activation_f(true_preds))  # objective to minimize w.r.t. discriminator

                discriminator.optimizer.zero_grad()
                dis_loss.backward()
                discriminator.optimizer.step()


    def train_generator(self, train_data=None, generated_data=None, generator=None, discriminator=None,
                        global_buffer=None):
        for entry in train_data:
            qid, batch_ranking, batch_label = entry[0], entry[1], entry[2]

            num_pos = global_buffer[qid]
            if num_pos < 1: continue

            batch_pred = generator.predict(batch_ranking)  # [batch, size_ranking]
            pred_probs = F.softmax(torch.squeeze(batch_pred), dim=0)

            neg_inds = torch.multinomial(pred_probs, self.samples_per_query, replacement=False)
            fake_docs = batch_ranking[0, neg_inds, :]

            d_fake_preds = discriminator.predict(fake_docs)
            d_fake_preds = self.conjugate_f(self.activation_f(d_fake_preds))

            ger_loss = -torch.mean((torch.log(pred_probs[neg_inds]) * d_fake_preds))

            generator.optimizer.zero_grad()
            ger_loss.backward()
            generator.optimizer.step()

        stop_training = False
        return stop_training


    def generate_data(self, train_data=None, generator=None, global_buffer=None):
        ''' Sampling for training discriminator '''
        generated_data = dict()
        for entry in train_data:
            qid, batch_ranking, _ = entry[0], entry[1], entry[2]
            samples = self.per_query_generation(qid=qid, batch_ranking=batch_ranking, generator=generator,
                                                global_buffer=global_buffer)
            if samples is not None:
                generated_data[qid] = samples

        return generated_data

    def per_query_generation(self, qid, batch_ranking, generator, global_buffer):
        num_pos = global_buffer[qid]

        if num_pos >= 1:
            valid_num = min(num_pos, self.samples_per_query)
            pos_inds = torch.randperm(num_pos)[0:valid_num] # randomly select positive documents

            batch_pred = generator.predict(batch_ranking)  # [batch, size_ranking]
            pred_probs = F.softmax(torch.squeeze(batch_pred), dim=0)

            neg_inds = torch.multinomial(pred_probs, valid_num, replacement=True)

            return (pos_inds, neg_inds) # torch.LongTensor as index
        else:
            return None


    def train_discriminator_generator_single_step(self, train_data=None, generator=None, discriminator=None,
                                                  global_buffer=None):
        ''' Train both discriminator and generator with a single step per query '''
        for entry in train_data:
            qid, batch_ranking, batch_label = entry[0], entry[1], entry[2]
            if self.gpu: batch_ranking = batch_ranking.to(self.device)

            num_pos = global_buffer[qid]
            if num_pos < 1: continue

            valid_num = min(num_pos, self.samples_per_query)
            true_inds = torch.randperm(num_pos)[0:valid_num]  # randomly select positive documents

            batch_preds = generator.predict(batch_ranking, train=True)  # [batch, size_ranking]
            pred_probs = F.softmax(torch.squeeze(batch_preds), dim=0)

            if torch.isnan(pred_probs).any():
                stop_training = True
                return stop_training

            fake_inds = torch.multinomial(pred_probs, valid_num, replacement=False)

            #real data and generated data
            true_docs = batch_ranking[0, true_inds, :]
            fake_docs = batch_ranking[0, fake_inds, :]
            true_docs = torch.unsqueeze(true_docs, dim=0)
            fake_docs = torch.unsqueeze(fake_docs, dim=0)

            ''' optimize discriminator '''
            true_preds = discriminator.predict(true_docs, train=True)
            fake_preds = discriminator.predict(fake_docs, train=True)

            dis_loss = torch.mean(self.conjugate_f(self.activation_f(fake_preds))) - torch.mean(self.activation_f(true_preds))  # objective to minimize w.r.t. discriminator

            discriminator.optimizer.zero_grad()
            dis_loss.backward()
            discriminator.optimizer.step()

            ''' optimize generator '''  #
            d_fake_preds = discriminator.predict(fake_docs)
            d_fake_preds = self.conjugate_f(self.activation_f(d_fake_preds))

            ger_loss = -torch.mean((torch.log(pred_probs[fake_inds]) * d_fake_preds))

            generator.optimizer.zero_grad()
            ger_loss.backward()
            generator.optimizer.step()

        stop_training = False
        return stop_training

    def reset_generator(self):
        self.generator.reset_parameters()

    def reset_discriminator(self):
        self.discriminator.reset_parameters()

    def get_generator(self):
        return self.generator

    def get_discriminator(self):
        return self.discriminator
class IRFGAN_Pair(AdversarialMachine):
    ''' '''
    def __init__(self,
                 eval_dict,
                 data_dict,
                 sf_para_dict=None,
                 ad_para_dict=None,
                 g_key='BT',
                 sigma=1.0,
                 gpu=False,
                 device=None):
        super(IRFGAN_Pair, self).__init__(eval_dict=eval_dict,
                                          data_dict=data_dict,
                                          gpu=gpu,
                                          device=device)

        self.f_div_id = ad_para_dict['f_div_id']
        self.samples_per_query = ad_para_dict['samples_per_query']
        self.activation_f, self.conjugate_f = get_f_divergence_functions(
            self.f_div_id)
        self.dict_diff = dict()
        self.tensor = torch.cuda.FloatTensor if self.gpu else torch.FloatTensor

        assert g_key == 'BT'
        '''
        Underlying formulation of generation
        (1) BT: the probability of observing a pair of ordered documents is formulated via Bradley-Terry model,
         i.e., p(d_i > d_j)=1/(1+exp(-sigma(s_i - s_j))), the default value of sigma is given as 1.0            
        '''
        self.g_key = g_key
        self.sigma = sigma  # only used w.r.t. SR

        g_sf_para_dict = sf_para_dict

        d_sf_para_dict = copy.deepcopy(g_sf_para_dict)
        d_sf_para_dict['ffnns']['apply_tl_af'] = False

        self.generator = Point_Generator(sf_para_dict=g_sf_para_dict)
        self.discriminator = Point_Discriminator(sf_para_dict=d_sf_para_dict)

    def fill_global_buffer(self, train_data, dict_buffer=None):
        ''' Buffer the number of positive documents, and the number of non-positive documents per query '''
        assert train_data.presort is True  # this is required for efficient truth exampling

        if train_data.data_id in MSLETOR_SEMI:
            for entry in train_data:
                qid, _, batch_label = entry[0], entry[1], entry[2]
                if not qid in dict_buffer:
                    pos_boolean_mat = torch.gt(batch_label, 0)
                    num_pos = torch.sum(pos_boolean_mat)

                    explicit_boolean_mat = torch.ge(batch_label, 0)
                    num_explicit = torch.sum(explicit_boolean_mat)

                    ranking_size = batch_label.size(1)
                    num_neg_unk = ranking_size - num_pos
                    num_unk = ranking_size - num_explicit

                    num_unique_labels = torch.unique(batch_label).size(0)

                    dict_buffer[qid] = (num_pos, num_explicit, num_neg_unk,
                                        num_unk, num_unique_labels)
        else:
            for entry in train_data:
                qid, _, batch_label = entry[0], entry[1], entry[2]
                if not qid in dict_buffer:
                    pos_boolean_mat = torch.gt(batch_label, 0)
                    num_pos = torch.sum(pos_boolean_mat)

                    ranking_size = batch_label.size(1)

                    num_explicit = ranking_size
                    num_neg_unk = ranking_size - num_pos
                    num_unk = 0

                    num_unique_labels = torch.unique(batch_label).size(0)

                    dict_buffer[qid] = (num_pos, num_explicit, num_neg_unk,
                                        num_unk, num_unique_labels)

    def mini_max_train(self,
                       train_data=None,
                       generator=None,
                       discriminator=None,
                       global_buffer=None):
        '''
        Here it can not use the way of training like irgan-pair (still relying on single documents rather thank pairs),
        since ir-fgan requires to sample with two distributions.
        '''
        stop_training = self.train_discriminator_generator_single_step(
            train_data=train_data,
            generator=generator,
            discriminator=discriminator,
            global_buffer=global_buffer)
        return stop_training

    def train_discriminator_generator_single_step(self,
                                                  train_data=None,
                                                  generator=None,
                                                  discriminator=None,
                                                  global_buffer=None):
        ''' Train both discriminator and generator with a single step per query '''
        stop_training = False
        for entry in train_data:
            qid, batch_ranking, batch_label = entry[0], entry[1], entry[2]
            if self.gpu: batch_ranking = batch_ranking.type(self.tensor)

            sorted_std_labels = torch.squeeze(batch_label, dim=0)

            num_pos, num_explicit, num_neg_unk, num_unk, num_unique_labels = global_buffer[
                qid]

            if num_unique_labels < 2:  # check unique values, say all [1, 1, 1] generates no pairs
                continue

            true_head_inds, true_tail_inds = generate_true_pairs(
                qid=qid,
                sorted_std_labels=sorted_std_labels,
                num_pairs=self.samples_per_query,
                dict_diff=self.dict_diff,
                global_buffer=global_buffer)

            batch_preds = generator.predict(
                batch_ranking, train=True)  # [batch, size_ranking]

            # todo determine how to activation
            point_preds = torch.squeeze(batch_preds)

            if torch.isnan(point_preds).any():
                print('Including NaN error.')
                stop_training = True
                return stop_training

            #--generate samples
            if 'BT' == self.g_key:
                mat_diffs = torch.unsqueeze(
                    point_preds, dim=1) - torch.unsqueeze(point_preds, dim=0)
                mat_bt_probs = torch.sigmoid(mat_diffs)  # default delta=1.0

                fake_head_inds, fake_tail_inds = sample_points_Bernoulli(
                    mat_bt_probs, num_pairs=self.samples_per_query)
            else:
                raise NotImplementedError
            #--

            # real data and generated data
            true_head_docs = batch_ranking[:, true_head_inds, :]
            true_tail_docs = batch_ranking[:, true_tail_inds, :]
            fake_head_docs = batch_ranking[:, fake_head_inds, :]
            fake_tail_docs = batch_ranking[:, fake_tail_inds, :]
            ''' optimize discriminator '''
            true_head_preds = discriminator.predict(true_head_docs, train=True)
            true_tail_preds = discriminator.predict(true_tail_docs, train=True)
            true_preds = true_head_preds - true_tail_preds
            fake_head_preds = discriminator.predict(fake_head_docs, train=True)
            fake_tail_preds = discriminator.predict(fake_tail_docs, train=True)
            fake_preds = fake_head_preds - fake_tail_preds

            dis_loss = torch.mean(
                self.conjugate_f(self.activation_f(fake_preds))) - torch.mean(
                    self.activation_f(true_preds)
                )  # objective to minimize w.r.t. discriminator
            discriminator.optimizer.zero_grad()
            dis_loss.backward()
            discriminator.optimizer.step()
            ''' optimize generator '''  #
            d_fake_head_preds = discriminator.predict(fake_head_docs)
            d_fake_tail_preds = discriminator.predict(fake_tail_docs)
            d_fake_preds = self.conjugate_f(
                self.activation_f(d_fake_head_preds - d_fake_tail_preds))

            if 'BT' == self.g_key:
                log_g_probs = torch.log(mat_bt_probs[fake_head_inds,
                                                     fake_tail_inds].view(
                                                         1, -1))
            else:
                raise NotImplementedError

            g_batch_loss = -torch.mean(log_g_probs * d_fake_preds)

            generator.optimizer.zero_grad()
            g_batch_loss.backward()
            generator.optimizer.step()

        # after iteration ove train_data
        return stop_training

    def reset_generator(self):
        self.generator.reset_parameters()

    def reset_discriminator(self):
        self.discriminator.reset_parameters()

    def get_generator(self):
        return self.generator

    def get_discriminator(self):
        return self.discriminator