Esempio n. 1
0
    def train_discriminator(self, discriminator, batch_std_sample_ranking=None, batch_gen_sample_ranking=None, PL_discriminator=None, replace_trick_4_discriminator=None):

        d_batch_preds_4_std = discriminator.predict(batch_std_sample_ranking, train=True)
        d_batch_preds_4_gen = discriminator.predict(batch_gen_sample_ranking, train=True)

        # debugging
        # discriminator.stop_training(d_batch_preds_4_std)
        # discriminator.stop_training(d_batch_preds_4_gen)

        if PL_discriminator:
            d_batch_log_ranking_prob_4_std = log_ranking_prob_Plackett_Luce(d_batch_preds_4_std)
            d_batch_log_ranking_prob_4_gen = log_ranking_prob_Plackett_Luce(d_batch_preds_4_gen)
        else:
            d_batch_log_ranking_prob_4_std = log_ranking_prob_Bradley_Terry(d_batch_preds_4_std)
            d_batch_log_ranking_prob_4_gen = log_ranking_prob_Bradley_Terry(d_batch_preds_4_gen)

        if replace_trick_4_discriminator:  # replace trick
            dis_loss = torch.sum(
                d_batch_log_ranking_prob_4_gen - d_batch_log_ranking_prob_4_std)  # objective to minimize

        else:  # standard cross-entropy loss
            dis_loss = - (torch.sum(d_batch_log_ranking_prob_4_std) + torch.sum(
                torch.log(1.0 - d_batch_log_ranking_prob_4_gen)))

        discriminator.optimizer.zero_grad()
        dis_loss.backward()
        discriminator.optimizer.step()
    def get_reward(self,
                   reward_d_batch_preds_4_gen_sorted_as_g,
                   PL_discriminator=None,
                   replace_trick_4_generator=None,
                   drop_discriminator_log_4_reward=None):

        if PL_discriminator:
            reward_d_batch_log_ranking_prob_4_gen = log_ranking_prob_Plackett_Luce(
                reward_d_batch_preds_4_gen_sorted_as_g)
        else:
            reward_d_batch_log_ranking_prob_4_gen = log_ranking_prob_Bradley_Terry(
                reward_d_batch_preds_4_gen_sorted_as_g)

        if replace_trick_4_generator:
            if drop_discriminator_log_4_reward:
                batch_rewards = -torch.exp(
                    reward_d_batch_log_ranking_prob_4_gen)
            else:
                batch_rewards = -reward_d_batch_log_ranking_prob_4_gen
        else:
            if drop_discriminator_log_4_reward:
                batch_rewards = torch.exp(
                    1.0 - reward_d_batch_log_ranking_prob_4_gen)
            else:
                batch_rewards = torch.log(
                    1.0 - reward_d_batch_log_ranking_prob_4_gen)

        return batch_rewards
Esempio n. 3
0
    def burn_in(self, train_data=None):
        if self.optimal_train:
            for entry in train_data:
                qid, batch_ranking = entry[0], entry[1]
                if gpu: batch_ranking = batch_ranking.to(device)

                g_batch_pred = self.super_generator.predict(batch_ranking, train=True)
                g_batch_log_ranking = log_ranking_prob_Plackett_Luce(g_batch_pred)
                g_loss = -torch.mean(g_batch_log_ranking)

                # alternative debugging
                #g_batch_logcumsumexps = apply_LogCumsumExp(g_batch_preds)
                #g_loss = torch.sum(g_batch_logcumsumexps - g_batch_preds)

                self.super_generator.optimizer.zero_grad()
                g_loss.backward()
                self.super_generator.optimizer.step()

                d_batch_pred = self.super_discriminator.predict(batch_ranking, train=True)

                if self.PL_discriminator:
                    d_batch_ranking_prob = log_ranking_prob_Plackett_Luce(d_batch_pred)
                else:
                    d_batch_ranking_prob = log_ranking_prob_Bradley_Terry(d_batch_pred)

                d_loss = -torch.mean(d_batch_ranking_prob)  # objective to minimize

                # alternative debugging
                #d_batch_logcumsumexps = apply_LogCumsumExp(d_batch_preds)
                #d_loss = torch.sum(d_batch_logcumsumexps - d_batch_preds)

                self.super_discriminator.optimizer.zero_grad()
                d_loss.backward()
                self.super_discriminator.optimizer.step()