def __init__(self, opt): super(DPGANInstructor, self).__init__(opt) # generator, discriminator self.gen = DPGAN_G(cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, cfg.max_seq_len, cfg.padding_idx, gpu=cfg.CUDA) self.dis = DPGAN_D(cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, cfg.max_seq_len, cfg.padding_idx, gpu=cfg.CUDA) self.init_model() # Optimizer self.gen_opt = optim.Adam(self.gen.parameters(), lr=cfg.gen_lr) self.gen_adv_opt = optim.Adam(self.gen.parameters(), lr=cfg.gen_lr) self.dis_opt = optim.Adam(self.dis.parameters(), lr=cfg.dis_lr)
class DPGANInstructor(BasicInstructor): def __init__(self, opt): super(DPGANInstructor, self).__init__(opt) # generator, discriminator self.gen = DPGAN_G(cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, cfg.max_seq_len, cfg.padding_idx, gpu=cfg.CUDA) self.dis = DPGAN_D(cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, cfg.max_seq_len, cfg.padding_idx, gpu=cfg.CUDA) self.init_model() # Optimizer self.gen_opt = optim.Adam(self.gen.parameters(), lr=cfg.gen_lr) self.gen_adv_opt = optim.Adam(self.gen.parameters(), lr=cfg.gen_lr) self.dis_opt = optim.Adam(self.dis.parameters(), lr=cfg.dis_lr) def _run(self): # ===PRE-TRAINING=== # TRAIN GENERATOR if not cfg.gen_pretrain: self.log.info('Starting Generator MLE Training...') self.pretrain_generator(cfg.MLE_train_epoch) if cfg.if_save and not cfg.if_test: torch.save(self.gen.state_dict(), cfg.pretrained_gen_path) print('Save pre-trained generator: {}'.format(cfg.pretrained_gen_path)) # # ===TRAIN DISCRIMINATOR==== if not cfg.dis_pretrain: self.log.info('Starting Discriminator Training...') self.train_discriminator(cfg.d_step, cfg.d_epoch, 'MLE') if cfg.if_save and not cfg.if_test: torch.save(self.dis.state_dict(), cfg.pretrained_dis_path) print('Save pre-trained discriminator: {}'.format(cfg.pretrained_dis_path)) # ===ADVERSARIAL TRAINING=== self.log.info('Starting Adversarial Training...') self.log.info('Initial generator: %s' % (self.cal_metrics(fmt_str=True))) for adv_epoch in range(cfg.ADV_train_epoch): self.log.info('-----\nADV EPOCH %d\n-----' % adv_epoch) self.sig.update() if self.sig.adv_sig: self.adv_train_generator(cfg.ADV_g_step) # Generator self.train_discriminator(cfg.ADV_d_step, cfg.ADV_d_epoch, 'ADV') # Discriminator if adv_epoch % cfg.adv_log_step == 0 or adv_epoch == cfg.ADV_train_epoch - 1: if cfg.if_save and not cfg.if_test: self._save('ADV', adv_epoch) else: self.log.info('>>> Stop by adv_signal! Finishing adversarial training...') break def _test(self): print('>>> Begin test...') self._run() pass def pretrain_generator(self, epochs): """ Max Likelihood Pre-training for the generator """ for epoch in range(epochs): self.sig.update() if self.sig.pre_sig: pre_loss = self.train_gen_epoch(self.gen, self.oracle_data.loader, self.mle_criterion, self.gen_opt) # ===Test=== if epoch % cfg.pre_log_step == 0 or epoch == epochs - 1: self.log.info( '[MLE-GEN] epoch %d : pre_loss = %.4f, %s' % (epoch, pre_loss, self.cal_metrics(fmt_str=True))) if cfg.if_save and not cfg.if_test: self._save('MLE', epoch) else: self.log.info('>>> Stop by pre signal, skip to adversarial training...') break def adv_train_generator(self, g_step): """ The gen is trained using policy gradients, using the reward from the discriminator. Training is done for num_batches batches. """ discount_rate = 1 total_g_loss = 0 dis_count_list = [discount_rate ** i for i in range(cfg.max_seq_len)] dis_count_matrix = torch.Tensor(dis_count_list).unsqueeze(0).repeat(cfg.batch_size, 1) if cfg.CUDA: dis_count_matrix = dis_count_matrix.cuda() for step in range(g_step): inp = self.oracle_data.random_batch()['input'] if cfg.CUDA: inp = inp.cuda() gen_sample, gen_sample_log_prob = self.gen.sample_teacher_forcing(inp) word_reward, sentence_reward = self.dis.getReward(gen_sample) sentence_reward = sentence_reward.repeat(1, cfg.max_seq_len) reward_matrix = sentence_reward * word_reward * dis_count_matrix for i in range(cfg.max_seq_len): reward_matrix[:, i] = reward_matrix[:, i:].sum(dim=-1) adv_loss = torch.sum(gen_sample_log_prob * reward_matrix) self.optimize(self.gen_adv_opt, adv_loss, self.gen) total_g_loss += adv_loss.item() # ===Test=== self.log.info( '[ADV-GEN]: g_loss = %.4f, %s' % (total_g_loss / (g_step * cfg.batch_size), self.cal_metrics(fmt_str=True))) def train_discriminator(self, d_step, d_epoch, phase='MLE'): """ Training the discriminator on real_data_samples (positive) and generated samples from gen (negative). Samples are drawn d_step times, and the discriminator is trained for d_epoch d_epoch. """ # prepare loader for validate global d_loss, train_acc pos_val = self.oracle.sample(8 * cfg.batch_size, 4 * cfg.batch_size) neg_val = self.gen.sample(8 * cfg.batch_size, 4 * cfg.batch_size) for step in range(d_step): # prepare loader for training pos_samples = self.oracle_samples # not re-sample the Oracle data neg_samples = self.gen.sample(pos_samples.size(0), 4 * cfg.batch_size) for epoch in range(d_epoch): # ===Train=== self.train_dis_epoch(self.dis, pos_samples, neg_samples, self.dis_opt) # ===Test=== pos_reward, neg_reward = self.eval_dis(self.dis, pos_val, neg_val) self.log.info('[%s-DIS] d_step %d: pos_reward = %.4f, neg_reward = %.4f,' % ( phase, step, pos_reward.item(), neg_reward.item())) if cfg.if_save and not cfg.if_test: torch.save(self.dis.state_dict(), cfg.pretrained_dis_path) def eval_dis(self, model, pos_val, neg_val): _, pos_reward = model.getReward(pos_val) _, neg_reward = model.getReward(neg_val) return torch.mean(pos_reward), torch.mean(neg_reward) def train_dis_epoch(self, model, pos_samples, neg_samples, optimizer): num_samples = pos_samples.size(0) num_batch = num_samples // cfg.batch_size for i in range(num_batch): pos_sample = pos_samples[i * cfg.batch_size: (i + 1) * cfg.batch_size] neg_sample = neg_samples[i * cfg.batch_size: (i + 1) * cfg.batch_size] _, pos_reward = model.getReward(pos_sample) _, neg_reward = model.getReward(neg_sample) loss = -torch.mean(pos_reward) + torch.mean(neg_reward) self.optimize(optimizer, loss, model)