class TRGANInstructor(BasicInstructor): def __init__(self, opt): super(TRGANInstructor, self).__init__(opt) # generator, discriminator self.gen = RelGAN_G(cfg.mem_slots, cfg.num_heads, cfg.head_size, cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, cfg.max_seq_len, cfg.padding_idx, gpu=cfg.CUDA) self.dis = RelGAN_D(cfg.dis_embed_dim, cfg.max_seq_len, cfg.num_rep, cfg.vocab_size, cfg.padding_idx, gpu=cfg.CUDA) self.dis_D = RelGAN_D(cfg.dis_embed_dim, cfg.max_seq_len, cfg.num_rep, cfg.vocab_size, cfg.padding_idx, gpu=cfg.CUDA) self.init_model() # Optimizer self.gen_opt = optim.Adam(self.gen.parameters(), lr=cfg.gen_lr) self.gen_adv_opt = optim.Adam(self.gen.parameters(), lr=cfg.gen_adv_lr) self.dis_opt = optim.Adam(self.dis.parameters(), lr=cfg.dis_lr) self.dis_D_opt = optim.Adam(self.dis_D.parameters(), lr=cfg.dis_D_lr) def init_model(self): if cfg.oracle_pretrain: if not os.path.exists(cfg.oracle_state_dict_path): create_oracle() self.oracle.load_state_dict(torch.load(cfg.oracle_state_dict_path)) if cfg.dis_pretrain: self.log.info( 'Load pretrained discriminator: {}'.format(cfg.pretrained_dis_path)) self.dis.load_state_dict(torch.load(cfg.pretrained_dis_path)) if cfg.gen_pretrain: self.log.info('Load MLE pretrained generator gen: {}'.format(cfg.pretrained_gen_path)) self.gen.load_state_dict(torch.load(cfg.pretrained_gen_path, map_location='cuda:{}'.format(cfg.device))) if cfg.CUDA: self.oracle = self.oracle.cuda() self.gen = self.gen.cuda() self.dis = self.dis.cuda() self.dis_D = self.dis_D.cuda() def _run(self): # ===PRE-TRAINING (GENERATOR)=== if not cfg.gen_pretrain: self.log.info('Starting Generator MLE Training...') self.pretrain_generator(cfg.MLE_train_epoch) if cfg.if_save and not cfg.if_test: torch.save(self.gen.state_dict(), cfg.pretrained_gen_path) print('Save pre-trained generator: {}'.format(cfg.pretrained_gen_path)) # # ===ADVERSARIAL TRAINING=== self.log.info('Starting Adversarial Training...') progress = tqdm(range(cfg.ADV_train_epoch)) for adv_epoch in progress: self.sig.update() if self.sig.adv_sig: g_loss = self.adv_train_generator(cfg.ADV_g_step) # Generator d_loss = self.adv_train_discriminator(cfg.ADV_d_step) # Discriminator self.update_temperature(adv_epoch, cfg.ADV_train_epoch) # update temperature progress.set_description( 'g_loss: %.4f, d_loss: %.4f, temperature: %.4f' % (g_loss, d_loss, self.gen.temperature)) # TEST if adv_epoch % cfg.adv_log_step == 0: self.log.info('[ADV] epoch %d: g_loss: %.4f, d_loss: %.4f, %s' % ( adv_epoch, g_loss, d_loss, self.cal_metrics(fmt_str=True))) if cfg.if_save and not cfg.if_test: self._save('ADV', adv_epoch) else: self.log.info('>>> Stop by adv_signal! Finishing adversarial training...') progress.close() break def _test(self): print('>>> Begin test...') self._run() pass def pretrain_generator(self, epochs): """ Max Likelihood Pre-training for the generator """ for epoch in range(epochs): self.sig.update() if self.sig.pre_sig: # ===Train=== pre_loss = self.train_gen_epoch(self.gen, self.oracle_data.loader, self.mle_criterion, self.gen_opt) # ===Test=== if epoch % cfg.pre_log_step == 0 or epoch == epochs - 1: self.log.info( '[MLE-GEN] epoch %d : pre_loss = %.4f, %s' % (epoch, pre_loss, self.cal_metrics(fmt_str=True))) if cfg.if_save and not cfg.if_test: self._save('MLE', epoch) else: self.log.info('>>> Stop by pre signal, skip to adversarial training...') break def adv_train_generator(self, g_step): criterion = nn.BCELoss() total_loss = 0 with torch.no_grad(): gen_samples = self.gen.sample(cfg.batch_size, cfg.batch_size, one_hot=True) if cfg.CUDA: gen_samples = gen_samples.cuda() D0 = torch.sigmoid(self.dis_D(gen_samples)) P0 = (1.-D0)/torch.clamp(D0, min = 1e-7) for step in range(g_step): real_samples = F.one_hot(self.oracle_data.random_batch()['target'], cfg.vocab_size).float() gen_samples = self.gen.sample(cfg.batch_size, cfg.batch_size, one_hot=True) real_label = torch.full((D0.shape[0],), 1.) fake_label = torch.full((D0.shape[0],), 0.) if cfg.CUDA: real_samples, gen_samples, real_label, fake_label = real_samples.cuda(), gen_samples.cuda(), real_label.cuda(), fake_label.cuda() # print(self.dis_D(real_samples).shape, real_label.shape) errDD_real = criterion(torch.sigmoid(self.dis_D(real_samples)), real_label) errDD_fake = criterion(torch.sigmoid(self.dis_D(gen_samples.detach())), fake_label) self.optimize(self.dis_D_opt, errDD_real+errDD_fake, self.dis_D) gen_samples = self.gen.sample(cfg.batch_size, cfg.batch_size, one_hot=True).cuda() real_samples = F.one_hot(self.oracle_data.random_batch()['target'], cfg.vocab_size).float().cuda() D1 = torch.sigmoid(self.dis_D(gen_samples)) P1 = (1.-D1) ratio = (P1/torch.clamp(D1*P0, min = 1e-7)) ratio_clipped = torch.clamp(ratio, 1.0 - cfg.clip_param, 1.0 + cfg.clip_param) # ===Train=== d_out_real = self.dis(real_samples) d_out_fake = self.dis(gen_samples) surr1 = ratio * d_out_fake surr2 = ratio_clipped * d_out_fake target = torch.where(d_out_fake>0, torch.min(surr1, surr2), torch.max(surr1, surr2)) g_loss, _ = get_losses(d_out_real, target, cfg.loss_type) # g_loss = -d_out_fake.mean() self.optimize(self.gen_adv_opt, g_loss, self.gen) total_loss += g_loss.item() return total_loss / g_step if g_step != 0 else 0 def calc_gradient_penalty(self, real_data, fake_data): BATCH_SIZE = real_data.shape[0] alpha = torch.rand(BATCH_SIZE, 1) alpha = alpha.expand(BATCH_SIZE, real_data.nelement()//BATCH_SIZE).contiguous().view(real_data.shape) alpha = alpha.cuda() interpolates = alpha * real_data + ((1 - alpha) * fake_data) interpolates = interpolates.cuda() interpolates = autograd.Variable(interpolates, requires_grad=True) # disc_interpolates = netD(interpolates) disc_interpolates = self.dis(interpolates) gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates, grad_outputs=torch.ones(disc_interpolates.size()).cuda(), create_graph=True, retain_graph=True, only_inputs=True)[0] gradients = gradients.contiguous().view(gradients.size(0), -1) gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() return gradient_penalty def adv_train_discriminator(self, d_step): total_loss = 0 for step in range(d_step): real_samples = F.one_hot(self.oracle_data.random_batch()['target'], cfg.vocab_size).float() gen_samples = self.gen.sample(cfg.batch_size, cfg.batch_size, one_hot=True) if cfg.CUDA: real_samples, gen_samples = real_samples.cuda(), gen_samples.cuda() # ===Train=== d_out_real = self.dis(real_samples) d_out_fake = self.dis(gen_samples) _, d_loss = get_losses(d_out_real, d_out_fake, cfg.loss_type) if cfg.GP: gradient_penalty = self.calc_gradient_penalty(real_samples.data, gen_samples.data) d_loss = d_loss+cfg.LAMBDA*gradient_penalty # print(d_loss.shape) self.optimize(self.dis_opt, d_loss, self.dis) total_loss += d_loss.item() return total_loss / d_step if d_step != 0 else 0 def update_temperature(self, i, N): self.gen.temperature = get_fixed_temperature(cfg.temperature, i, N, cfg.temp_adpt) @staticmethod def optimize(opt, loss, model=None, retain_graph=False): """Add clip_grad_norm_""" opt.zero_grad() loss.backward(retain_graph=retain_graph) if model is not None: torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.clip_norm) opt.step()
class RelGANInstructor(BasicInstructor): def __init__(self, opt): super(RelGANInstructor, self).__init__(opt) # generator, discriminator self.gen = RelGAN_G(cfg.mem_slots, cfg.num_heads, cfg.head_size, cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, cfg.max_seq_len, cfg.padding_idx, gpu=cfg.CUDA) self.dis = RelGAN_D(cfg.dis_embed_dim, cfg.max_seq_len, cfg.num_rep, cfg.vocab_size, cfg.padding_idx, gpu=cfg.CUDA) self.init_model() # Optimizer self.gen_opt = optim.Adam(self.gen.parameters(), lr=cfg.gen_lr) self.gen_adv_opt = optim.Adam(self.gen.parameters(), lr=cfg.gen_adv_lr) self.dis_opt = optim.Adam(self.dis.parameters(), lr=cfg.dis_lr) # Criterion self.mle_criterion = nn.NLLLoss() # DataLoader self.gen_data = GenDataIter(self.gen.sample(cfg.batch_size, cfg.batch_size)) def _run(self): # =====PRE-TRAINING (GENERATOR)===== if not cfg.gen_pretrain: self.log.info('Starting Generator MLE Training...') self.pretrain_generator(cfg.MLE_train_epoch) if cfg.if_save and not cfg.if_test: torch.save(self.gen.state_dict(), cfg.pretrained_gen_path) print('Save pre-trained generator: {}'.format(cfg.pretrained_gen_path)) self.log.info('Initial generator: %s' % (self.cal_metrics(fmt_str=True))) # # =====ADVERSARIAL TRAINING===== self.log.info('Starting Adversarial Training...') progress = tqdm(range(cfg.ADV_train_epoch)) for adv_epoch in progress: self.sig.update() if self.sig.adv_sig: g_loss = self.adv_train_generator(cfg.ADV_g_step) # Generator d_loss = self.adv_train_discriminator(cfg.ADV_d_step) # Discriminator self.update_temperature(adv_epoch, cfg.ADV_train_epoch) # update temperature progress.set_description( 'g_loss: %.4f, d_loss: %.4f, temperature: %.4f' % (g_loss, d_loss, self.gen.temperature)) # TEST if adv_epoch % cfg.adv_log_step == 0: self.log.info('[ADV] epoch %d: g_loss: %.4f, d_loss: %.4f, %s' % ( adv_epoch, g_loss, d_loss, self.cal_metrics(fmt_str=True))) if cfg.if_save and not cfg.if_test: self._save('ADV', adv_epoch) else: self.log.info('>>> Stop by adv_signal! Finishing adversarial training...') progress.close() break def _test(self): print('>>> Begin test...') self._run() pass def pretrain_generator(self, epochs): """ Max Likelihood Pre-training for the generator """ for epoch in range(epochs): self.sig.update() if self.sig.pre_sig: # =====Train===== pre_loss = self.train_gen_epoch(self.gen, self.oracle_data.loader, self.mle_criterion, self.gen_opt) # =====Test===== if epoch % cfg.pre_log_step == 0: self.log.info( '[MLE-GEN] epoch %d : pre_loss = %.4f, %s' % (epoch, pre_loss, self.cal_metrics(fmt_str=True))) if cfg.if_save and not cfg.if_test: self._save('MLE', epoch) else: self.log.info('>>> Stop by pre signal, skip to adversarial training...') break if cfg.if_save and not cfg.if_test: self._save('MLE', epoch) def adv_train_generator(self, g_step): total_loss = 0 for step in range(g_step): real_samples = F.one_hot(self.oracle_data.random_batch()['target'], cfg.vocab_size).float() gen_samples = self.gen.sample(cfg.batch_size, cfg.batch_size, one_hot=True) if cfg.CUDA: real_samples, gen_samples = real_samples.cuda(), gen_samples.cuda() # =====Train===== d_out_real = self.dis(real_samples) d_out_fake = self.dis(gen_samples) g_loss, _ = get_losses(d_out_real, d_out_fake, cfg.loss_type) self.optimize(self.gen_adv_opt, g_loss, self.gen) total_loss += g_loss.item() return total_loss / g_step if g_step != 0 else 0 def adv_train_discriminator(self, d_step): total_loss = 0 for step in range(d_step): real_samples = F.one_hot(self.oracle_data.random_batch()['target'], cfg.vocab_size).float() gen_samples = self.gen.sample(cfg.batch_size, cfg.batch_size, one_hot=True) if cfg.CUDA: real_samples, gen_samples = real_samples.cuda(), gen_samples.cuda() # =====Train===== d_out_real = self.dis(real_samples) d_out_fake = self.dis(gen_samples) _, d_loss = get_losses(d_out_real, d_out_fake, cfg.loss_type) self.optimize(self.dis_opt, d_loss, self.dis) total_loss += d_loss.item() return total_loss / d_step if d_step != 0 else 0 def update_temperature(self, i, N): self.gen.temperature = get_fixed_temperature(cfg.temperature, i, N, cfg.temp_adpt) @staticmethod def optimize(opt, loss, model=None, retain_graph=False): """Add clip_grad_norm_""" opt.zero_grad() loss.backward(retain_graph=retain_graph) if model is not None: torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.clip_norm) opt.step()
class RelGANInstructor(BasicInstructor): def __init__(self, opt): super(RelGANInstructor, self).__init__(opt) norm = opt.norm assert norm in ['none', 'spectral', 'gradnorm', 'gp'] # generator, discriminator print('norm ', norm) self.norm = norm self.gen = RelGAN_G(cfg.mem_slots, cfg.num_heads, cfg.head_size, cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, cfg.max_seq_len, cfg.padding_idx, gpu=cfg.CUDA) self.dis = RelGAN_D(cfg.dis_embed_dim, cfg.max_seq_len, cfg.num_rep, cfg.vocab_size, cfg.padding_idx, gpu=cfg.CUDA, norm=norm).cuda() if norm == 'gradnorm': print('use gradnorm') self.dis = GradNorm(self.dis).cuda() self.init_model() # Optimizer if norm == 'gradnorm': self.gen_opt = optim.Adam(self.gen.parameters(), lr=cfg.gen_lr, betas=(0.5, 0.999)) self.gen_adv_opt = optim.Adam(self.gen.parameters(), lr=cfg.gen_adv_lr, betas=(0.5, 0.999)) self.dis_opt = optim.Adam(self.dis.parameters(), lr=cfg.dis_lr, betas=(0.5, 0.999)) else: self.gen_opt = optim.Adam(self.gen.parameters(), lr=cfg.gen_lr) self.gen_adv_opt = optim.Adam(self.gen.parameters(), lr=cfg.gen_adv_lr) self.dis_opt = optim.Adam(self.dis.parameters(), lr=cfg.dis_lr) os.makedirs(cfg.log_filename.replace('.txt', ''), exist_ok=True) self.logger = SummaryWriter( cfg.log_filename.replace('.txt', '') + '_' + norm) def _run(self): # ===PRE-TRAINING (GENERATOR)=== if os.path.exists(cfg.pretrained_gen_path): checkpoint = torch.load(cfg.pretrained_gen_path) generation_weights = self.gen.state_dict() match = True for key, value in checkpoint.items(): if key not in generation_weights: match = False elif generation_weights[key].shape != checkpoint[key].shape: match = False if match: self.gen.load_state_dict(checkpoint) print('Load pre-trained generator: {}'.format( cfg.pretrained_gen_path)) elif not cfg.gen_pretrain: self.log.info('Starting Generator MLE Training...') self.pretrain_generator(cfg.MLE_train_epoch) if cfg.if_save and not cfg.if_test: torch.save(self.gen.state_dict(), cfg.pretrained_gen_path) print('Save pre-trained generator: {}'.format( cfg.pretrained_gen_path)) # # ===ADVERSARIAL TRAINING=== self.log.info('Starting Adversarial Training...') progress = tqdm(range(cfg.ADV_train_epoch), dynamic_ncols=True) for adv_epoch in progress: self.sig.update() if self.sig.adv_sig: start = time() g_loss = self.adv_train_generator(cfg.ADV_g_step) # Generator d_loss = self.adv_train_discriminator( cfg.ADV_d_step) # Discriminator self.update_temperature( adv_epoch, cfg.ADV_train_epoch) # update temperature progress.set_description( 'g_loss: %.4f, d_loss: %.4f, temperature: %.4f' % (g_loss, d_loss, self.gen.temperature)) if adv_epoch % 10 == 0: self.logger.add_scalar('train/d_loss', float(d_loss), adv_epoch) self.logger.add_scalar('train/g_loss', float(g_loss), adv_epoch) self.logger.add_scalar('train/temperature', self.gen.temperature, adv_epoch) # TEST if adv_epoch % cfg.adv_log_step == 0: metrics = self.cal_metrics(fmt_str=False) for key, value in metrics.items(): if isinstance(value, list): for idx, v in enumerate(value): self.logger.add_scalar( 'train/' + key + '/' + str(idx), v, adv_epoch) else: self.logger.add_scalar('train/' + key, value, adv_epoch) self.logger.flush() self.log.info( '[ADV] epoch %d: g_loss: %.4f, d_loss: %.4f' % (adv_epoch, g_loss, d_loss)) if cfg.if_save and not cfg.if_test: self._save('GEN', adv_epoch) else: self.log.info( '>>> Stop by adv_signal! Finishing adversarial training...' ) progress.close() break def _test(self): print('>>> Begin test...') self._run() pass def pretrain_generator(self, epochs): """ Max Likelihood Pre-training for the generator """ for epoch in range(epochs): self.sig.update() if self.sig.pre_sig: # ===Train=== pre_loss = self.train_gen_epoch(self.gen, self.train_data.loader, self.mle_criterion, self.gen_opt) # ===Test=== if (epoch % cfg.pre_log_step == 0 or epoch == epochs - 1) and epoch > 0: metrics = self.cal_metrics(fmt_str=False) for key, value in metrics.items(): if isinstance(value, list): for idx, v in enumerate(value): self.logger.add_scalar( 'pretrain/' + key + '/' + str(idx), v, epoch) else: self.logger.add_scalar('pretrain/' + key, value, epoch) self.logger.add_scalar('pretrain/loss', pre_loss, epoch) self.logger.flush() self.log.info('[MLE-GEN] epoch %d : pre_loss = %.4f' % (epoch, pre_loss)) if cfg.if_save and not cfg.if_test: self._save('MLE', epoch) else: self.log.info( '>>> Stop by pre signal, skip to adversarial training...') break def adv_train_generator(self, g_step): total_loss = 0 for step in range(g_step): real_samples = self.train_data.random_batch()['target'] gen_samples = self.gen.sample(cfg.batch_size, cfg.batch_size, one_hot=True) if cfg.CUDA: real_samples, gen_samples = real_samples.cuda( ), gen_samples.cuda() real_samples = F.one_hot(real_samples, cfg.vocab_size).float() # ===Train=== d_out_real = self.dis(real_samples) d_out_fake = self.dis(gen_samples) g_loss, _ = get_losses(d_out_real, d_out_fake, cfg.loss_type) self.optimize(self.gen_adv_opt, g_loss, self.gen) total_loss += g_loss.item() return total_loss / g_step if g_step != 0 else 0 def adv_train_discriminator(self, d_step): total_loss = 0 for step in range(d_step): real_samples = self.train_data.random_batch()['target'] gen_samples = self.gen.sample(cfg.batch_size, cfg.batch_size, one_hot=True) if cfg.CUDA: real_samples, gen_samples = real_samples.cuda( ), gen_samples.cuda() real_samples = F.one_hot(real_samples, cfg.vocab_size).float() # ===Train=== d_out_real = self.dis(real_samples) d_out_fake = self.dis(gen_samples) _, d_loss = get_losses(d_out_real, d_out_fake, cfg.loss_type) if self.norm == 'gp': gp_loss = cacl_gradient_penalty(self.dis, real_samples, gen_samples) d_loss += gp_loss * 10 self.optimize(self.dis_opt, d_loss, self.dis) total_loss += d_loss.item() return total_loss / d_step if d_step != 0 else 0 def update_temperature(self, i, N): self.gen.temperature = get_fixed_temperature(cfg.temperature, i, N, cfg.temp_adpt) @staticmethod def optimize(opt, loss, model=None, retain_graph=False): opt.zero_grad() loss.backward(retain_graph=retain_graph) if model is not None: torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.clip_norm) opt.step()