예제 #1
0
def test(td_ct, args, conf, summary_writer, epoch_id):
    """eval auc on the full test dataset"""
    dataset = NpzDataset(conf.test_npz_list,
                         conf.npz_config_path,
                         conf.requested_npz_names,
                         if_random_shuffle=False)
    data_gen = dataset.get_data_generator(conf.batch_size)

    auc_metric = AUCMetrics()
    seq_rmse_metric = SequenceRMSEMetrics()
    seq_correlation_metric = SequenceCorrelationMetrics()
    batch_id = 0
    for tensor_dict in data_gen:
        batch_data = BatchData(conf, tensor_dict)
        fetch_dict = td_ct.test(EvalFeedConvertor.train_test(batch_data))
        click_id = np.array(fetch_dict['click_id']).flatten()
        click_prob = np.array(fetch_dict['click_prob'])[:, 1]
        click_id_unconcat = sequence_unconcat(click_id, batch_data.seq_lens())
        click_prob_unconcat = sequence_unconcat(click_prob,
                                                batch_data.seq_lens())
        auc_metric.add(labels=click_id, y_scores=click_prob)
        for sub_click_id, sub_click_prob in zip(click_id_unconcat,
                                                click_prob_unconcat):
            seq_rmse_metric.add(labels=sub_click_id, preds=sub_click_prob)
            seq_correlation_metric.add(labels=sub_click_id,
                                       preds=sub_click_prob)

        batch_id += 1

    add_scalar_summary(summary_writer, epoch_id, 'test/auc',
                       auc_metric.overall_auc())
    add_scalar_summary(summary_writer, epoch_id, 'test/seq_rmse',
                       seq_rmse_metric.overall_rmse())
    add_scalar_summary(summary_writer, epoch_id, 'test/seq_correlation',
                       seq_correlation_metric.overall_correlation())
def train_model(X_train, X_val, model, training_logger, bs=10000):
    batch_data = BatchData(X_train, bs)
    for i in range(1501):
        logprob = model.train_step(batch_data.get_batch())
        training_logger.add(i, logprob)
        if i % 100 == 0:
            val_logprob = model.eval(X_val).numpy()
            training_logger.add_val(i, val_logprob)
예제 #3
0
def main_batch_rl(args):
    """
    Include online inference and offline training.
    """
    ct_sim = get_ct_sim(args.sim_exp, args.use_cuda, args.train_mode, args.sim_cell_type, args.output_dim)
    assert ct_sim.ckp_step > 0, (ct_sim.ckp_step)
    dict_gen_ct = {}
    if args.gen_type in ['env', 'env_credit', 'env_rl']:
        if args.gen_type == 'env_rl':
            assert args.env_output_type == 'click', \
                ('env_rl only support click env, which will be used as a simulator', args.env_output_type)
        ct_env = get_ct_env(args.env_exp, args.use_cuda, args.train_mode, args.env_output_type, args.output_dim)
        dict_gen_ct['env'] = ct_env
    if args.gen_type in ['env_credit', 'mc_credit']:
        ct_credit = get_ct_credit(args.credit_exp, args.use_cuda, args.train_mode, args.credit_scale)
        dict_gen_ct['credit'] = ct_credit
    if args.gen_type in ['rl', 'env_rl']:
        ct_rl = get_ct_rl(args.rl_exp, args.use_cuda, args.train_mode, args.rl_gamma, args.rl_Q_type)
        dict_gen_ct['rl'] = ct_rl
    if args.gen_type == 'ddpg':
        ct_ddpg = get_ct_ddpg(args.ddpg_exp, args.use_cuda, args.train_mode, args.ddpg_gamma)
        dict_gen_ct['ddpg'] = ct_ddpg

    ### dataset
    sim_conf = ct_sim.alg.model.conf
    dataset = NpzDataset(args.train_npz_list, 
                        sim_conf.npz_config_path, 
                        sim_conf.requested_names,
                        if_random_shuffle=True,
                        one_pass=True)

    if args.gen_type == 'env_rl':       # env_rl will need data from env_conf
        env_conf = ct_env.alg.model.conf
        env_dataset = NpzDataset(args.train_npz_list, 
                                env_conf.npz_config_path, 
                                env_conf.requested_names,
                                if_random_shuffle=True,
                                one_pass=False)

    summary_writer = tf.summary.FileWriter(args.summary_dir)
    max_test_steps = 1000
    for epoch_id in range(50):
        if args.gen_type == 'env_rl':
            env_data_gen = env_dataset.get_data_generator(env_conf.batch_size)
            thread_env_data_gen = threaded_generator(env_data_gen, capacity=10)
        else:
            thread_env_data_gen = None

        data_gen = dataset.get_data_generator(sim_conf.batch_size)
        thread_data_gen = threaded_generator(data_gen, capacity=100)
        for batch_id, tensor_dict in enumerate(thread_data_gen):
            if_save = True if batch_id == 0 else False
            batch_data = BatchData(sim_conf, tensor_dict)
            if batch_data.batch_size() == 1:    # otherwise, rl will crash
                continue
            offline_training(args, epoch_id, [batch_data], dict_gen_ct, summary_writer, if_save=if_save, env_rl_data_gen=thread_env_data_gen)
        if epoch_id % 1 == 0:
            online_inference_for_test(args, epoch_id, max_test_steps, ct_sim, dict_gen_ct, summary_writer)
예제 #4
0
def log_train(td_ct, args, conf, summary_writer, replay_memory, epoch_id):
    """train"""
    dataset = NpzDataset(conf.train_npz_list,
                         conf.npz_config_path,
                         conf.requested_npz_names,
                         if_random_shuffle=True)
    data_gen = dataset.get_data_generator(conf.batch_size)

    list_reward = []
    list_loss = []
    list_first_Q = []
    guessed_batch_num = 11500
    batch_id = 0
    last_batch_data = BatchData(conf, data_gen.next())
    for tensor_dict in data_gen:
        ### eps_greedy_sampling
        batch_data = BatchData(conf, tensor_dict)
        batch_data.set_decode_len(batch_data.seq_lens())
        order = [np.arange(d) for d in batch_data.decode_len()]

        ### get reward
        reordered_batch_data = batch_data.get_reordered(order)
        reordered_batch_data.set_decode_len(batch_data.decode_len())
        reward = batch_data.tensor_dict['click_id'].values

        ### save to replay_memory
        replay_memory.append((reordered_batch_data, reward))

        ### train
        memory_batch_data, reward = replay_memory[np.random.randint(
            len(replay_memory))]
        feed_dict = GenRLFeedConvertor.train_test(memory_batch_data, reward)
        fetch_dict = td_ct.train(feed_dict)

        ### logging
        list_reward.append(np.mean(reward))
        list_loss.append(np.array(fetch_dict['loss']))
        list_first_Q.append(np.mean(np.array(fetch_dict['c_Q'])[0]))
        if batch_id % 10 == 0:
            global_batch_id = epoch_id * guessed_batch_num + batch_id
            add_scalar_summary(summary_writer, global_batch_id,
                               'train/rl_reward', np.mean(list_reward))
            add_scalar_summary(summary_writer, global_batch_id,
                               'train/rl_loss', np.mean(list_loss))
            add_scalar_summary(summary_writer, global_batch_id,
                               'train/rl_1st_Q', np.mean(list_first_Q))
            list_reward = []
            list_loss = []
            list_first_Q = []

        last_batch_data = BatchData(conf, tensor_dict)
        batch_id += 1
예제 #5
0
def test(ct, args, conf, summary_writer, epoch_id, item_shuffle=False):
    """eval auc on the full test dataset"""
    dataset = NpzDataset(args.test_npz_list,
                         conf.npz_config_path,
                         conf.requested_names,
                         if_random_shuffle=True)
    data_gen = dataset.get_data_generator(conf.batch_size)

    click_rmse_metric = RMSEMetrics()
    click_accu_metric = AccuracyMetrics()
    for batch_id, tensor_dict in enumerate(
            threaded_generator(data_gen, capacity=100)):
        batch_data = BatchData(conf, tensor_dict)
        fetch_dict = ct.train(SimFeedConvertor.train_test(batch_data))
        click_id = np.array(fetch_dict['click_id']).flatten()
        click_score = click_prob_2_score(np.array(
            fetch_dict['click_prob'])).flatten()
        click_rmse_metric.add(labels=click_id, preds=click_score)
        click_accu_metric.add(labels=click_id,
                              probs=np.array(fetch_dict['click_prob']))

    add_scalar_summary(summary_writer, epoch_id, 'test/click_rmse',
                       click_rmse_metric.overall_rmse())
    add_scalar_summary(summary_writer, epoch_id, 'test/click_accuracy',
                       click_accu_metric.overall_accuracy())
    for key, value in click_accu_metric.overall_metrics().items():
        add_scalar_summary(summary_writer, epoch_id, 'test/%s' % key, value)
예제 #6
0
def main_calculate_credit_variance(args, credit_type):
    """
    Calculate variance of credit by varing following items.
    """
    def sampling(click_prob):
        """
        click_prob: (n, n_class)
        """
        n_class = click_prob.shape[1]
        return np.int64([np.random.choice(n_class, 1, p=p) for p in click_prob]).reshape([-1])

    assert args.gen_type == 'env'
    ct_sim = get_ct_sim(args.sim_exp, args.use_cuda, args.train_mode, args.sim_cell_type, args.output_dim)
    assert ct_sim.ckp_step > 0, (ct_sim.ckp_step)
    ct_env = get_ct_env(args.env_exp, args.use_cuda, args.train_mode, args.env_output_type, args.output_dim)
    assert ct_env.ckp_step > 0, (ct_env.ckp_step)

    ### dataset
    sim_conf = ct_sim.alg.model.conf
    dataset = NpzDataset(args.train_npz_list, 
                        sim_conf.npz_config_path, 
                        sim_conf.requested_names,
                        if_random_shuffle=True,
                        one_pass=True)
    data_gen = dataset.get_data_generator(sim_conf.batch_size)
    thread_data_gen = threaded_generator(data_gen, capacity=100)

    n_vary = 64
    base_batch_data = BatchData(sim_conf, thread_data_gen.next())
    batch_size = base_batch_data.batch_size()
    batch_credits = []
    for pos in range(base_batch_data.seq_lens()[0]):
        list_credits = []
        for batch_id, tensor_dict in enumerate(thread_data_gen):
            if len(list_credits) == n_vary:
                break
            ref_batch_data = BatchData(sim_conf, tensor_dict)
            if ref_batch_data.batch_size() != batch_size:
                continue
            mix_batch_data = base_batch_data.replace_following_items(pos + 1, ref_batch_data)
            sim_fetch_dict = ct_sim.inference(SimFeedConvertor.inference(mix_batch_data))
            sim_response = sampling(np.array(sim_fetch_dict['click_prob'])).reshape([-1, 1]).astype('int64')
            mix_batch_data.tensor_dict['click_id'] = FakeTensor(sim_response, mix_batch_data.seq_lens())
            credit = generate_credit_one_batch(ct_env, 
                                            mix_batch_data, 
                                            credit_type=credit_type,
                                            credit_gamma=args.credit_gamma,
                                            globbase=None)
            credit = credit.reshape([batch_size, -1])
            list_credits.append(credit[:, pos].reshape(-1, 1))
        list_credits = np.concatenate(list_credits, 1)  # (b, n_vary)
        batch_credits.append(list_credits)
    batch_credits = np.concatenate(batch_credits, 0)    # (seq_len*b, n_vary)
    print(credit_type)
    print(batch_credits.shape)
    print('(s,a)-wise credit variance', np.mean(np.std(batch_credits, 1)))
예제 #7
0
def train(ct, args, conf, summary_writer, epoch_id):
    """train for conf.train_interval steps"""
    dataset = NpzDataset(args.train_npz_list,
                         conf.npz_config_path,
                         conf.requested_names,
                         if_random_shuffle=True)
    data_gen = dataset.get_data_generator(conf.batch_size)

    list_loss = []
    list_epoch_loss = []
    for batch_id, tensor_dict in enumerate(
            threaded_generator(data_gen, capacity=100)):
        batch_data = BatchData(conf, tensor_dict)
        fetch_dict = ct.train(SimFeedConvertor.train_test(batch_data))
        list_loss.append(np.array(fetch_dict['loss']))
        list_epoch_loss.append(np.mean(np.array(fetch_dict['loss'])))
        if batch_id % conf.prt_interval == 0:
            logging.info('batch_id:%d loss:%f' %
                         (batch_id, np.mean(list_loss)))
            list_loss = []

    add_scalar_summary(summary_writer, epoch_id, 'train/loss',
                       np.mean(list_epoch_loss))
예제 #8
0
def train(td_ct, args, conf, summary_writer, epoch_id):
    """train for conf.train_interval steps"""
    dataset = NpzDataset(conf.train_npz_list,
                         conf.npz_config_path,
                         conf.requested_npz_names,
                         if_random_shuffle=True)
    data_gen = dataset.get_data_generator(conf.batch_size)

    list_epoch_loss = []
    list_loss = []
    batch_id = 0
    for tensor_dict in data_gen:
        batch_data = BatchData(conf, tensor_dict)
        fetch_dict = td_ct.train(GenSLFeedConvertor.train_test(batch_data))
        list_loss.append(np.array(fetch_dict['loss']))
        list_epoch_loss.append(np.mean(np.array(fetch_dict['loss'])))
        if batch_id % conf.prt_interval == 0:
            logging.info('batch_id:%d loss:%f' %
                         (batch_id, np.mean(list_loss)))
            list_loss = []
        batch_id += 1

    add_scalar_summary(summary_writer, epoch_id, 'train/loss',
                       np.mean(list_epoch_loss))
예제 #9
0
def evaluate(td_ct, eval_td_ct, args, conf, epoch_id):
    """softmax_sampling"""
    np.random.seed(
        0
    )  # IMPORTANT. To have the same candidates, since the candidates is selected by np.random.choice.
    dataset = NpzDataset(conf.test_npz_list,
                         conf.npz_config_path,
                         conf.requested_npz_names,
                         if_random_shuffle=False)
    batch_size = 250
    data_gen = dataset.get_data_generator(batch_size)

    max_batch_id = 200
    list_n = [1, 20, 40]
    dict_reward = {'eps_greedy': [], 'softmax': {n: [] for n in list_n}}
    p_counter = PatternCounter()
    last_batch_data = BatchData(conf, data_gen.next())
    for batch_id in range(max_batch_id):

        def get_list_wise_reward(batch_data, order):
            reordered_batch_data = batch_data.get_reordered(order)
            fetch_dict = eval_td_ct.inference(
                EvalFeedConvertor.inference(reordered_batch_data))
            reward = np.array(fetch_dict['click_prob'])[:, 1]
            reward_unconcat = sequence_unconcat(reward,
                                                [len(od) for od in order])
            return [np.sum(rw) for rw in reward_unconcat]

        def greedy_sampling(batch_data):
            fetch_dict = td_ct.eps_greedy_sampling(
                GenRLFeedConvertor.eps_greedy_sampling(batch_data, eps=0))
            sampled_id = np.array(fetch_dict['sampled_id']).reshape([-1])
            order = sequence_unconcat(sampled_id, batch_data.decode_len())
            list_wise_reward = get_list_wise_reward(batch_data, order)  # (b,)
            return order, list_wise_reward

        def softmax_sampling(batch_data, max_sampling_time):
            mat_list_wise_reward = []
            mat_order = []
            for i in range(max_sampling_time):
                fetch_dict = td_ct.softmax_sampling(
                    GenRLFeedConvertor.softmax_sampling(batch_data, eta=0.1))
                sampled_id = np.array(fetch_dict['sampled_id']).reshape([-1])
                order = sequence_unconcat(sampled_id, batch_data.decode_len())
                list_wise_reward = get_list_wise_reward(batch_data, order)
                mat_order.append(order)
                mat_list_wise_reward.append(list_wise_reward)
            mat_list_wise_reward = np.array(
                mat_list_wise_reward)  # (max_sampling_time, b)
            return mat_order, mat_list_wise_reward  # (max_sampling_time, b, var_seq_len),

        tensor_dict = data_gen.next()
        batch_data = BatchData(conf, tensor_dict)
        batch_data.set_decode_len(batch_data.seq_lens())
        batch_data.expand_candidates(last_batch_data, batch_data.seq_lens())
        p_counter.add_log_pattern(batch_data)

        ### eps_greedy_sampling
        order, list_wise_reward = greedy_sampling(batch_data)
        dict_reward['eps_greedy'] += list_wise_reward
        p_counter.add_sampled_pattern('eps_greedy', batch_data, order)

        ### softmax_sampling
        max_sampling_time = np.max(list_n)
        mat_order, mat_list_wise_reward = softmax_sampling(
            batch_data, max_sampling_time)
        for n in list_n:
            dict_reward['softmax'][n] += np.max(mat_list_wise_reward[:n],
                                                0).tolist()
            max_indice = np.argmax(mat_list_wise_reward[:n], 0)  # (b,)
            max_order = [
                mat_order[max_id][b_id]
                for b_id, max_id in enumerate(max_indice)
            ]
            p_counter.add_sampled_pattern('softmax_%d' % n, batch_data,
                                          max_order)

        ### log
        if batch_id % 10 == 0:
            logging.info('batch_id:%d eps_greedy %f' %
                         (batch_id, np.mean(dict_reward['eps_greedy'])))
            for n in list_n:
                logging.info('batch_id:%d softmax_%d %f' %
                             (batch_id, n, np.mean(dict_reward['softmax'][n])))
            p_counter.Print()

        last_batch_data = BatchData(conf, tensor_dict)

    ### log
    logging.info('final eps_greedy %f' % np.mean(dict_reward['eps_greedy']))
    for n in list_n:
        logging.info('final softmax_%d %f' %
                     (n, np.mean(dict_reward['softmax'][n])))
    p_counter.Print()

    ### save
    pickle_file = 'tmp/%s-eval_%s.pkl' % (args.exp, args.eval_model)
    p_counter.save(pickle_file)
예제 #10
0
def eps_greedy_sampling(td_ct, eval_td_ct, args, conf, summary_writer,
                        epoch_id):
    """eps_greedy_sampling"""
    dataset = NpzDataset(conf.test_npz_list,
                         conf.npz_config_path,
                         conf.requested_npz_names,
                         if_random_shuffle=False)
    data_gen = dataset.get_data_generator(conf.batch_size)

    list_reward = []
    last_batch_data = BatchData(conf, data_gen.next())
    batch_id = 0
    for tensor_dict in data_gen:
        ### eps_greedy_sampling
        batch_data = BatchData(conf, tensor_dict)
        batch_data.set_decode_len(batch_data.seq_lens())
        batch_data.expand_candidates(last_batch_data, batch_data.seq_lens())

        fetch_dict = td_ct.eps_greedy_sampling(
            GenRLFeedConvertor.eps_greedy_sampling(batch_data, eps=0))
        sampled_id = np.array(fetch_dict['sampled_id']).reshape([-1])
        order = sequence_unconcat(sampled_id, batch_data.decode_len())

        ### get reward
        reordered_batch_data = batch_data.get_reordered(order)
        fetch_dict = eval_td_ct.inference(
            EvalFeedConvertor.inference(reordered_batch_data))
        reward = np.array(fetch_dict['click_prob'])[:, 1]

        ### logging
        list_reward.append(np.mean(reward))

        if batch_id == 100:
            break

        last_batch_data = BatchData(conf, tensor_dict)
        batch_id += 1

    add_scalar_summary(summary_writer, epoch_id,
                       'eps_greedy_sampling/reward-%s' % args.eval_exp,
                       np.mean(list_reward))
예제 #11
0
def train(td_ct, eval_td_ct, args, conf, summary_writer, replay_memory,
          epoch_id):
    """train"""
    dataset = NpzDataset(conf.train_npz_list,
                         conf.npz_config_path,
                         conf.requested_npz_names,
                         if_random_shuffle=True)
    data_gen = dataset.get_data_generator(conf.batch_size)

    list_reward = []
    list_loss = []
    list_first_Q = []
    guessed_batch_num = 11500
    batch_id = 0
    last_batch_data = BatchData(conf, data_gen.next())
    for tensor_dict in data_gen:
        ### eps_greedy_sampling
        batch_data = BatchData(conf, tensor_dict)
        batch_data.set_decode_len(batch_data.seq_lens())
        batch_data.expand_candidates(last_batch_data, batch_data.seq_lens())

        fetch_dict = td_ct.eps_greedy_sampling(
            GenRLFeedConvertor.eps_greedy_sampling(batch_data, eps=0.2))
        sampled_id = np.array(fetch_dict['sampled_id']).reshape([-1])
        order = sequence_unconcat(sampled_id, batch_data.decode_len())

        ### get reward
        reordered_batch_data = batch_data.get_reordered(order)
        fetch_dict = eval_td_ct.inference(
            EvalFeedConvertor.inference(reordered_batch_data))
        reward = np.array(fetch_dict['click_prob'])[:, 1]

        ### save to replay_memory
        reordered_batch_data2 = batch_data.get_reordered_keep_candidate(order)
        reordered_batch_data2.set_decode_len(batch_data.decode_len())
        replay_memory.append((reordered_batch_data2, reward))

        ### train
        memory_batch_data, reward = replay_memory[np.random.randint(
            len(replay_memory))]
        feed_dict = GenRLFeedConvertor.train_test(memory_batch_data, reward)
        fetch_dict = td_ct.train(feed_dict)

        ### logging
        list_reward.append(np.mean(reward))
        list_loss.append(np.array(fetch_dict['loss']))
        list_first_Q.append(np.mean(np.array(fetch_dict['c_Q'])[0]))
        if batch_id % 10 == 0:
            global_batch_id = epoch_id * guessed_batch_num + batch_id
            add_scalar_summary(summary_writer, global_batch_id,
                               'train/rl_reward', np.mean(list_reward))
            add_scalar_summary(summary_writer, global_batch_id,
                               'train/rl_loss', np.mean(list_loss))
            add_scalar_summary(summary_writer, global_batch_id,
                               'train/rl_1st_Q', np.mean(list_first_Q))
            list_reward = []
            list_loss = []
            list_first_Q = []

        last_batch_data = BatchData(conf, tensor_dict)
        batch_id += 1
예제 #12
0
def online_inference_for_test(args, epoch_id, max_steps, ct_sim, dict_gen_ct, summary_writer):
    """
    Do inference on the test test.
    """
    sim_conf = ct_sim.alg.model.conf
    dataset = NpzDataset(args.test_npz_list, 
                        sim_conf.npz_config_path, 
                        sim_conf.requested_names,
                        if_random_shuffle=False,
                        one_pass=True)
    data_gen = dataset.get_data_generator(sim_conf.batch_size)
    thread_data_gen = threaded_generator(data_gen, capacity=100)

    list_sim_responses = []
    ### online inference
    last_batch_data = BatchData(sim_conf, thread_data_gen.next())
    for batch_id, tensor_dict in enumerate(thread_data_gen):
        if batch_id > max_steps:
            break
        np.random.seed(batch_id)
        batch_data = BatchData(sim_conf, tensor_dict)
        batch_data.set_decode_len(batch_data.seq_lens())
        batch_data.expand_candidates(last_batch_data, batch_data.seq_lens())
        np.random.seed(None)
        del batch_data.tensor_dict['click_id']

        orders, sim_responses = inference_one_batch(args.gen_type, ct_sim, dict_gen_ct, batch_data, eps=0) # , (b, decode_len)

        # save to replay memory
        sim_batch_data = batch_data.get_reordered(orders, sim_responses)
        list_sim_responses.append(sim_responses)
        last_batch_data = BatchData(sim_conf, tensor_dict)

        if batch_id % 100 == 0:
            logging.info('inference test batch %d' % batch_id)

    list_sum_response = np.sum(np.concatenate(list_sim_responses, 0), 1)    # (b,)
    add_scalar_summary(summary_writer, epoch_id, 'inference/test_sim_responses', np.mean(list_sum_response))
예제 #13
0
def online_inference(args, epoch_id, max_steps, data_gen, ct_sim, dict_gen_ct, summary_writer, if_print=True):
    """
    Do inference for `max_steps` batches.
    """
    sim_conf = ct_sim.alg.model.conf

    replay_memory = []
    list_sim_responses = []
    ### online inference
    last_batch_data = BatchData(sim_conf, data_gen.next())
    for batch_id in range(max_steps):
        np.random.seed(epoch_id * max_steps + batch_id)
        tensor_dict = data_gen.next()
        batch_data = BatchData(sim_conf, tensor_dict)
        batch_data.set_decode_len(batch_data.seq_lens())
        batch_data.expand_candidates(last_batch_data, batch_data.seq_lens())
        np.random.seed(None)
        del batch_data.tensor_dict['click_id']

        if batch_data.batch_size() == 1:    # otherwise, rl will crash
            continue

        orders, sim_responses = inference_one_batch(args.gen_type, ct_sim, dict_gen_ct, batch_data, eps=args.infer_eps) # , (b, decode_len)

        # save to replay memory
        sim_batch_data = batch_data.get_reordered(orders, sim_responses)
        replay_memory.append(sim_batch_data)
        list_sim_responses.append(sim_responses)
        last_batch_data = BatchData(sim_conf, tensor_dict)

        if batch_id % 100 == 0 and if_print:
            logging.info('inference epoch %d batch %d' % (epoch_id, batch_id))

    if if_print:
        list_sum_response = np.sum(np.concatenate(list_sim_responses, 0), 1)    # (b,)
        add_scalar_summary(summary_writer, epoch_id, 'inference/sim_responses', np.mean(list_sum_response))
    return replay_memory