def train_discriminator(self, discriminator, batch_std_sample_ranking=None, batch_gen_sample_ranking=None, PL_discriminator=None, replace_trick_4_discriminator=None): d_batch_preds_4_std = discriminator.predict(batch_std_sample_ranking, train=True) d_batch_preds_4_gen = discriminator.predict(batch_gen_sample_ranking, train=True) # debugging # discriminator.stop_training(d_batch_preds_4_std) # discriminator.stop_training(d_batch_preds_4_gen) if PL_discriminator: d_batch_log_ranking_prob_4_std = log_ranking_prob_Plackett_Luce(d_batch_preds_4_std) d_batch_log_ranking_prob_4_gen = log_ranking_prob_Plackett_Luce(d_batch_preds_4_gen) else: d_batch_log_ranking_prob_4_std = log_ranking_prob_Bradley_Terry(d_batch_preds_4_std) d_batch_log_ranking_prob_4_gen = log_ranking_prob_Bradley_Terry(d_batch_preds_4_gen) if replace_trick_4_discriminator: # replace trick dis_loss = torch.sum( d_batch_log_ranking_prob_4_gen - d_batch_log_ranking_prob_4_std) # objective to minimize else: # standard cross-entropy loss dis_loss = - (torch.sum(d_batch_log_ranking_prob_4_std) + torch.sum( torch.log(1.0 - d_batch_log_ranking_prob_4_gen))) discriminator.optimizer.zero_grad() dis_loss.backward() discriminator.optimizer.step()
def get_reward(self, reward_d_batch_preds_4_gen_sorted_as_g, PL_discriminator=None, replace_trick_4_generator=None, drop_discriminator_log_4_reward=None): if PL_discriminator: reward_d_batch_log_ranking_prob_4_gen = log_ranking_prob_Plackett_Luce( reward_d_batch_preds_4_gen_sorted_as_g) else: reward_d_batch_log_ranking_prob_4_gen = log_ranking_prob_Bradley_Terry( reward_d_batch_preds_4_gen_sorted_as_g) if replace_trick_4_generator: if drop_discriminator_log_4_reward: batch_rewards = -torch.exp( reward_d_batch_log_ranking_prob_4_gen) else: batch_rewards = -reward_d_batch_log_ranking_prob_4_gen else: if drop_discriminator_log_4_reward: batch_rewards = torch.exp( 1.0 - reward_d_batch_log_ranking_prob_4_gen) else: batch_rewards = torch.log( 1.0 - reward_d_batch_log_ranking_prob_4_gen) return batch_rewards
def burn_in(self, train_data=None): if self.optimal_train: for entry in train_data: qid, batch_ranking = entry[0], entry[1] if gpu: batch_ranking = batch_ranking.to(device) g_batch_pred = self.super_generator.predict(batch_ranking, train=True) g_batch_log_ranking = log_ranking_prob_Plackett_Luce(g_batch_pred) g_loss = -torch.mean(g_batch_log_ranking) # alternative debugging #g_batch_logcumsumexps = apply_LogCumsumExp(g_batch_preds) #g_loss = torch.sum(g_batch_logcumsumexps - g_batch_preds) self.super_generator.optimizer.zero_grad() g_loss.backward() self.super_generator.optimizer.step() d_batch_pred = self.super_discriminator.predict(batch_ranking, train=True) if self.PL_discriminator: d_batch_ranking_prob = log_ranking_prob_Plackett_Luce(d_batch_pred) else: d_batch_ranking_prob = log_ranking_prob_Bradley_Terry(d_batch_pred) d_loss = -torch.mean(d_batch_ranking_prob) # objective to minimize # alternative debugging #d_batch_logcumsumexps = apply_LogCumsumExp(d_batch_preds) #d_loss = torch.sum(d_batch_logcumsumexps - d_batch_preds) self.super_discriminator.optimizer.zero_grad() d_loss.backward() self.super_discriminator.optimizer.step()