def test_generator(use_cuda=False): ''' Prepare model_dict. ''' model_dict = prepare_model_dict(use_cuda) ''' Prepare some fake data. ''' dataloader = prepare_fake_data() ''' Start testing all recurrent functions. ''' for i, sample in enumerate(dataloader): sample = Variable(sample) if use_cuda: sample = sample.cuda(async=True) # Test pre. pre_rets = recurrent_func("pre")(model_dict, sample, use_cuda) for key in pre_rets.keys(): print("{}:{}".format(key, pre_rets[key].size())) print("Pretrain recurrent function test finished!") print("\n") del pre_rets # Test adv. adv_rets = recurrent_func('adv')(model_dict, use_cuda) for key in adv_rets.keys(): print("{}:{}".format(key, adv_rets[key].size())) print("Adversarial recurrent function test finished!") print("\n") del adv_rets # Test roll. gen_token = recurrent_func("rollout")(model_dict, sample, 4, use_cuda) print("gen_token:{}".format(gen_token.size())) print("Rollout test finished!") print("\n") del gen_token # Test gen. gen_token = recurrent_func("gen")(model_dict, use_cuda) print("gen_token:{}".format(gen_token.size())) print("Generate test finished!") print("\n") del gen_token break
def adversarial_train(model_dict, optimizer_dict, scheduler_dict, dis_dataloader_params, vocab_size, positive_file, negative_file, num_batches, gen_train_num=1, dis_train_epoch=5, dis_train_num=3, max_norm=5.0, rollout_num=4, use_cuda=False, temperature=1.0): ''' Get models, optimizers and schedulers. ''' generator = model_dict["generator"] discriminator = model_dict["discriminator"] worker = generator.worker manager = generator.manager m_optimizer = optimizer_dict["manager"] w_optimizer = optimizer_dict["worker"] d_optimizer = optimizer_dict["discriminator"] m_optimizer.zero_grad() w_optimizer.zero_grad() m_lr_scheduler = scheduler_dict["manager"] w_lr_scheduler = scheduler_dict["worker"] d_lr_scheduler = scheduler_dict["discriminator"] ''' Adversarial train for generator. ''' for _ in range(gen_train_num): m_lr_scheduler.step() w_lr_scheduler.step() m_optimizer.zero_grad() w_optimizer.zero_grad() adv_rets = recurrent_func('adv')(model_dict, use_cuda) real_goal = adv_rets["real_goal"] all_goal = adv_rets["all_goal"] prediction = adv_rets["prediction"] delta_feature = adv_rets["delta_feature"] delta_feature_for_worker = adv_rets["delta_feature_for_worker"] gen_token = adv_rets["gen_token"] rewards = get_rewards(model_dict, gen_token, rollout_num, use_cuda) m_loss = loss_func("adv_manager")(rewards, real_goal, delta_feature) w_loss = loss_func("adv_worker")(all_goal, delta_feature_for_worker, gen_token, prediction, vocab_size, use_cuda) torch.autograd.grad(m_loss, manager.parameters()) torch.autograd.grad(w_loss, worker.parameters()) clip_grad_norm(manager.parameters(), max_norm=max_norm) clip_grad_norm(worker.parameters(), max_norm=max_norm) m_optimizer.step() w_optimizer.step() del adv_rets del real_goal del all_goal del prediction del delta_feature del delta_feature_for_worker del gen_token del rewards ''' Adversarial train for discriminator. ''' for _ in range(dis_train_epoch): generate_samples(model_dict, negative_file, num_batches, use_cuda, temperature) dis_dataloader_params["positive_filepath"] = positive_file dis_dataloader_params["negative_filepath"] = negative_file dataloader = dis_data_loader(**dis_dataloader_params) cross_entropy = nn.CrossEntropyLoss() if use_cuda: cross_entropy = cross_entropy.cuda() for _ in range(dis_train_num): for i, sample in enumerate(dataloader): data, label = sample["data"], sample["label"] data = Variable(data) label = Variable(label) if use_cuda: data = data.cuda(async=True) label = label.cuda(async=True) outs = discriminator(data) loss = cross_entropy(outs["score"], label.view(-1)) + \ discriminator.l2_loss() d_optimizer.zero_grad() d_lr_scheduler.step() loss.backward() d_optimizer.step() model_dict["discriminator"] = discriminator generator.worker = worker generator.manager = manager model_dict["generator"] = generator optimizer_dict["manager"] = m_optimizer optimizer_dict["worker"] = w_optimizer optimizer_dict["discriminator"] = d_optimizer scheduler_dict["manager"] = m_lr_scheduler scheduler_dict["worker"] = w_lr_scheduler scheduler_dict["discriminator"] = d_lr_scheduler return model_dict, optimizer_dict, scheduler_dict
def pretrain_generator(model_dict, optimizer_dict, scheduler_dict, dataloader, vocab_size, max_norm=5.0, use_cuda=False): ''' Get models, optimizers and schedulers. ''' generator = model_dict["generator"] worker = generator.worker manager = generator.manager m_optimizer = optimizer_dict["manager"] w_optimizer = optimizer_dict["worker"] m_optimizer.zero_grad() w_optimizer.zero_grad() m_lr_scheduler = scheduler_dict["manager"] w_lr_scheduler = scheduler_dict["worker"] ''' Perform pretrain step for real data. ''' for i, sample in enumerate(dataloader): m_lr_scheduler.step() w_lr_scheduler.step() sample = Variable(sample) if use_cuda: sample = sample.cuda(async=True) # Calculate pretrain loss. pre_rets = recurrent_func("pre")(model_dict, sample, use_cuda) real_goal = pre_rets["real_goal"] prediction = pre_rets["prediction"] delta_feature = pre_rets["delta_feature"] m_loss = loss_func("pre_manager")(real_goal, delta_feature) torch.autograd.grad(m_loss, manager.parameters()) clip_grad_norm(manager.parameters(), max_norm=max_norm) m_optimizer.step() m_optimizer.zero_grad() w_loss = loss_func("pre_worker")(sample, prediction, vocab_size, use_cuda) torch.autograd.grad(w_loss, worker.parameters()) clip_grad_norm(worker.parameters(), max_norm=max_norm) w_optimizer.step() w_optimizer.zero_grad() ''' Update model_dict, optimizer_dict and scheduler_dict. ''' generator.worker = worker generator.manager = manager model_dict["generator"] = generator optimizer_dict["manager"] = m_optimizer optimizer_dict["worker"] = w_optimizer scheduler_dict["manager"] = m_lr_scheduler scheduler_dict["worker"] = w_lr_scheduler return model_dict, optimizer_dict, scheduler_dict
def adversarial_train(model_dict, optimizer_dict, scheduler_dict, dis_dataloader_params, vocab_size, pos_file, neg_file, batch_size, gen_train_num=1, dis_train_epoch=5, dis_train_num=3, max_norm=5.0, rollout_num=4, use_cuda=False, temperature=1.0, epoch=1, tot_epoch=100): """ Get all the models, optimizer and schedulers """ generator = model_dict["generator"] discriminator = model_dict["discriminator"] worker = generator.worker manager = generator.manager m_optimizer = optimizer_dict["manager"] w_optimizer = optimizer_dict["worker"] d_optimizer = optimizer_dict["discriminator"] #Why zero grad only m and w? m_optimizer.zero_grad() w_optimizer.zero_grad() m_lr_scheduler = scheduler_dict["manager"] w_lr_scheduler = scheduler_dict["worker"] d_lr_scheduler = scheduler_dict["discriminator"] #Adversarial training for generator for _ in range(gen_train_num): m_lr_scheduler.step() w_lr_scheduler.step() m_optimizer.zero_grad() w_optimizer.zero_grad() #get all the return values adv_rets = recurrent_func("adv")(model_dict, use_cuda) real_goal = adv_rets["real_goal"] all_goal = adv_rets["all_goal"] prediction = adv_rets["prediction"] delta_feature = adv_rets["delta_feature"] delta_feature_for_worker = adv_rets["delta_feature_for_worker"] gen_token = adv_rets["gen_token"] rewards = get_rewards(model_dict, gen_token, rollout_num, use_cuda) m_loss = loss_func("adv_manager")(rewards, real_goal, delta_feature) w_loss = loss_func("adv_worker")(all_goal, delta_feature_for_worker, gen_token, prediction, vocab_size, use_cuda) torch.autograd.grad( m_loss, manager.parameters()) #based on loss improve the parameters torch.autograd.grad(w_loss, worker.parameters()) clip_grad_norm_(manager.parameters(), max_norm) clip_grad_norm_(worker.parameters(), max_norm) m_optimizer.step() w_optimizer.step() print("Adv-Manager loss: {:.5f} Adv-Worker loss: {:.5f}".format( m_loss, w_loss)) del adv_rets del real_goal del all_goal del prediction del delta_feature del delta_feature_for_worker del gen_token del rewards #Adversarial training for discriminator for n in range(dis_train_epoch): generate_samples(model_dict, neg_file, batch_size, use_cuda, temperature) dis_dataloader_params["positive_filepath"] = pos_file dis_dataloader_params["negative_filepath"] = neg_file dataloader = dis_data_loader(**dis_dataloader_params) cross_entropy = nn.CrossEntropyLoss() if use_cuda: cross_entropy = cross_entropy.cuda() """ for d-steps do Use current G, θm,θw to generate negative examples and combine with given positive examples S Train discriminator Dφ for k epochs by Eq. (2) end for """ for _ in range(dis_train_num): for i, sample in enumerate(dataloader): data, label = sample["data"], sample["label"] data = Variable(data) label = Variable(label) if use_cuda: data = data.cuda(async=True) label = label.cuda(async=True) outs = discriminator(data) loss = cross_entropy(outs["score"], label.view(-1)) + discriminator.l2_loss() d_optimizer.zero_grad() d_lr_scheduler.step() loss.backward() d_optimizer.step() print("{}/{} Adv-Discriminator Loss: {:.5f}".format( n, range(dis_train_epoch), loss)) #Save all changes model_dict["discriminator"] = discriminator generator.worker = worker generator.manager = manager model_dict["generator"] = generator optimizer_dict["manager"] = m_optimizer optimizer_dict["worker"] = w_optimizer optimizer_dict["discriminator"] = d_optimizer scheduler_dict["manager"] = m_lr_scheduler scheduler_dict["worker"] = w_lr_scheduler scheduler_dict["disciminator"] = d_lr_scheduler return model_dict, optimizer_dict, scheduler_dict
def pretrain_generator(model_dict, optimizer_dict, scheduler_dict, dataloader, vocab_size, max_norm=5.0, use_cuda=False, epoch=1, tot_epochs=100): #get the models of generator generator = model_dict["generator"] worker = generator.worker manager = generator.manager #get the optimizers m_optimizer = optimizer_dict["manager"] w_optimizer = optimizer_dict["worker"] m_optimizer.zero_grad() w_optimizer.zero_grad() m_lr_scheduler = scheduler_dict["manager"] w_lr_scheduler = scheduler_dict["worker"] """ Perform pretrain step for real data """ for i, sample in enumerate(dataloader): #print("DataLoader: {}".format(dataloader)) m_lr_scheduler.step() w_lr_scheduler.step() sample = Variable(sample) if use_cuda: sample = sample.cuda(async=True) # Calculate pretrain loss if ( sample.size() == torch.zeros([64, 20]).size() ): #sometimes smaller than 64 (16) is passed, so this if statement disables it #print("Sample size: {}".format(sample.size())) pre_rets = recurrent_func("pre")(model_dict, sample, use_cuda) real_goal = pre_rets["real_goal"] prediction = pre_rets["prediction"] delta_feature = pre_rets["delta_feature"] m_loss = loss_func("pre_manager")(real_goal, delta_feature) torch.autograd.grad(m_loss, manager.parameters()) clip_grad_norm_(manager.parameters(), max_norm=max_norm) m_optimizer.step() m_optimizer.zero_grad() w_loss = loss_func("pre_worker")(sample, prediction, vocab_size, use_cuda) torch.autograd.grad(w_loss, worker.parameters()) clip_grad_norm_(worker.parameters(), max_norm=max_norm) w_optimizer.step() w_optimizer.zero_grad() if i == 63: print("Pre-Manager Loss: {:.5f}, Pre-Worker Loss: {:.5f}\n". format(m_loss, w_loss)) """ Update model_dict, optimizer_dict, and scheduler_dict """ generator.woroker = worker generator.manager = manager model_dict["generator"] = generator optimizer_dict["manager"] = m_optimizer optimizer_dict["worker"] = w_optimizer scheduler_dict["manager"] = m_lr_scheduler scheduler_dict["worker"] = w_lr_scheduler return model_dict, optimizer_dict, scheduler_dict
def pretrain_generator(model_dict, optimizer_dict, scheduler_dict, dataloader, vocab_size, max_norm=5.0, use_cuda=False, epoch=1, tot_epochs=100): #get the models of generator generator = model_dict["generator"] worker = generator.worker manager = generator.manager #get the optimizers m_optimizer = optimizer_dict["manager"] w_optimizer = optimizer_dict["worker"] m_optimizer.zero_grad() w_optimizer.zero_grad() m_lr_scheduler = scheduler_dict["manager"] w_lr_scheduler = scheduler_dict["worker"] """ Perform pretrain step for real data """ for i, sample in enumerate(dataloader): #print("DataLoader: {}".format(dataloader)) m_lr_scheduler.step() w_lr_scheduler.step() sample = Variable(sample) if use_cuda: sample = sample.cuda() #sample = sample.cuda() # Calculate pretrain loss if ( sample.size() == torch.zeros([64, 20]).size() ): #sometimes smaller than 64 (16) is passed, so this if statement disables it # 上面这一行能不能效率更高一些,只检测size[0]是否等于64就可以了 #print("Sample size: {}".format(sample.size())) 其中sample:[batch_size , seq_len] pre_rets = recurrent_func("pre")(model_dict, sample, use_cuda) real_goal = pre_rets["real_goal"] prediction = pre_rets["prediction"] delta_feature = pre_rets["delta_feature"] #real_goal和delta求manager的loss,prediction和sample求worker的loss m_loss = loss_func("pre_manager")(real_goal, delta_feature) torch.autograd.grad(m_loss, manager.parameters() ) #这一行不会更改manager.parameters的grad呀???有什么用??? clip_grad_norm_(manager.parameters(), max_norm=max_norm) m_optimizer.step() m_optimizer.zero_grad() # 现在我的理解,上面4行:前两行只是为了裁剪梯度 # 这里较平常的训练过程还有一点区别,平常一般是optimizer.zero_grad() -> loss -> loss.backward() -> optimizer.step() # 这里不需要loss.backward()吗? w_loss = loss_func("pre_worker")(sample, prediction, vocab_size, use_cuda) torch.autograd.grad(w_loss, worker.parameters( )) #这里是求d(w_loss)/d(worker.parameters),但是这里又并不把结果分给任何一个变量 clip_grad_norm_(worker.parameters(), max_norm=max_norm) #这个超参被设定为5,为什么? w_optimizer.step() w_optimizer.zero_grad() if i == 63: print("Pre-Manager Loss: {:.5f}, Pre-Worker Loss: {:.5f}\n". format(m_loss, w_loss)) """ Update model_dict, optimizer_dict, and scheduler_dict """ generator.woroker = worker generator.manager = manager model_dict["generator"] = generator optimizer_dict["manager"] = m_optimizer optimizer_dict["worker"] = w_optimizer scheduler_dict["manager"] = m_lr_scheduler scheduler_dict["worker"] = w_lr_scheduler return model_dict, optimizer_dict, scheduler_dict
def test_loss_func(use_cuda=False): ''' Prepare model_dict. ''' model_dict = prepare_model_dict(use_cuda) generator = model_dict["generator"] worker = generator.worker manager = generator.manager ''' Prepare some fake data. ''' dataloader = prepare_fake_data() ''' Start testing all recurrent functions. ''' m_optimizer = optim.Adam(manager.parameters(), lr=0.001) w_optimizer = optim.Adam(worker.parameters(), lr=0.001) m_optimizer.zero_grad() w_optimizer.zero_grad() for i, sample in enumerate(dataloader): sample = Variable(sample) if use_cuda: sample = sample.cuda(async=True) # Test pre. pre_rets = recurrent_func("pre")(model_dict, sample, use_cuda) real_goal = pre_rets["real_goal"] prediction = pre_rets["prediction"] delta_feature = pre_rets["delta_feature"] m_loss = loss_func("pre_manager")(real_goal, delta_feature) torch.autograd.grad(m_loss, manager.parameters()) nn.utils.clip_grad_norm(manager.parameters(), max_norm=5.0) m_optimizer.step() m_optimizer.zero_grad() w_loss = loss_func("pre_worker")(sample, prediction, 5000, use_cuda) torch.autograd.grad(w_loss, worker.parameters()) nn.utils.clip_grad_norm(worker.parameters(), max_norm=5.0) w_optimizer.step() w_optimizer.zero_grad() print("pre_m_loss={}, pre_w_loss={}".format(m_loss.data[0], w_loss.data[0])) print("Pretrain loss function test finished!") print("\n") # Test adv. adv_rets = recurrent_func('adv')(model_dict, use_cuda) real_goal = adv_rets["real_goal"] all_goal = adv_rets["all_goal"] prediction = adv_rets["prediction"] delta_feature = adv_rets["delta_feature"] delta_feature_for_worker = adv_rets["delta_feature_for_worker"] gen_token = adv_rets["gen_token"] rewards = get_rewards(model_dict, gen_token, 4, use_cuda) m_loss = loss_func("adv_manager")(rewards, real_goal, delta_feature) w_loss = loss_func("adv_worker")(all_goal, delta_feature_for_worker, gen_token, prediction, 5000, use_cuda) m_optimizer = optim.Adam(manager.parameters(), lr=0.001) w_optimizer = optim.Adam(worker.parameters(), lr=0.001) m_optimizer.zero_grad() w_optimizer.zero_grad() torch.autograd.grad(m_loss, manager.parameters()) torch.autograd.grad(w_loss, worker.parameters()) nn.utils.clip_grad_norm(manager.parameters(), max_norm=5.0) nn.utils.clip_grad_norm(worker.parameters(), max_norm=5.0) m_optimizer.step() w_optimizer.step() print("adv_m_loss={}, adv_w_loss={}".format(m_loss.data[0], w_loss.data[0])) print("Adversarial training loss function test finished!") print("\n") if i > 0: break