def log_train(td_ct, args, conf, summary_writer, replay_memory, epoch_id): """train""" dataset = NpzDataset(conf.train_npz_list, conf.npz_config_path, conf.requested_npz_names, if_random_shuffle=True) data_gen = dataset.get_data_generator(conf.batch_size) list_reward = [] list_loss = [] list_first_Q = [] guessed_batch_num = 11500 batch_id = 0 last_batch_data = BatchData(conf, data_gen.next()) for tensor_dict in data_gen: ### eps_greedy_sampling batch_data = BatchData(conf, tensor_dict) batch_data.set_decode_len(batch_data.seq_lens()) order = [np.arange(d) for d in batch_data.decode_len()] ### get reward reordered_batch_data = batch_data.get_reordered(order) reordered_batch_data.set_decode_len(batch_data.decode_len()) reward = batch_data.tensor_dict['click_id'].values ### save to replay_memory replay_memory.append((reordered_batch_data, reward)) ### train memory_batch_data, reward = replay_memory[np.random.randint( len(replay_memory))] feed_dict = GenRLFeedConvertor.train_test(memory_batch_data, reward) fetch_dict = td_ct.train(feed_dict) ### logging list_reward.append(np.mean(reward)) list_loss.append(np.array(fetch_dict['loss'])) list_first_Q.append(np.mean(np.array(fetch_dict['c_Q'])[0])) if batch_id % 10 == 0: global_batch_id = epoch_id * guessed_batch_num + batch_id add_scalar_summary(summary_writer, global_batch_id, 'train/rl_reward', np.mean(list_reward)) add_scalar_summary(summary_writer, global_batch_id, 'train/rl_loss', np.mean(list_loss)) add_scalar_summary(summary_writer, global_batch_id, 'train/rl_1st_Q', np.mean(list_first_Q)) list_reward = [] list_loss = [] list_first_Q = [] last_batch_data = BatchData(conf, tensor_dict) batch_id += 1
def eps_greedy_sampling(td_ct, eval_td_ct, args, conf, summary_writer, epoch_id): """eps_greedy_sampling""" dataset = NpzDataset(conf.test_npz_list, conf.npz_config_path, conf.requested_npz_names, if_random_shuffle=False) data_gen = dataset.get_data_generator(conf.batch_size) list_reward = [] last_batch_data = BatchData(conf, data_gen.next()) batch_id = 0 for tensor_dict in data_gen: ### eps_greedy_sampling batch_data = BatchData(conf, tensor_dict) batch_data.set_decode_len(batch_data.seq_lens()) batch_data.expand_candidates(last_batch_data, batch_data.seq_lens()) fetch_dict = td_ct.eps_greedy_sampling( GenRLFeedConvertor.eps_greedy_sampling(batch_data, eps=0)) sampled_id = np.array(fetch_dict['sampled_id']).reshape([-1]) order = sequence_unconcat(sampled_id, batch_data.decode_len()) ### get reward reordered_batch_data = batch_data.get_reordered(order) fetch_dict = eval_td_ct.inference( EvalFeedConvertor.inference(reordered_batch_data)) reward = np.array(fetch_dict['click_prob'])[:, 1] ### logging list_reward.append(np.mean(reward)) if batch_id == 100: break last_batch_data = BatchData(conf, tensor_dict) batch_id += 1 add_scalar_summary(summary_writer, epoch_id, 'eps_greedy_sampling/reward-%s' % args.eval_exp, np.mean(list_reward))
def train(td_ct, eval_td_ct, args, conf, summary_writer, replay_memory, epoch_id): """train""" dataset = NpzDataset(conf.train_npz_list, conf.npz_config_path, conf.requested_npz_names, if_random_shuffle=True) data_gen = dataset.get_data_generator(conf.batch_size) list_reward = [] list_loss = [] list_first_Q = [] guessed_batch_num = 11500 batch_id = 0 last_batch_data = BatchData(conf, data_gen.next()) for tensor_dict in data_gen: ### eps_greedy_sampling batch_data = BatchData(conf, tensor_dict) batch_data.set_decode_len(batch_data.seq_lens()) batch_data.expand_candidates(last_batch_data, batch_data.seq_lens()) fetch_dict = td_ct.eps_greedy_sampling( GenRLFeedConvertor.eps_greedy_sampling(batch_data, eps=0.2)) sampled_id = np.array(fetch_dict['sampled_id']).reshape([-1]) order = sequence_unconcat(sampled_id, batch_data.decode_len()) ### get reward reordered_batch_data = batch_data.get_reordered(order) fetch_dict = eval_td_ct.inference( EvalFeedConvertor.inference(reordered_batch_data)) reward = np.array(fetch_dict['click_prob'])[:, 1] ### save to replay_memory reordered_batch_data2 = batch_data.get_reordered_keep_candidate(order) reordered_batch_data2.set_decode_len(batch_data.decode_len()) replay_memory.append((reordered_batch_data2, reward)) ### train memory_batch_data, reward = replay_memory[np.random.randint( len(replay_memory))] feed_dict = GenRLFeedConvertor.train_test(memory_batch_data, reward) fetch_dict = td_ct.train(feed_dict) ### logging list_reward.append(np.mean(reward)) list_loss.append(np.array(fetch_dict['loss'])) list_first_Q.append(np.mean(np.array(fetch_dict['c_Q'])[0])) if batch_id % 10 == 0: global_batch_id = epoch_id * guessed_batch_num + batch_id add_scalar_summary(summary_writer, global_batch_id, 'train/rl_reward', np.mean(list_reward)) add_scalar_summary(summary_writer, global_batch_id, 'train/rl_loss', np.mean(list_loss)) add_scalar_summary(summary_writer, global_batch_id, 'train/rl_1st_Q', np.mean(list_first_Q)) list_reward = [] list_loss = [] list_first_Q = [] last_batch_data = BatchData(conf, tensor_dict) batch_id += 1