class LeakGANInstructor(BasicInstructor): def __init__(self, opt): super(LeakGANInstructor, self).__init__(opt) # generator, discriminator self.gen = LeakGAN_G(cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, cfg.max_seq_len, cfg.padding_idx, cfg.goal_size, cfg.step_size, cfg.CUDA) self.dis = LeakGAN_D(cfg.dis_embed_dim, cfg.vocab_size, cfg.padding_idx, gpu=cfg.CUDA) self.init_model() # optimizer mana_params, work_params = self.gen.split_params() mana_opt = optim.Adam(mana_params, lr=cfg.gen_lr) work_opt = optim.Adam(work_params, lr=cfg.gen_lr) self.gen_opt = [mana_opt, work_opt] self.dis_opt = optim.Adam(self.dis.parameters(), lr=cfg.dis_lr) def _run(self): for inter_num in range(cfg.inter_epoch): self.log.info('>>> Interleaved Round %d...' % inter_num) self.sig.update() # update signal if self.sig.pre_sig: # ===DISCRIMINATOR PRE-TRAINING=== if not cfg.dis_pretrain: self.log.info('Starting Discriminator Training...') self.train_discriminator(cfg.d_step, cfg.d_epoch) if cfg.if_save and not cfg.if_test: torch.save(self.dis.state_dict(), cfg.pretrained_dis_path) print('Save pre-trained discriminator: {}'.format( cfg.pretrained_dis_path)) # ===GENERATOR MLE TRAINING=== if not cfg.gen_pretrain: self.log.info('Starting Generator MLE Training...') self.pretrain_generator(cfg.MLE_train_epoch) if cfg.if_save and not cfg.if_test: torch.save(self.gen.state_dict(), cfg.pretrained_gen_path) print('Save pre-trained generator: {}'.format( cfg.pretrained_gen_path)) else: self.log.info( '>>> Stop by pre_signal! Skip to adversarial training...') break # ===ADVERSARIAL TRAINING=== self.log.info('Starting Adversarial Training...') self.log.info('Initial generator: %s' % (str(self.cal_metrics(fmt_str=True)))) for adv_epoch in range(cfg.ADV_train_epoch): self.log.info('-----\nADV EPOCH %d\n-----' % adv_epoch) self.sig.update() if self.sig.adv_sig: self.adv_train_generator(cfg.ADV_g_step) # Generator self.train_discriminator(cfg.ADV_d_step, cfg.ADV_d_epoch, 'ADV') # Discriminator if adv_epoch % cfg.adv_log_step == 0 or adv_epoch == cfg.ADV_train_epoch - 1: if cfg.if_save and not cfg.if_test: self._save('ADV', adv_epoch) else: self.log.info( '>>> Stop by adv_signal! Finishing adversarial training...' ) break def _test(self): print('>>> Begin test...') self._run() pass def pretrain_generator(self, epochs): """ Max Likelihood Pretraining for the gen - gen_opt: [mana_opt, work_opt] """ for epoch in range(epochs): self.sig.update() if self.sig.pre_sig: pre_mana_loss = 0 pre_work_loss = 0 # ===Train=== for i, data in enumerate(self.oracle_data.loader): inp, target = data['input'], data['target'] if cfg.CUDA: inp, target = inp.cuda(), target.cuda() mana_loss, work_loss = self.gen.pretrain_loss( target, self.dis) self.optimize_multi(self.gen_opt, [mana_loss, work_loss]) pre_mana_loss += mana_loss.data.item() pre_work_loss += work_loss.data.item() pre_mana_loss = pre_mana_loss / len(self.oracle_data.loader) pre_work_loss = pre_work_loss / len(self.oracle_data.loader) # ===Test=== if epoch % cfg.pre_log_step == 0 or epoch == epochs - 1: self.log.info( '[MLE-GEN] epoch %d : pre_mana_loss = %.4f, pre_work_loss = %.4f, %s' % (epoch, pre_mana_loss, pre_work_loss, self.cal_metrics(fmt_str=True))) if cfg.if_save and not cfg.if_test: self._save('MLE', epoch) else: self.log.info( '>>> Stop by pre signal, skip to adversarial training...') break def adv_train_generator(self, g_step, current_k=0): """ The gen is trained using policy gradients, using the reward from the discriminator. Training is done for num_batches batches. """ rollout_func = rollout.ROLLOUT(self.gen, cfg.CUDA) adv_mana_loss = 0 adv_work_loss = 0 for step in range(g_step): with torch.no_grad(): gen_samples = self.gen.sample( cfg.batch_size, cfg.batch_size, self.dis, train=True) # !!! train=True, the only place inp, target = GenDataIter.prepare(gen_samples, gpu=cfg.CUDA) # ===Train=== rewards = rollout_func.get_reward_leakgan( target, cfg.rollout_num, self.dis, current_k).cpu() # reward with MC search mana_loss, work_loss = self.gen.adversarial_loss( target, rewards, self.dis) # update parameters self.optimize_multi(self.gen_opt, [mana_loss, work_loss]) adv_mana_loss += mana_loss.data.item() adv_work_loss += work_loss.data.item() # ===Test=== self.log.info( '[ADV-GEN] adv_mana_loss = %.4f, adv_work_loss = %.4f, %s' % (adv_mana_loss / g_step, adv_work_loss / g_step, self.cal_metrics(fmt_str=True))) def train_discriminator(self, d_step, d_epoch, phase='MLE'): """ Training the discriminator on real_data_samples (positive) and generated samples from gen (negative). Samples are drawn d_step times, and the discriminator is trained for d_epoch d_epoch. """ # prepare loader for validate global d_loss, train_acc pos_val = self.oracle.sample(8 * cfg.batch_size, cfg.batch_size) neg_val = self.gen.sample(8 * cfg.batch_size, cfg.batch_size, self.dis) dis_eval_data = DisDataIter(pos_val, neg_val) for step in range(d_step): # prepare loader for training pos_samples = self.oracle.sample( cfg.samples_num, cfg.batch_size) # re-sample the Oracle Data neg_samples = self.gen.sample(cfg.samples_num, cfg.batch_size, self.dis) dis_data = DisDataIter(pos_samples, neg_samples) for epoch in range(d_epoch): # ===Train=== d_loss, train_acc = self.train_dis_epoch( self.dis, dis_data.loader, self.dis_criterion, self.dis_opt) # ===Test=== _, eval_acc = self.eval_dis(self.dis, dis_eval_data.loader, self.dis_criterion) self.log.info( '[%s-DIS] d_step %d: d_loss = %.4f, train_acc = %.4f, eval_acc = %.4f,' % (phase, step, d_loss, train_acc, eval_acc)) def cal_metrics(self, fmt_str=False): # Prepare data for evaluation gen_data = GenDataIter( self.gen.sample(cfg.samples_num, cfg.batch_size, self.dis)) # Reset metrics self.nll_oracle.reset(self.oracle, gen_data.loader) self.nll_gen.reset(self.gen, self.oracle_data.loader, leak_dis=self.dis) self.nll_div.reset(self.gen, gen_data.loader, leak_dis=self.dis) if fmt_str: return ', '.join([ '%s = %s' % (metric.get_name(), metric.get_score()) for metric in self.all_metrics ]) else: return [metric.get_score() for metric in self.all_metrics] def _save(self, phase, epoch): torch.save( self.gen.state_dict(), cfg.save_model_root + 'gen_{}_{:05d}.pt'.format(phase, epoch)) save_sample_path = cfg.save_samples_root + 'samples_{}_{:05d}.txt'.format( phase, epoch) samples = self.gen.sample(cfg.batch_size, cfg.batch_size, self.dis) write_tensor(save_sample_path, samples)
class LeakGANInstructor(BasicInstructor): def __init__(self, opt): super(LeakGANInstructor, self).__init__(opt) # generator, discriminator self.gen = LeakGAN_G(cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, cfg.max_seq_len, cfg.padding_idx, cfg.goal_size, cfg.step_size, cfg.CUDA) self.dis = LeakGAN_D(cfg.dis_embed_dim, cfg.vocab_size, cfg.padding_idx, gpu=cfg.CUDA) self.init_model() # optimizer mana_params, work_params = self.gen.split_params() mana_opt = optim.Adam(mana_params, lr=cfg.gen_lr) work_opt = optim.Adam(work_params, lr=cfg.gen_lr) self.gen_opt = [mana_opt, work_opt] self.dis_opt = optim.Adam(self.dis.parameters(), lr=cfg.dis_lr) # Criterion self.mle_criterion = nn.NLLLoss() self.dis_criterion = nn.CrossEntropyLoss() # DataLoader self.gen_data = GenDataIter( self.gen.sample(cfg.batch_size, cfg.batch_size, self.dis)) self.dis_data = DisDataIter(self.gen_data.random_batch()['target'], self.train_data.random_batch()['target']) # Metrics self.bleu = BLEU(test_text=tensor_to_tokens(self.gen_data.target, self.index_word_dict), real_text=tensor_to_tokens( self.test_data.target, self.test_data.index_word_dict), gram=3) self.self_bleu = BLEU( test_text=tensor_to_tokens(self.gen_data.target, self.index_word_dict), real_text=tensor_to_tokens(self.gen_data.target, self.index_word_dict), gram=3) def _run(self): for inter_num in range(cfg.inter_epoch): self.log.info('>>> Interleaved Round %d...' % inter_num) self.sig.update() # update signal if self.sig.pre_sig: # =====DISCRIMINATOR PRE-TRAINING===== if not cfg.dis_pretrain: self.log.info('Starting Discriminator Training...') self.train_discriminator(cfg.d_step, cfg.d_epoch) if cfg.if_save and not cfg.if_test: torch.save(self.dis.state_dict(), cfg.pretrained_dis_path) print('Save pre-trained discriminator: {}'.format( cfg.pretrained_dis_path)) # =====GENERATOR MLE TRAINING===== if not cfg.gen_pretrain: self.log.info('Starting Generator MLE Training...') self.pretrain_generator(cfg.MLE_train_epoch) if cfg.if_save and not cfg.if_test: torch.save(self.gen.state_dict(), cfg.pretrained_gen_path) print('Save pre-trained generator: {}'.format( cfg.pretrained_gen_path)) else: self.log.info( '>>> Stop by pre_signal! Skip to adversarial training...') break # =====ADVERSARIAL TRAINING===== self.log.info('Starting Adversarial Training...') self.log.info('Initial generator: %s' % (str(self.cal_metrics(fmt_str=True)))) for adv_epoch in range(cfg.ADV_train_epoch): self.log.info('-----\nADV EPOCH %d\n-----' % adv_epoch) self.sig.update() if self.sig.adv_sig: self.adv_train_generator(cfg.ADV_g_step) # Generator self.train_discriminator(cfg.ADV_d_step, cfg.ADV_d_epoch, 'ADV') # Discriminator if adv_epoch % cfg.adv_log_step == 0: if cfg.if_save and not cfg.if_test: self._save('ADV', adv_epoch) else: self.log.info( '>>> Stop by adv_signal! Finishing adversarial training...' ) break def _test(self): print('>>> Begin test...') self._run() pass def pretrain_generator(self, epochs): """ Max Likelihood Pretraining for the gen - gen_opt: [mana_opt, work_opt] """ for epoch in range(epochs): self.sig.update() if self.sig.pre_sig: pre_mana_loss = 0 pre_work_loss = 0 # =====Train===== for i, data in enumerate(self.train_data.loader): inp, target = data['input'], data['target'] if cfg.CUDA: inp, target = inp.cuda(), target.cuda() mana_loss, work_loss = self.gen.pretrain_loss( target, self.dis) self.optimize_multi(self.gen_opt, [mana_loss, work_loss]) pre_mana_loss += mana_loss.data.item() pre_work_loss += work_loss.data.item() pre_mana_loss = pre_mana_loss / len(self.train_data.loader) pre_work_loss = pre_work_loss / len(self.train_data.loader) # =====Test===== if epoch % cfg.pre_log_step == 0: self.log.info( '[MLE-GEN] epoch %d : pre_mana_loss = %.4f, pre_work_loss = %.4f, %s' % (epoch, pre_mana_loss, pre_work_loss, self.cal_metrics(fmt_str=True))) if cfg.if_save and not cfg.if_test: self._save('MLE', epoch) else: self.log.info( '>>> Stop by pre signal, skip to adversarial training...') break def adv_train_generator(self, g_step, current_k=0): """ The gen is trained using policy gradients, using the reward from the discriminator. Training is done for num_batches batches. """ rollout_func = rollout.ROLLOUT(self.gen, cfg.CUDA) adv_mana_loss = 0 adv_work_loss = 0 for step in range(g_step): with torch.no_grad(): gen_samples = self.gen.sample( cfg.batch_size, cfg.batch_size, self.dis, train=True) # !!! train=True, the only place inp, target = self.gen_data.prepare(gen_samples, gpu=cfg.CUDA) # =====Train===== rewards = rollout_func.get_reward_leakgan( target, cfg.rollout_num, self.dis, current_k).cpu() # reward with MC search mana_loss, work_loss = self.gen.adversarial_loss( target, rewards, self.dis) # update parameters self.optimize_multi(self.gen_opt, [mana_loss, work_loss]) adv_mana_loss += mana_loss.data.item() adv_work_loss += work_loss.data.item() # =====Test===== self.log.info( '[ADV-GEN] adv_mana_loss = %.4f, adv_work_loss = %.4f, %s' % (adv_mana_loss / g_step, adv_work_loss / g_step, self.cal_metrics(fmt_str=True))) def train_discriminator(self, d_step, d_epoch, phrase='MLE'): """ Training the discriminator on real_data_samples (positive) and generated samples from gen (negative). Samples are drawn d_step times, and the discriminator is trained for d_epoch d_epoch. """ for step in range(d_step): # prepare loader for training pos_samples = self.train_data.target neg_samples = self.gen.sample(cfg.samples_num, cfg.batch_size, self.dis) self.dis_data.reset(pos_samples, neg_samples) for epoch in range(d_epoch): # =====Train===== d_loss, train_acc = self.train_dis_epoch( self.dis, self.dis_data.loader, self.dis_criterion, self.dis_opt) # =====Test===== self.log.info( '[%s-DIS] d_step %d: d_loss = %.4f, train_acc = %.4f,' % (phrase, step, d_loss, train_acc)) def cal_metrics(self, fmt_str=False): self.gen_data.reset( self.gen.sample(cfg.samples_num, cfg.batch_size, self.dis)) self.bleu.test_text = tensor_to_tokens(self.gen_data.target, self.index_word_dict) bleu_score = self.bleu.get_score(ignore=False) with torch.no_grad(): gen_nll = 0 for data in self.train_data.loader: inp, target = data['input'], data['target'] if cfg.CUDA: inp, target = inp.cuda(), target.cuda() loss = self.gen.batchNLLLoss(target, self.dis) gen_nll += loss.item() gen_nll /= len(self.train_data.loader) if fmt_str: ''' print('bleu_score:\n') print(bleu_score) print('gen_nll:\n') print(gen_nll) ''' return 'BLEU-3 = %.4f, gen_NLL = %.4f,' % (bleu_score[0], gen_nll) return bleu_score, gen_nll def _save(self, phrase, epoch): torch.save( self.gen.state_dict(), cfg.save_model_root + 'gen_{}_{:05d}.pt'.format(phrase, epoch)) save_sample_path = cfg.save_samples_root + 'samples_{}_{:05d}.txt'.format( phrase, epoch) samples = self.gen.sample(cfg.batch_size, cfg.batch_size, self.dis) write_tokens(save_sample_path, tensor_to_tokens(samples, self.index_word_dict))