def log_train(td_ct, args, conf, summary_writer, replay_memory, epoch_id): """train""" dataset = NpzDataset(conf.train_npz_list, conf.npz_config_path, conf.requested_npz_names, if_random_shuffle=True) data_gen = dataset.get_data_generator(conf.batch_size) list_reward = [] list_loss = [] list_first_Q = [] guessed_batch_num = 11500 batch_id = 0 last_batch_data = BatchData(conf, data_gen.next()) for tensor_dict in data_gen: ### eps_greedy_sampling batch_data = BatchData(conf, tensor_dict) batch_data.set_decode_len(batch_data.seq_lens()) order = [np.arange(d) for d in batch_data.decode_len()] ### get reward reordered_batch_data = batch_data.get_reordered(order) reordered_batch_data.set_decode_len(batch_data.decode_len()) reward = batch_data.tensor_dict['click_id'].values ### save to replay_memory replay_memory.append((reordered_batch_data, reward)) ### train memory_batch_data, reward = replay_memory[np.random.randint( len(replay_memory))] feed_dict = GenRLFeedConvertor.train_test(memory_batch_data, reward) fetch_dict = td_ct.train(feed_dict) ### logging list_reward.append(np.mean(reward)) list_loss.append(np.array(fetch_dict['loss'])) list_first_Q.append(np.mean(np.array(fetch_dict['c_Q'])[0])) if batch_id % 10 == 0: global_batch_id = epoch_id * guessed_batch_num + batch_id add_scalar_summary(summary_writer, global_batch_id, 'train/rl_reward', np.mean(list_reward)) add_scalar_summary(summary_writer, global_batch_id, 'train/rl_loss', np.mean(list_loss)) add_scalar_summary(summary_writer, global_batch_id, 'train/rl_1st_Q', np.mean(list_first_Q)) list_reward = [] list_loss = [] list_first_Q = [] last_batch_data = BatchData(conf, tensor_dict) batch_id += 1
def eps_greedy_sampling(td_ct, eval_td_ct, args, conf, summary_writer, epoch_id): """eps_greedy_sampling""" dataset = NpzDataset(conf.test_npz_list, conf.npz_config_path, conf.requested_npz_names, if_random_shuffle=False) data_gen = dataset.get_data_generator(conf.batch_size) list_reward = [] last_batch_data = BatchData(conf, data_gen.next()) batch_id = 0 for tensor_dict in data_gen: ### eps_greedy_sampling batch_data = BatchData(conf, tensor_dict) batch_data.set_decode_len(batch_data.seq_lens()) batch_data.expand_candidates(last_batch_data, batch_data.seq_lens()) fetch_dict = td_ct.eps_greedy_sampling( GenRLFeedConvertor.eps_greedy_sampling(batch_data, eps=0)) sampled_id = np.array(fetch_dict['sampled_id']).reshape([-1]) order = sequence_unconcat(sampled_id, batch_data.decode_len()) ### get reward reordered_batch_data = batch_data.get_reordered(order) fetch_dict = eval_td_ct.inference( EvalFeedConvertor.inference(reordered_batch_data)) reward = np.array(fetch_dict['click_prob'])[:, 1] ### logging list_reward.append(np.mean(reward)) if batch_id == 100: break last_batch_data = BatchData(conf, tensor_dict) batch_id += 1 add_scalar_summary(summary_writer, epoch_id, 'eps_greedy_sampling/reward-%s' % args.eval_exp, np.mean(list_reward))
def online_inference_for_test(args, epoch_id, max_steps, ct_sim, dict_gen_ct, summary_writer): """ Do inference on the test test. """ sim_conf = ct_sim.alg.model.conf dataset = NpzDataset(args.test_npz_list, sim_conf.npz_config_path, sim_conf.requested_names, if_random_shuffle=False, one_pass=True) data_gen = dataset.get_data_generator(sim_conf.batch_size) thread_data_gen = threaded_generator(data_gen, capacity=100) list_sim_responses = [] ### online inference last_batch_data = BatchData(sim_conf, thread_data_gen.next()) for batch_id, tensor_dict in enumerate(thread_data_gen): if batch_id > max_steps: break np.random.seed(batch_id) batch_data = BatchData(sim_conf, tensor_dict) batch_data.set_decode_len(batch_data.seq_lens()) batch_data.expand_candidates(last_batch_data, batch_data.seq_lens()) np.random.seed(None) del batch_data.tensor_dict['click_id'] orders, sim_responses = inference_one_batch(args.gen_type, ct_sim, dict_gen_ct, batch_data, eps=0) # , (b, decode_len) # save to replay memory sim_batch_data = batch_data.get_reordered(orders, sim_responses) list_sim_responses.append(sim_responses) last_batch_data = BatchData(sim_conf, tensor_dict) if batch_id % 100 == 0: logging.info('inference test batch %d' % batch_id) list_sum_response = np.sum(np.concatenate(list_sim_responses, 0), 1) # (b,) add_scalar_summary(summary_writer, epoch_id, 'inference/test_sim_responses', np.mean(list_sum_response))
def online_inference(args, epoch_id, max_steps, data_gen, ct_sim, dict_gen_ct, summary_writer, if_print=True): """ Do inference for `max_steps` batches. """ sim_conf = ct_sim.alg.model.conf replay_memory = [] list_sim_responses = [] ### online inference last_batch_data = BatchData(sim_conf, data_gen.next()) for batch_id in range(max_steps): np.random.seed(epoch_id * max_steps + batch_id) tensor_dict = data_gen.next() batch_data = BatchData(sim_conf, tensor_dict) batch_data.set_decode_len(batch_data.seq_lens()) batch_data.expand_candidates(last_batch_data, batch_data.seq_lens()) np.random.seed(None) del batch_data.tensor_dict['click_id'] if batch_data.batch_size() == 1: # otherwise, rl will crash continue orders, sim_responses = inference_one_batch(args.gen_type, ct_sim, dict_gen_ct, batch_data, eps=args.infer_eps) # , (b, decode_len) # save to replay memory sim_batch_data = batch_data.get_reordered(orders, sim_responses) replay_memory.append(sim_batch_data) list_sim_responses.append(sim_responses) last_batch_data = BatchData(sim_conf, tensor_dict) if batch_id % 100 == 0 and if_print: logging.info('inference epoch %d batch %d' % (epoch_id, batch_id)) if if_print: list_sum_response = np.sum(np.concatenate(list_sim_responses, 0), 1) # (b,) add_scalar_summary(summary_writer, epoch_id, 'inference/sim_responses', np.mean(list_sum_response)) return replay_memory
def train(td_ct, eval_td_ct, args, conf, summary_writer, replay_memory, epoch_id): """train""" dataset = NpzDataset(conf.train_npz_list, conf.npz_config_path, conf.requested_npz_names, if_random_shuffle=True) data_gen = dataset.get_data_generator(conf.batch_size) list_reward = [] list_loss = [] list_first_Q = [] guessed_batch_num = 11500 batch_id = 0 last_batch_data = BatchData(conf, data_gen.next()) for tensor_dict in data_gen: ### eps_greedy_sampling batch_data = BatchData(conf, tensor_dict) batch_data.set_decode_len(batch_data.seq_lens()) batch_data.expand_candidates(last_batch_data, batch_data.seq_lens()) fetch_dict = td_ct.eps_greedy_sampling( GenRLFeedConvertor.eps_greedy_sampling(batch_data, eps=0.2)) sampled_id = np.array(fetch_dict['sampled_id']).reshape([-1]) order = sequence_unconcat(sampled_id, batch_data.decode_len()) ### get reward reordered_batch_data = batch_data.get_reordered(order) fetch_dict = eval_td_ct.inference( EvalFeedConvertor.inference(reordered_batch_data)) reward = np.array(fetch_dict['click_prob'])[:, 1] ### save to replay_memory reordered_batch_data2 = batch_data.get_reordered_keep_candidate(order) reordered_batch_data2.set_decode_len(batch_data.decode_len()) replay_memory.append((reordered_batch_data2, reward)) ### train memory_batch_data, reward = replay_memory[np.random.randint( len(replay_memory))] feed_dict = GenRLFeedConvertor.train_test(memory_batch_data, reward) fetch_dict = td_ct.train(feed_dict) ### logging list_reward.append(np.mean(reward)) list_loss.append(np.array(fetch_dict['loss'])) list_first_Q.append(np.mean(np.array(fetch_dict['c_Q'])[0])) if batch_id % 10 == 0: global_batch_id = epoch_id * guessed_batch_num + batch_id add_scalar_summary(summary_writer, global_batch_id, 'train/rl_reward', np.mean(list_reward)) add_scalar_summary(summary_writer, global_batch_id, 'train/rl_loss', np.mean(list_loss)) add_scalar_summary(summary_writer, global_batch_id, 'train/rl_1st_Q', np.mean(list_first_Q)) list_reward = [] list_loss = [] list_first_Q = [] last_batch_data = BatchData(conf, tensor_dict) batch_id += 1