def gen_fake(generator, agent, trainSample, batch_size, embed_dim, device, write_item, write_target, write_reward, write_action, action_num, max_length=5, recom_length=None): for stidx in range(0, trainSample.length(), batch_size): click_batch, length, _, reward_batch, action_batch = getBatch_dis(stidx, stidx + batch_size, trainSample, embed_dim, recom_length) click_batch = click_batch.to(device) reward_batch = reward_batch.to(device) action_batch = action_batch.to(device) if recom_length == None: recom_length = action_batch.size(1) replay = ReplayMemory(generator, agent, len(length), max_length, action_num, recom_length) with torch.no_grad(): replay.init_click_sample((click_batch, length), reward_batch, action_batch) replay.gen_sample(batch_size, False) seq_samples, lengths, seq_rewards, seq_actions = replay.clicks, replay.lengths, replay.tgt_rewards, replay.actions seq_rewards = torch.round(seq_rewards) write_tensor(seq_samples, lengths, write_item, write_target, 'dis', real=False) write_tensor_reward(seq_rewards, lengths, write_reward) write_tensor_action(seq_actions, lengths, write_action) return seq_samples, lengths, seq_rewards, seq_actions
def train_gen_pg_each(generator, agent, discriminator, epoch, trainSample, subnum, optimizer_agent, optimizer_usr, batch_size, embed_dim, recom_length, max_length, real_label_num, device, gen_ratio, pretrain = False, shuffle_index=None): generator.train() agent.train() print('\nTRAINING : Epoch ' + str(epoch)) generator.train() all_costs = [] logs = [] decay = 0.95 gamma = 0.9 max_norm=5 all_num=0 last_time = time.time() #Adjust the learning rate if epoch>1: optimizer_agent.param_groups[0]['lr'] = optimizer_agent.param_groups[0]['lr'] * decay optimizer_usr.param_groups[0]['lr'] = optimizer_usr.param_groups[0]['lr'] * decay print('Learning rate_agent : {0}'.format(optimizer_agent.param_groups[0]['lr'])) print('Learning rate_usr : {0}'.format(optimizer_usr.param_groups[0]['lr'])) #Generate subsamples trainSample_sub = Sample() trainSample_sub.subSample_copy(subnum, trainSample, shuffle_index) for stidx in range(0, trainSample_sub.length(), batch_size): # prepare batch embed_batch, length, _, reward_batch, action_batch = getBatch_dis(stidx, stidx + batch_size, trainSample_sub, embed_dim, recom_length) embed_batch, reward_batch, action_batch = Variable(embed_batch.to(device)), Variable(reward_batch.to(device)), Variable(action_batch.to(device)) k = embed_batch.size(0) #Actual batch size replay = ReplayMemory(generator, agent, int((1+gen_ratio)*k), max_length, real_label_num, action_batch.size(1)) replay.init_click((embed_batch, length), reward_batch, action_batch) replay.gen_sample(batch_size, True, discriminator) tgt_reward, gen_reward, usr_prob, agent_prob = replay.tgt_rewards.type(torch.FloatTensor).to(device), replay.gen_rewards.type(torch.FloatTensor).to(device), replay.usr_probs.to(device), replay.agent_probs.to(device) tgt_prob = torch.abs(1.0-torch.round(tgt_reward)-tgt_reward) tgt_reward = torch.round(tgt_reward) if not pretrain: loss_usr = -((torch.log(usr_prob + 1e-12) + torch.log(tgt_prob + 1e-12)) * gen_reward).sum()/k #Calculate the cumulative reward tgt_reward = gen_reward * (1 + tgt_reward) tgt_value = generator.value(tgt_reward) #loss_agent = -(torch.log(agent_prob + 1e-12) * (gen_reward + tgt_value)).sum()/k #+ 1e-18 loss_agent = -(torch.log(agent_prob + 1e-12) * (tgt_value)).sum()/k #+ 1e-18 all_costs.append(loss_agent.data.cpu().numpy()) # backward optimizer_agent.zero_grad() optimizer_usr.zero_grad() if not pretrain: loss_usr.backward(retain_graph=True) #Print gradients for each layer ''' print("Gradients for user behavior models:") print("Embedding:") generator.embedding.print_grad() print("Encoder:") generator.encoder.print_grad() print("MLPlayer:") print(generator.enc2out.weight.grad) ''' #Gradient clipping clip_grad_value_(filter(lambda p: p.requires_grad, generator.parameters()), 1) #clip_grad_norm_(filter(lambda p: p.requires_grad, generator.parameters()), 5) optimizer_usr.step() loss_agent.backward() #Gradient clipping clip_grad_value_(filter(lambda p: p.requires_grad, agent.parameters()), 1) #clip_grad_norm_(filter(lambda p: p.requires_grad, agent.parameters()), 5) # optimizer step optimizer_agent.step() # Printing if len(all_costs) == 100: logs.append( '{0} ; loss {1} ; seq/s {2}'.format(stidx, round(np.mean(all_costs),2), int(len(all_costs) * batch_size / (time.time() - last_time)))) print(logs[-1]) last_time = time.time() all_costs = [] return all_costs