def __init__(self,
                 thread_index,
                 global_net,
                 local_net,
                 initial_learning_rate,
                 learning_rate_input,
                 grad_applier,
                 device=None,
                 batch_size=None,
                 use_rollout=False,
                 one_buffer=False,
                 sampleR=False):
        """Initialize A3CTrainingThread class."""
        assert self.action_size != -1

        self.is_sil_thread = True
        self.thread_idx = thread_index
        self.initial_learning_rate = initial_learning_rate
        self.learning_rate_input = learning_rate_input
        self.local_net = local_net
        self.batch_size = batch_size
        self.use_rollout = use_rollout
        self.one_buffer = one_buffer
        self.sampleR = sampleR

        logger.info("===SIL thread_index: {}===".format(self.thread_idx))
        logger.info("device: {}".format(device))
        logger.info("action_size: {}".format(self.action_size))
        logger.info("entropy_beta: {}".format(self.entropy_beta))
        logger.info("gamma: {}".format(self.gamma))
        logger.info("reward_type: {}".format(self.reward_type))
        logger.info("transformed_bellman: {}".format(
            colored(self.transformed_bellman,
                    "green" if self.transformed_bellman else "red")))
        logger.info("clip_norm: {}".format(self.clip_norm))
        logger.info("use_grad_cam: {}".format(
            colored(self.use_grad_cam,
                    "green" if self.use_grad_cam else "red")))

        reward_clipped = True if self.reward_type == 'CLIP' else False

        local_vars = self.local_net.get_vars

        with tf.device(device):
            critic_lr = 0.1
            entropy_beta = 0
            w_loss = 1.0
            logger.info("sil batch_size: {}".format(self.batch_size))
            logger.info("sil w_loss: {}".format(w_loss))
            logger.info("sil critic_lr: {}".format(critic_lr))
            logger.info("sil entropy_beta: {}".format(entropy_beta))
            self.local_net.prepare_sil_loss(entropy_beta=entropy_beta,
                                            w_loss=w_loss,
                                            critic_lr=critic_lr)
            var_refs = [v._ref() for v in local_vars()]

            self.sil_gradients = tf.gradients(self.local_net.total_loss_sil,
                                              var_refs)

        global_vars = global_net.get_vars

        with tf.device(device):
            if self.clip_norm is not None:
                self.sil_gradients, grad_norm = tf.clip_by_global_norm(
                    self.sil_gradients, self.clip_norm)
            sil_gradients_global = list(zip(self.sil_gradients, global_vars()))
            sil_gradients_local = list(zip(self.sil_gradients, local_vars()))
            self.sil_apply_gradients = grad_applier.apply_gradients(
                sil_gradients_global)
            self.sil_apply_gradients_local = grad_applier.apply_gradients(
                sil_gradients_local)

        self.sync = self.local_net.sync_from(global_net)

        self.episode = SILReplayMemory(self.action_size,
                                       max_len=None,
                                       gamma=self.gamma,
                                       clip=reward_clipped,
                                       height=self.local_net.in_shape[0],
                                       width=self.local_net.in_shape[1],
                                       phi_length=self.local_net.in_shape[2],
                                       reward_constant=self.reward_constant)

        # temp_buffer for mixing and re-sample (brown arrow in Figure 1)
        # initial only when needed (A3CTBSIL & LiDER-OneBuffer does not need temp_buffer)
        self.temp_buffer = None
        if (self.use_rollout) and (not self.one_buffer):
            self.temp_buffer = SILReplayMemory(
                self.action_size,
                max_len=self.batch_size * 2,
                gamma=self.gamma,
                clip=reward_clipped,
                height=self.local_net.in_shape[0],
                width=self.local_net.in_shape[1],
                phi_length=self.local_net.in_shape[2],
                priority=True,
                reward_constant=self.reward_constant)
Esempio n. 2
0
def run_a3c(args):
    """Run A3C experiment."""
    GYM_ENV_NAME = args.gym_env.replace('-', '_')
    GAME_NAME = args.gym_env.replace('NoFrameskip-v4','')

    # setup folder name and path to folder
    folder = pathlib.Path(setup_folder(args, GYM_ENV_NAME))

    # setup GPU (if applicable)
    import tensorflow as tf
    gpu_options = setup_gpu(tf, args.use_gpu, args.gpu_fraction)

    ######################################################
    # setup default device
    device = "/cpu:0"

    global_t = 0
    rewards = {'train': {}, 'eval': {}}
    best_model_reward = -(sys.maxsize)
    if args.load_pretrained_model:
        class_rewards = {'class_eval': {}}

    # setup logging info for analysis, see Section 4.2 of the paper
    sil_dict = {
                # count number of SIL updates
                "sil_ctr":{},
                # total number of butter D sampled during SIL
                "sil_a3c_sampled":{},
                # total number of buffer D samples (i.e., generated by A3C workers) used during SIL (i.e., passed max op)
                "sil_a3c_used":{},
                # the return of used samples for buffer D
                "sil_a3c_used_return":{},
                # total number of buffer R sampled during SIL
                "sil_rollout_sampled":{},
                # total number of buffer R samples (i.e., generated by refresher worker) used during SIL (i.e., passed max op)
                "sil_rollout_used":{},
                # the return of used samples for buffer R
                "sil_rollout_used_return":{},
                # number of old samples still used (even after refreshing)
                "sil_old_used":{}
                }
    sil_ctr, sil_a3c_sampled, sil_a3c_used, sil_a3c_used_return = 0, 0, 0, 0
    sil_rollout_sampled, sil_rollout_used, sil_rollout_used_return = 0, 0, 0
    sil_old_used = 0


    rollout_dict = {
                    # total number of rollout performed
                    "rollout_ctr": {},
                    # total number of successful rollout (i.e., Gnew > G)
                    "rollout_added_ctr":{},
                    # the return of Gnew
                    "rollout_new_return":{},
                    # the return of G
                    "rollout_old_return":{}
                    }
    rollout_ctr, rollout_added_ctr = 0, 0
    rollout_new_return = 0 # this records the total, avg = this / rollout_added_ctr
    rollout_old_return = 0 # this records the total, avg = this / rollout_added_ctr

    # setup file names
    reward_fname = folder / '{}-a3c-rewards.pkl'.format(GYM_ENV_NAME)
    sil_fname = folder / '{}-a3c-dict-sil.pkl'.format(GYM_ENV_NAME)
    rollout_fname = folder / '{}-a3c-dict-rollout.pkl'.format(GYM_ENV_NAME)
    if args.load_pretrained_model:
        class_reward_fname = folder / '{}-class-rewards.pkl'.format(GYM_ENV_NAME)

    sharedmem_fname = folder / '{}-sharedmem.pkl'.format(GYM_ENV_NAME)
    sharedmem_params_fname = folder / '{}-sharedmem-params.pkl'.format(GYM_ENV_NAME)
    sharedmem_trees_fname = folder / '{}-sharedmem-trees.pkl'.format(GYM_ENV_NAME)

    rolloutmem_fname = folder / '{}-rolloutmem.pkl'.format(GYM_ENV_NAME)
    rolloutmem_params_fname = folder / '{}-rolloutmem-params.pkl'.format(GYM_ENV_NAME)
    rolloutmem_trees_fname = folder / '{}-rolloutmem-trees.pkl'.format(GYM_ENV_NAME)

    # for removing older ckpt, save mem space
    prev_ckpt_t = -1

    stop_req = False

    game_state = GameState(env_id=args.gym_env)
    action_size = game_state.env.action_space.n
    game_state.close()
    del game_state.env
    del game_state

    input_shape = (args.input_shape, args.input_shape, 4)
    #######################################################
    # setup global A3C
    GameACFFNetwork.use_mnih_2015 = args.use_mnih_2015
    global_network = GameACFFNetwork(
        action_size, -1, device, padding=args.padding,
        in_shape=input_shape)
    logger.info('A3C Initial Learning Rate={}'.format(args.initial_learn_rate))

    # setup pretrained model
    global_pretrained_model = None
    local_pretrained_model = None
    pretrain_graph = None

    # if use pretrained model to refresh
    # then must load pretrained model
    # otherwise, don't load model
    if args.use_lider and args.nstep_bc > 0:
        assert args.load_pretrained_model, "refreshing with other policies, must load a pre-trained model (TA or BC)"
    else:
        assert not args.load_pretrained_model, "refreshing with the current policy, don't load pre-trained models"

    if args.load_pretrained_model:
        pretrain_graph, global_pretrained_model = setup_pretrained_model(tf,
            args, action_size, input_shape,
            device="/gpu:0" if args.use_gpu else device)
        assert global_pretrained_model is not None
        assert pretrain_graph is not None

    time.sleep(2.0)

    # setup experience memory
    shared_memory = None # => this is BufferD
    rollout_buffer = None # => this is BufferR
    if args.use_sil:
        shared_memory = SILReplayMemory(
            action_size, max_len=args.memory_length, gamma=args.gamma,
            clip=False if args.unclipped_reward else True,
            height=input_shape[0], width=input_shape[1],
            phi_length=input_shape[2], priority=args.priority_memory,
            reward_constant=args.reward_constant)

        if args.use_lider and not args.onebuffer:
            rollout_buffer = SILReplayMemory(
                action_size, max_len=args.memory_length, gamma=args.gamma,
                clip=False if args.unclipped_reward else True,
                height=input_shape[0], width=input_shape[1],
                phi_length=input_shape[2], priority=args.priority_memory,
                reward_constant=args.reward_constant)

        # log memory information
        shared_memory.log()
        if args.use_lider and not args.onebuffer:
            rollout_buffer.log()

    ############## Setup Thread Workers BEGIN ################
    # 17 total number of threads for all experiments
    assert args.parallel_size ==17, "use 17 workers for all experiments"

    startIndex = 0
    all_workers = []

    # a3c and sil learning rate and optimizer
    learning_rate_input = tf.placeholder(tf.float32, shape=(), name="opt_lr")
    grad_applier = tf.train.RMSPropOptimizer(
        learning_rate=learning_rate_input,
        decay=args.rmsp_alpha,
        epsilon=args.rmsp_epsilon)

    setup_common_worker(CommonWorker, args, action_size)

    # setup SIL worker
    sil_worker = None
    if args.use_sil:
        _device = "/gpu:0" if args.use_gpu else device

        sil_network = GameACFFNetwork(
            action_size, startIndex, device=_device,
            padding=args.padding, in_shape=input_shape)

        sil_worker = SILTrainingThread(startIndex, global_network, sil_network,
            args.initial_learn_rate,
            learning_rate_input,
            grad_applier, device=_device,
            batch_size=args.batch_size,
            use_rollout=args.use_lider,
            one_buffer=args.onebuffer,
            sampleR=args.sampleR)

        all_workers.append(sil_worker)
        startIndex += 1

    # setup refresh worker
    refresh_worker = None
    if args.use_lider:
        _device = "/gpu:0" if args.use_gpu else device

        refresh_network = GameACFFNetwork(
            action_size, startIndex, device=_device,
            padding=args.padding, in_shape=input_shape)

        refresh_local_pretrained_model = None
        # if refreshing with other polies
        if args.nstep_bc > 0:
            refresh_local_pretrained_model = PretrainedModelNetwork(
                pretrain_graph, action_size, startIndex,
                padding=args.padding,
                in_shape=input_shape, sae=False,
                tied_weights=False,
                use_denoising=False,
                noise_factor=0.3,
                loss_function='mse',
                use_slv=False, device=_device)

        refresh_worker = RefreshThread(
            thread_index=startIndex, action_size=action_size, env_id=args.gym_env,
            global_a3c=global_network, local_a3c=refresh_network,
            update_in_rollout=args.update_in_rollout, nstep_bc=args.nstep_bc,
            global_pretrained_model=global_pretrained_model,
            local_pretrained_model=refresh_local_pretrained_model,
            transformed_bellman = args.transformed_bellman,
            device=_device,
            entropy_beta=args.entropy_beta, clip_norm=args.grad_norm_clip,
            grad_applier=grad_applier,
            initial_learn_rate=args.initial_learn_rate,
            learning_rate_input=learning_rate_input)

        all_workers.append(refresh_worker)
        startIndex += 1

    # setup a3c workers
    setup_a3c_worker(A3CTrainingThread, args, startIndex)
    for i in range(startIndex, args.parallel_size):
        local_network = GameACFFNetwork(
            action_size, i, device="/cpu:0",
            padding=args.padding,
            in_shape=input_shape)

        a3c_worker = A3CTrainingThread(
            i, global_network, local_network,
            args.initial_learn_rate, learning_rate_input, grad_applier,
            device="/cpu:0", no_op_max=30)

        all_workers.append(a3c_worker)
    ############## Setup Thread Workers END ################

    # setup config for tensorflow
    config = tf.ConfigProto(
        gpu_options=gpu_options,
        log_device_placement=False,
        allow_soft_placement=True)

    # prepare sessions
    sess = tf.Session(config=config)
    pretrain_sess = None
    if global_pretrained_model:
        pretrain_sess = tf.Session(config=config, graph=pretrain_graph)

    # initial pretrained model
    if pretrain_sess:
        assert args.pretrained_model_folder is not None
        global_pretrained_model.load(
            pretrain_sess,
            args.pretrained_model_folder)

    sess.run(tf.global_variables_initializer())
    if global_pretrained_model:
        initialize_uninitialized(tf, pretrain_sess,
                                 global_pretrained_model)
    if local_pretrained_model:
        initialize_uninitialized(tf, pretrain_sess,
                                 local_pretrained_model)

    # summary writer for tensorboard
    summ_file = args.save_to+'log/a3c/{}/'.format(GYM_ENV_NAME) + str(folder)[58:] # str(folder)[12:]
    summary_writer = tf.summary.FileWriter(summ_file, sess.graph)

    # init or load checkpoint with saver
    root_saver = tf.train.Saver(max_to_keep=1)
    saver = tf.train.Saver(max_to_keep=1)
    best_saver = tf.train.Saver(max_to_keep=1)

    checkpoint = tf.train.get_checkpoint_state(str(folder)+'/model_checkpoints')
    if checkpoint and checkpoint.model_checkpoint_path:
        root_saver.restore(sess, checkpoint.model_checkpoint_path)
        logger.info("checkpoint loaded:{}".format(
            checkpoint.model_checkpoint_path))
        tokens = checkpoint.model_checkpoint_path.split("-")
        # set global step
        global_t = int(tokens[-1])
        logger.info(">>> global step set: {}".format(global_t))

        tmp_t = (global_t // args.eval_freq) * args.eval_freq
        logger.info(">>> tmp_t: {}".format(tmp_t))

        # set wall time
        wall_t = 0.

        # set up reward files
        best_reward_file = folder / 'model_best/best_model_reward'
        with best_reward_file.open('r') as f:
            best_model_reward = float(f.read())

        # restore rewards
        rewards = restore_dict(reward_fname, global_t)
        logger.info(">>> restored: rewards")

        # restore loggings
        sil_dict = restore_dict(sil_fname, global_t)
        sil_ctr = sil_dict['sil_ctr'][tmp_t]
        sil_a3c_sampled = sil_dict['sil_a3c_sampled'][tmp_t]
        sil_a3c_used = sil_dict['sil_a3c_used'][tmp_t]
        sil_a3c_used_return = sil_dict['sil_a3c_used_return'][tmp_t]
        sil_rollout_sampled = sil_dict['sil_rollout_sampled'][tmp_t]
        sil_rollout_used = sil_dict['sil_rollout_used'][tmp_t]
        sil_rollout_used_return = sil_dict['sil_rollout_used_return'][tmp_t]
        sil_old_used = sil_dict['sil_old_used'][tmp_t]
        logger.info(">>> restored: sil_dict")

        rollout_dict = restore_dict(rollout_fname, global_t)
        rollout_ctr = rollout_dict['rollout_ctr'][tmp_t]
        rollout_added_ctr = rollout_dict['rollout_added_ctr'][tmp_t]
        rollout_new_return = rollout_dict['rollout_new_return'][tmp_t]
        rollout_old_return = rollout_dict['rollout_old_return'][tmp_t]
        logger.info(">>> restored: rollout_dict")

        if args.load_pretrained_model:
            class_reward_file = folder / '{}-class-rewards.pkl'.format(GYM_ENV_NAME)
            class_rewards = restore_dict(class_reward_file, global_t)

        # restore replay buffers (if saved)
        if args.checkpoint_buffer:
            # restore buffer D
            if args.use_sil and args.priority_memory:
                shared_memory = restore_buffer(sharedmem_fname, shared_memory, global_t)
                shared_memory = restore_buffer_trees(sharedmem_trees_fname, shared_memory, global_t)
                shared_memory = restore_buffer_params(sharedmem_params_fname, shared_memory, global_t)
                logger.info(">>> restored: shared_memory (Buffer D)")
                shared_memory.log()
                # restore buffer R
                if args.use_lider and not args.onebuffer:
                    rollout_buffer = restore_buffer(rolloutmem_fname, rollout_buffer, global_t)
                    rollout_buffer = restore_buffer_trees(rolloutmem_trees_fname, rollout_buffer, global_t)
                    rollout_buffer = restore_buffer_params(rolloutmem_params_fname, rollout_buffer, global_t)
                    logger.info(">>> restored: rollout_buffer (Buffer R)")
                    rollout_buffer.log()

        # if all restores okay, remove old ckpt to save storage space
        prev_ckpt_t = global_t

    else:
        logger.warning("Could not find old checkpoint")
        wall_t = 0.0
        prepare_dir(folder, empty=True)
        prepare_dir(folder / 'model_checkpoints', empty=True)
        prepare_dir(folder / 'model_best', empty=True)
        prepare_dir(folder / 'frames', empty=True)

    lock = threading.Lock()

    # next saving global_t
    def next_t(current_t, freq):
        return np.ceil((current_t + 0.00001) / freq) * freq

    next_global_t = next_t(global_t, args.eval_freq)
    next_save_t = next_t(
        global_t, args.eval_freq*args.checkpoint_freq)

    step_t = 0

    def train_function(parallel_idx, th_ctr, ep_queue, net_updates):
        nonlocal global_t, step_t, rewards, class_rewards, lock, \
                 next_save_t, next_global_t, prev_ckpt_t
        nonlocal shared_memory, rollout_buffer
        nonlocal sil_dict, sil_ctr, sil_a3c_sampled, sil_a3c_used, sil_a3c_used_return, \
                 sil_rollout_sampled, sil_rollout_used, sil_rollout_used_return, \
                 sil_old_used
        nonlocal rollout_dict, rollout_ctr, rollout_added_ctr, \
                 rollout_new_return, rollout_old_return

        parallel_worker = all_workers[parallel_idx]
        parallel_worker.set_summary_writer(summary_writer)

        with lock:
            # Evaluate model before training
            if not stop_req and global_t == 0 and step_t == 0:
                rewards['eval'][step_t] = parallel_worker.testing(
                    sess, args.eval_max_steps, global_t, folder,
                    worker=all_workers[-1])

                # testing pretrained TA or BC in game
                if args.load_pretrained_model:
                    assert pretrain_sess is not None
                    assert global_pretrained_model is not None
                    class_rewards['class_eval'][step_t] = \
                        parallel_worker.test_loaded_classifier(global_t=global_t,
                                                    max_eps=50, # testing 50 episodes
                                                    sess=pretrain_sess,
                                                    worker=all_workers[-1],
                                                    model=global_pretrained_model)
                    # log pretrained model performance
                    class_eval_file = pathlib.Path(args.pretrained_model_folder[:21]+\
                        str(GAME_NAME)+"/"+str(GAME_NAME)+'-model-eval.txt')
                    class_std = np.std(class_rewards['class_eval'][step_t][-1])
                    class_mean = np.mean(class_rewards['class_eval'][step_t][-1])
                    with class_eval_file.open('w') as f:
                        f.write("class_mean: \n" + str(class_mean) + "\n")
                        f.write("class_std: \n" + str(class_std) + "\n")
                        f.write("class_rewards: \n" + str(class_rewards['class_eval'][step_t][-1]) + "\n")

                checkpt_file = folder / 'model_checkpoints'
                checkpt_file /= '{}_checkpoint'.format(GYM_ENV_NAME)
                saver.save(sess, str(checkpt_file), global_step=global_t)
                save_best_model(rewards['eval'][global_t][0])

                # saving worker info to dicts for analysis
                sil_dict['sil_ctr'][step_t] = sil_ctr
                sil_dict['sil_a3c_sampled'][step_t] = sil_a3c_sampled
                sil_dict['sil_a3c_used'][step_t] = sil_a3c_used
                sil_dict['sil_a3c_used_return'][step_t] = sil_a3c_used_return
                sil_dict['sil_rollout_sampled'][step_t] = sil_rollout_sampled
                sil_dict['sil_rollout_used'][step_t] = sil_rollout_used
                sil_dict['sil_rollout_used_return'][step_t] = sil_rollout_used_return
                sil_dict['sil_old_used'][step_t] = sil_old_used

                rollout_dict['rollout_ctr'][step_t] = rollout_ctr
                rollout_dict['rollout_added_ctr'][step_t] = rollout_added_ctr
                rollout_dict['rollout_new_return'][step_t] = rollout_new_return
                rollout_dict['rollout_old_return'][step_t] = rollout_old_return

                # dump pickle
                dump_pickle([rewards, sil_dict, rollout_dict],
                            [reward_fname, sil_fname, rollout_fname],
                            global_t)
                if args.load_pretrained_model:
                    dump_pickle([class_rewards], [class_reward_fname], global_t)

                logger.info('Dump pickle at step {}'.format(global_t))

                # save replay buffer (only works under priority mem)
                if args.checkpoint_buffer:
                    if shared_memory is not None and args.priority_memory:
                        params = [shared_memory.buff._next_idx, shared_memory.buff._max_priority]
                        trees = [shared_memory.buff._it_sum._value,
                                 shared_memory.buff._it_min._value]
                        dump_pickle([shared_memory.buff._storage, params, trees],
                                    [sharedmem_fname, sharedmem_params_fname, sharedmem_trees_fname],
                                    global_t)
                        logger.info('Saving shared_memory')

                    if rollout_buffer is not None and args.priority_memory:
                        params = [rollout_buffer.buff._next_idx, rollout_buffer.buff._max_priority]
                        trees = [rollout_buffer.buff._it_sum._value,
                                 rollout_buffer.buff._it_min._value]
                        dump_pickle([rollout_buffer.buff._storage, params, trees],
                                    [rolloutmem_fname, rolloutmem_params_fname, rolloutmem_trees_fname],
                                    global_t)
                        logger.info('Saving rollout_buffer')

                prev_ckpt_t = global_t

                step_t = 1

        # set start_time
        start_time = time.time() - wall_t
        parallel_worker.set_start_time(start_time)

        if parallel_worker.is_sil_thread:
            sil_interval = 0  # bigger number => slower SIL updates
            m_repeat = 4
            min_mem = args.batch_size * m_repeat
            sil_train_flag = len(shared_memory) >= min_mem

        while True:
            if stop_req:
                return

            if global_t >= (args.max_time_step * args.max_time_step_fraction):
                return

            if parallel_worker.is_sil_thread:
                # before sil starts, init local count
                local_sil_ctr = 0
                local_sil_a3c_sampled, local_sil_a3c_used, local_sil_a3c_used_return = 0, 0, 0
                local_sil_rollout_sampled, local_sil_rollout_used, local_sil_rollout_used_return = 0, 0, 0
                local_sil_old_used = 0

                if net_updates.qsize() >= sil_interval \
                   and len(shared_memory) >= min_mem:
                    sil_train_flag = True

                if sil_train_flag:
                    sil_train_flag = False

                    th_ctr.get()

                    train_out = parallel_worker.sil_train(
                        sess, global_t, shared_memory, m_repeat,
                        rollout_buffer=rollout_buffer)

                    local_sil_ctr, local_sil_a3c_sampled, local_sil_a3c_used, \
                       local_sil_a3c_used_return, \
                       local_sil_rollout_sampled, local_sil_rollout_used, \
                       local_sil_rollout_used_return, \
                       local_sil_old_used = train_out

                    th_ctr.put(1)

                    with net_updates.mutex:
                        net_updates.queue.clear()

                    if args.use_lider:
                        parallel_worker.record_sil(sil_ctr=sil_ctr,
                                              total_used=(sil_a3c_used + sil_rollout_used),
                                              num_a3c_used=sil_a3c_used,
                                              a3c_used_return=sil_a3c_used_return/(sil_a3c_used+1),#add one in case divide by zero
                                              rollout_used=sil_rollout_used,
                                              rollout_used_return=sil_rollout_used_return/(sil_rollout_used+1),
                                              old_used=sil_old_used,
                                              global_t=global_t)

                        if sil_ctr % 200 == 0 and sil_ctr > 0:
                            rollout_buffsize = 0
                            if not args.onebuffer:
                                rollout_buffsize = len(rollout_buffer)
                            log_data = (sil_ctr, len(shared_memory),
                                        rollout_buffsize,
                                        sil_a3c_used+sil_rollout_used,
                                        args.batch_size*sil_ctr,
                                        sil_a3c_used,
                                        sil_a3c_used_return/(sil_a3c_used+1),
                                        sil_rollout_used,
                                        sil_rollout_used_return/(sil_rollout_used+1),
                                        sil_old_used)
                            logger.info("SIL: sil_ctr={0:}"
                                        " sil_memory_size={1:}"
                                        " rollout_buffer_size={2:}"
                                        " total_sample_used={3:}/{4:}"
                                        " a3c_used={5:}"
                                        " a3c_used_return_avg={6:.2f}"
                                        " rollout_used={7:}"
                                        " rollout_used_return_avg={8:.2f}"
                                        " old_used={9:}".format(*log_data))
                    else:
                        parallel_worker.record_sil(sil_ctr=sil_ctr,
                                                   total_used=(sil_a3c_used + sil_rollout_used),
                                                   num_a3c_used=sil_a3c_used,
                                                   rollout_used=sil_rollout_used,
                                                   global_t=global_t)
                        if sil_ctr % 200 == 0 and sil_ctr > 0:
                            log_data = (sil_ctr, sil_a3c_used+sil_rollout_used,
                                        args.batch_size*sil_ctr,
                                        sil_a3c_used,
                                        len(shared_memory))
                            logger.info("SIL: sil_ctr={0:}"
                                        " total_sample_used={1:}/{2:}"
                                        " a3c_used={3:}"
                                        " sil_memory_size={4:}".format(*log_data))

                # Adding episodes to SIL memory is centralize to ensure
                # sampling and updating of priorities does not become a problem
                # since we add new episodes to SIL at once and during
                # SIL training it is guaranteed that SIL memory is untouched.
                max = args.parallel_size
                while not ep_queue.empty():
                    data = ep_queue.get()
                    parallel_worker.episode.set_data(*data)
                    shared_memory.extend(parallel_worker.episode)
                    parallel_worker.episode.reset()
                    max -= 1
                    if max <= 0: # This ensures that SIL has a chance to train
                        break

                diff_global_t = 0

                # centralized rollout counting
                local_rollout_ctr, local_rollout_added_ctr = 0, 0
                local_rollout_new_return, local_rollout_old_return = 0, 0

            elif parallel_worker.is_refresh_thread:
                # before refresh starts, init local count
                diff_global_t = 0
                local_rollout_ctr, local_rollout_added_ctr = 0, 0
                local_rollout_new_return, local_rollout_old_return = 0, 0

                if len(shared_memory) >= 1:
                    th_ctr.get()
                    # randomly sample a state from buffer D
                    sample = shared_memory.sample_one_random()
                    # after sample, flip refreshed to True
                    # TODO: fix this so that only *succesful* refresh is flipped to True
                    # currently counting *all* refresh as True
                    assert sample[-1] == True

                    train_out = parallel_worker.rollout(sess, folder, pretrain_sess,
                                                        global_t, sample,
                                                        args.addall,
                                                        args.max_ep_step,
                                                        args.nstep_bc,
                                                        args.update_in_rollout)

                    diff_global_t, episode_end, part_end, local_rollout_ctr, \
                        local_rollout_added_ctr, add, local_rollout_new_return, \
                        local_rollout_old_return = train_out

                    th_ctr.put(1)

                    if rollout_ctr % 20 == 0 and rollout_ctr > 0:
                        log_msg = "ROLLOUT: rollout_ctr={} added_rollout_ct={} worker={}".format(
                        rollout_ctr, rollout_added_ctr, parallel_worker.thread_idx)
                        logger.info(log_msg)
                        logger.info("ROLLOUT Gnew: {}, G: {}".format(local_rollout_new_return,
                                                                     local_rollout_old_return))

                    # should always part_end, i.e., end of episode
                    # and only add if new return is better (if not LiDER-AddAll)
                    if part_end and add:
                        if not args.onebuffer:
                            # directly put into Buffer R
                            rollout_buffer.extend(parallel_worker.episode)
                        else:
                            # Buffer D add sample is centralized when OneBuffer
                            ep_queue.put(parallel_worker.episode.get_data())

                    parallel_worker.episode.reset()

                # centralized SIL counting
                local_sil_ctr = 0
                local_sil_a3c_sampled, local_sil_a3c_used, local_sil_a3c_used_return = 0, 0, 0
                local_sil_rollout_sampled, local_sil_rollout_used, local_sil_rollout_used_return = 0, 0, 0
                local_sil_old_used = 0

            # a3c training thread worker
            else:
                th_ctr.get()

                train_out = parallel_worker.train(sess, global_t, rewards)
                diff_global_t, episode_end, part_end = train_out

                th_ctr.put(1)

                if args.use_sil:
                    net_updates.put(1)
                    if part_end:
                        ep_queue.put(parallel_worker.episode.get_data())
                        parallel_worker.episode.reset()

                # centralized SIL counting
                local_sil_ctr = 0
                local_sil_a3c_sampled, local_sil_a3c_used, local_sil_a3c_used_return = 0, 0, 0
                local_sil_rollout_sampled, local_sil_rollout_used, local_sil_rollout_used_return = 0, 0, 0
                local_sil_old_used = 0
                # centralized rollout counting
                local_rollout_ctr, local_rollout_added_ctr = 0, 0
                local_rollout_new_return, local_rollout_old_return = 0, 0

            # ensure only one thread is updating global_t at a time
            with lock:
                global_t += diff_global_t

                # centralize increasing count for SIL and Rollout
                sil_ctr += local_sil_ctr
                sil_a3c_sampled += local_sil_a3c_sampled
                sil_a3c_used += local_sil_a3c_used
                sil_a3c_used_return += local_sil_a3c_used_return
                sil_rollout_sampled += local_sil_rollout_sampled
                sil_rollout_used += local_sil_rollout_used
                sil_rollout_used_return += local_sil_rollout_used_return
                sil_old_used += local_sil_old_used

                rollout_ctr += local_rollout_ctr
                rollout_added_ctr += local_rollout_added_ctr
                rollout_new_return += local_rollout_new_return
                rollout_old_return += local_rollout_old_return

                # if during a thread's update, global_t has reached a evaluation interval
                if global_t > next_global_t:
                    next_global_t = next_t(global_t, args.eval_freq)
                    step_t = int(next_global_t - args.eval_freq)

                    # wait for all threads to be done before testing
                    while not stop_req and th_ctr.qsize() < len(all_workers):
                        time.sleep(0.001)

                    step_t = int(next_global_t - args.eval_freq)

                    # Evaluate for 125,000 steps
                    rewards['eval'][step_t] = parallel_worker.testing(
                        sess, args.eval_max_steps, step_t, folder,
                        worker=all_workers[-1])
                    save_best_model(rewards['eval'][step_t][0])
                    last_reward = rewards['eval'][step_t][0]

                    # saving worker info to dicts
                    # SIL
                    sil_dict['sil_ctr'][step_t] = sil_ctr
                    sil_dict['sil_a3c_sampled'][step_t] = sil_a3c_sampled
                    sil_dict['sil_a3c_used'][step_t] = sil_a3c_used
                    sil_dict['sil_a3c_used_return'][step_t] = sil_a3c_used_return
                    sil_dict['sil_rollout_sampled'][step_t] = sil_rollout_sampled
                    sil_dict['sil_rollout_used'][step_t] = sil_rollout_used
                    sil_dict['sil_rollout_used_return'][step_t] = sil_rollout_used_return
                    sil_dict['sil_old_used'][step_t] = sil_old_used
                    # ROLLOUT
                    rollout_dict['rollout_ctr'][step_t] = rollout_ctr
                    rollout_dict['rollout_added_ctr'][step_t] = rollout_added_ctr
                    rollout_dict['rollout_new_return'][step_t] = rollout_new_return
                    rollout_dict['rollout_old_return'][step_t] = rollout_old_return

                    # save ckpt after done with eval
                    if global_t > next_save_t:
                        next_save_t = next_t(global_t, args.eval_freq*args.checkpoint_freq)

                        # dump pickle
                        dump_pickle([rewards, sil_dict, rollout_dict],
                                    [reward_fname, sil_fname, rollout_fname],
                                    global_t)
                        if args.load_pretrained_model:
                            dump_pickle([class_rewards], [class_reward_fname], global_t)
                        logger.info('Dump pickle at step {}'.format(global_t))

                        # save replay buffer (only works for priority mem for now)
                        if args.checkpoint_buffer:
                            if shared_memory is not None and args.priority_memory:
                                params = [shared_memory.buff._next_idx, shared_memory.buff._max_priority]
                                trees = [shared_memory.buff._it_sum._value,
                                         shared_memory.buff._it_min._value]
                                dump_pickle([shared_memory.buff._storage, params, trees],
                                            [sharedmem_fname, sharedmem_params_fname, sharedmem_trees_fname],
                                            global_t)
                                logger.info('Saved shared_memory')

                            if rollout_buffer is not None and args.priority_memory:
                                params = [rollout_buffer.buff._next_idx, rollout_buffer.buff._max_priority]
                                trees = [rollout_buffer.buff._it_sum._value,
                                         rollout_buffer.buff._it_min._value]
                                dump_pickle([rollout_buffer.buff._storage, params, trees],
                                            [rolloutmem_fname, rolloutmem_params_fname, rolloutmem_trees_fname],
                                            global_t)
                                logger.info('Saved rollout_buffer')

                        # save a3c after saving buffer -- in case saving buffer OOM
                        # so that at least we can revert back to the previous ckpt
                        checkpt_file = folder / 'model_checkpoints'
                        checkpt_file /= '{}_checkpoint'.format(GYM_ENV_NAME)
                        saver.save(sess, str(checkpt_file), global_step=global_t,
                                   write_meta_graph=False)
                        logger.info('Saved model ckpt')

                        # if everything saves okay, clean up previous ckpt to save space
                        remove_pickle([reward_fname, sil_fname, rollout_fname],
                                      prev_ckpt_t)
                        if args.load_pretrained_model:
                            remove_pickle([class_reward_fname], prev_ckpt_t)

                        remove_pickle([sharedmem_fname, sharedmem_params_fname,
                                       sharedmem_trees_fname],
                                      prev_ckpt_t)
                        if rollout_buffer is not None and args.priority_memory:
                            remove_pickle([rolloutmem_fname, rolloutmem_params_fname,
                                           rolloutmem_trees_fname],
                                          prev_ckpt_t)

                        logger.info('Removed ckpt from step {}'.format(prev_ckpt_t))

                        prev_ckpt_t = global_t


    def signal_handler(signal, frame):
        nonlocal stop_req
        logger.info('You pressed Ctrl+C!')
        stop_req = True

        if stop_req and global_t == 0:
            sys.exit(1)

    def save_best_model(test_reward):
        nonlocal best_model_reward
        if test_reward > best_model_reward:
            best_model_reward = test_reward
            best_reward_file = folder / 'model_best/best_model_reward'

            with best_reward_file.open('w') as f:
                f.write(str(best_model_reward))

            best_checkpt_file = folder / 'model_best'
            best_checkpt_file /= '{}_checkpoint'.format(GYM_ENV_NAME)
            best_saver.save(sess, str(best_checkpt_file))


    train_threads = []
    th_ctr = Queue()
    for i in range(args.parallel_size):
        th_ctr.put(1)

    episodes_queue = None
    net_updates = None
    if args.use_sil:
        episodes_queue = Queue()
        net_updates = Queue()

    for i in range(args.parallel_size):
        worker_thread = Thread(
            target=train_function,
            args=(i, th_ctr, episodes_queue, net_updates,))
        train_threads.append(worker_thread)

    signal.signal(signal.SIGINT, signal_handler)
    signal.signal(signal.SIGTERM, signal_handler)

    # set start time
    start_time = time.time() - wall_t

    for t in train_threads:
        t.start()

    print('Press Ctrl+C to stop')

    for t in train_threads:
        t.join()

    logger.info('Now saving data. Please wait')

    # write wall time
    wall_t = time.time() - start_time
    wall_t_fname = folder / 'wall_t.{}'.format(global_t)
    with wall_t_fname.open('w') as f:
        f.write(str(wall_t))

    # save final model
    checkpoint_file = str(folder / '{}_checkpoint_a3c'.format(GYM_ENV_NAME))
    root_saver.save(sess, checkpoint_file, global_step=global_t)

    dump_final_pickle([rewards, sil_dict, rollout_dict],
                      [reward_fname, sil_fname, rollout_fname])

    logger.info('Data saved!')

    # if everything saves okay & is done training (not because of pressed Ctrl+C),
    # clean up previous ckpt to save space
    if global_t >= (args.max_time_step * args.max_time_step_fraction):
        remove_pickle([reward_fname, sil_fname, rollout_fname],
                      prev_ckpt_t)
        if args.load_pretrained_model:
            remove_pickle([class_reward_fname], prev_ckpt_t)

        remove_pickle([sharedmem_fname, sharedmem_params_fname, sharedmem_trees_fname],
                      prev_ckpt_t)
        if rollout_buffer is not None and args.priority_memory:
            remove_pickle([rolloutmem_fname, rolloutmem_params_fname, rolloutmem_trees_fname],
                          prev_ckpt_t)

        logger.info('Done training, removed ckpt from step {}'.format(prev_ckpt_t))


    sess.close()
    if pretrain_sess:
        pretrain_sess.close()
class SILTrainingThread(CommonWorker):
    """Asynchronous Actor-Critic Training Thread Class."""

    entropy_beta = 0.01
    gamma = 0.99
    finetune_upper_layers_only = False
    transformed_bellman = False
    clip_norm = 0.5
    use_grad_cam = False

    def __init__(self,
                 thread_index,
                 global_net,
                 local_net,
                 initial_learning_rate,
                 learning_rate_input,
                 grad_applier,
                 device=None,
                 batch_size=None,
                 use_rollout=False,
                 one_buffer=False,
                 sampleR=False):
        """Initialize A3CTrainingThread class."""
        assert self.action_size != -1

        self.is_sil_thread = True
        self.thread_idx = thread_index
        self.initial_learning_rate = initial_learning_rate
        self.learning_rate_input = learning_rate_input
        self.local_net = local_net
        self.batch_size = batch_size
        self.use_rollout = use_rollout
        self.one_buffer = one_buffer
        self.sampleR = sampleR

        logger.info("===SIL thread_index: {}===".format(self.thread_idx))
        logger.info("device: {}".format(device))
        logger.info("action_size: {}".format(self.action_size))
        logger.info("entropy_beta: {}".format(self.entropy_beta))
        logger.info("gamma: {}".format(self.gamma))
        logger.info("reward_type: {}".format(self.reward_type))
        logger.info("transformed_bellman: {}".format(
            colored(self.transformed_bellman,
                    "green" if self.transformed_bellman else "red")))
        logger.info("clip_norm: {}".format(self.clip_norm))
        logger.info("use_grad_cam: {}".format(
            colored(self.use_grad_cam,
                    "green" if self.use_grad_cam else "red")))

        reward_clipped = True if self.reward_type == 'CLIP' else False

        local_vars = self.local_net.get_vars

        with tf.device(device):
            critic_lr = 0.1
            entropy_beta = 0
            w_loss = 1.0
            logger.info("sil batch_size: {}".format(self.batch_size))
            logger.info("sil w_loss: {}".format(w_loss))
            logger.info("sil critic_lr: {}".format(critic_lr))
            logger.info("sil entropy_beta: {}".format(entropy_beta))
            self.local_net.prepare_sil_loss(entropy_beta=entropy_beta,
                                            w_loss=w_loss,
                                            critic_lr=critic_lr)
            var_refs = [v._ref() for v in local_vars()]

            self.sil_gradients = tf.gradients(self.local_net.total_loss_sil,
                                              var_refs)

        global_vars = global_net.get_vars

        with tf.device(device):
            if self.clip_norm is not None:
                self.sil_gradients, grad_norm = tf.clip_by_global_norm(
                    self.sil_gradients, self.clip_norm)
            sil_gradients_global = list(zip(self.sil_gradients, global_vars()))
            sil_gradients_local = list(zip(self.sil_gradients, local_vars()))
            self.sil_apply_gradients = grad_applier.apply_gradients(
                sil_gradients_global)
            self.sil_apply_gradients_local = grad_applier.apply_gradients(
                sil_gradients_local)

        self.sync = self.local_net.sync_from(global_net)

        self.episode = SILReplayMemory(self.action_size,
                                       max_len=None,
                                       gamma=self.gamma,
                                       clip=reward_clipped,
                                       height=self.local_net.in_shape[0],
                                       width=self.local_net.in_shape[1],
                                       phi_length=self.local_net.in_shape[2],
                                       reward_constant=self.reward_constant)

        # temp_buffer for mixing and re-sample (brown arrow in Figure 1)
        # initial only when needed (A3CTBSIL & LiDER-OneBuffer does not need temp_buffer)
        self.temp_buffer = None
        if (self.use_rollout) and (not self.one_buffer):
            self.temp_buffer = SILReplayMemory(
                self.action_size,
                max_len=self.batch_size * 2,
                gamma=self.gamma,
                clip=reward_clipped,
                height=self.local_net.in_shape[0],
                width=self.local_net.in_shape[1],
                phi_length=self.local_net.in_shape[2],
                priority=True,
                reward_constant=self.reward_constant)

    def record_sil(self,
                   sil_ctr=0,
                   total_used=0,
                   num_a3c_used=0,
                   a3c_used_return=0,
                   rollout_used=0,
                   rollout_used_return=0,
                   old_used=0,
                   global_t=0,
                   mode='SIL'):
        """Record SIL."""
        summary = tf.Summary()
        summary.value.add(tag='{}/sil_ctr'.format(mode),
                          simple_value=float(sil_ctr))
        summary.value.add(tag='{}/total_num_sample_used'.format(mode),
                          simple_value=float(total_used))

        summary.value.add(tag='{}/num_a3c_used'.format(mode),
                          simple_value=float(num_a3c_used))
        summary.value.add(tag='{}/a3c_used_return'.format(mode),
                          simple_value=float(a3c_used_return))

        summary.value.add(tag='{}/num_rollout_used'.format(mode),
                          simple_value=float(rollout_used_return))
        summary.value.add(tag='{}/rollout_used_return'.format(mode),
                          simple_value=float(rollout_used))

        summary.value.add(tag='{}/num_old_used'.format(mode),
                          simple_value=float(old_used))

        self.writer.add_summary(summary, global_t)
        self.writer.flush()

    def sil_train(self, sess, global_t, sil_memory, m, rollout_buffer=None):
        """Self-imitation learning process."""
        # copy weights from shared to local
        sess.run(self.sync)
        cur_learning_rate = self._anneal_learning_rate(
            global_t, self.initial_learning_rate)

        local_sil_ctr = 0
        local_sil_a3c_sampled, local_sil_a3c_used, local_sil_a3c_used_return = 0, 0, 0
        local_sil_rollout_sampled, local_sil_rollout_used, local_sil_rollout_used_return = 0, 0, 0
        local_sil_old_used = 0

        total_used = 0
        num_a3c_used = 0
        num_rollout_used = 0
        num_rollout_sampled = 0
        num_old_used = 0

        for _ in range(m):
            d_batch_size, r_batch_size = 0, 0
            # A3CTBSIL
            if not self.use_rollout:
                d_batch_size = self.batch_size
            # or LiDER-OneBuffer
            elif self.use_rollout and self.one_buffer:
                d_batch_size = self.batch_size
            # or LiDER
            else:
                assert rollout_buffer is not None
                assert self.temp_buffer is not None
                self.temp_buffer.reset()
                if not self.sampleR:  # otherwise, LiDER-SampleR
                    d_batch_size = self.batch_size
                r_batch_size = self.batch_size

            batch_state, batch_action, batch_returns, batch_fullstate, \
                batch_rollout, batch_refresh, weights = ([] for i in range(7))

            # sample from buffer D
            if d_batch_size > 0 and len(sil_memory) > d_batch_size:
                d_sample = sil_memory.sample(d_batch_size, beta=0.4)
                d_index_list, d_batch, d_weights = d_sample
                d_batch_state, d_action, d_batch_returns, \
                    d_batch_fullstate, d_batch_rollout, d_batch_refresh = d_batch
                # update priority of sampled experiences
                self.update_priorities_once(sess, sil_memory, d_index_list,
                                            d_batch_state, d_action,
                                            d_batch_returns)

                if self.temp_buffer is not None:  # when LiDER
                    d_batch_action = convert_onehot_to_a(d_action)
                    self.temp_buffer.extend_one_priority(
                        d_batch_state, d_batch_fullstate, d_batch_action,
                        d_batch_returns, d_batch_rollout, d_batch_refresh)
                else:  # when A3CTBSIL or LiDER-OneBuffer
                    batch_state.extend(d_batch_state)
                    batch_action.extend(d_action)
                    batch_returns.extend(d_batch_returns)
                    batch_fullstate.extend(d_batch_fullstate)
                    batch_rollout.extend(d_batch_rollout)
                    batch_refresh.extend(d_batch_refresh)
                    weights.extend(d_weights)

            # sample from buffer R
            if r_batch_size > 0 and len(rollout_buffer) > r_batch_size:
                r_sample = rollout_buffer.sample(r_batch_size, beta=0.4)
                r_index_list, r_batch, r_weights = r_sample
                r_batch_state, r_action, r_batch_returns, \
                    r_batch_fullstate, r_batch_rollout, r_batch_refresh = r_batch
                # update priority of sampled experiences
                self.update_priorities_once(sess, rollout_buffer, r_index_list,
                                            r_batch_state, r_action,
                                            r_batch_returns)

                if self.temp_buffer is not None:  # when LiDER
                    r_batch_action = convert_onehot_to_a(r_action)
                    self.temp_buffer.extend_one_priority(
                        r_batch_state, r_batch_fullstate, r_batch_action,
                        r_batch_returns, r_batch_rollout, r_batch_refresh)
                else:  # when A3CTBSIL or LiDER-OneBuffer
                    batch_state.extend(r_batch_state)
                    batch_action.extend(r_action)
                    batch_returns.extend(r_batch_returns)
                    batch_fullstate.extend(r_batch_fullstate)
                    batch_rollout.extend(r_batch_rollout)
                    batch_refresh.extend(r_batch_refresh)
                    weights.extend(r_weights)

            # LiDER only: pick 32 out of mixed
            # (at the beginning the 32 could all from buffer D since rollout has no data yet)
            # make sure the temp_buffer has been filled with at least size of one batch before sampling
            if self.temp_buffer is not None and len(
                    self.temp_buffer) >= self.batch_size:
                sample = self.temp_buffer.sample(self.batch_size, beta=0.4)
                index_list, batch, weights = sample
                # overwrite the initial empty list
                batch_state, batch_action, batch_returns, \
                    batch_fullstate, batch_rollout, batch_refresh = batch

            if self.use_rollout:
                num_rollout_sampled += np.sum(batch_rollout)

            # sil policy update (if one full batch is sampled)
            if len(batch_state) == self.batch_size:
                feed_dict = {
                    self.local_net.s: batch_state,
                    self.local_net.a_sil: batch_action,
                    self.local_net.returns: batch_returns,
                    self.local_net.weights: weights,
                    self.learning_rate_input: cur_learning_rate,
                }
                fetch = [
                    self.local_net.clipped_advs,
                    self.local_net.advs,
                    self.sil_apply_gradients,
                    self.sil_apply_gradients_local,
                ]
                adv_clip, adv, _, _ = sess.run(fetch, feed_dict=feed_dict)
                pos_idx = [i for (i, num) in enumerate(adv) if num > 0]
                neg_idx = [i for (i, num) in enumerate(adv) if num <= 0]
                # log number of samples used for SIL updates
                total_used += len(pos_idx)
                num_rollout_used += np.sum(np.take(batch_rollout, pos_idx))
                num_a3c_used += (len(pos_idx) -
                                 np.sum(np.take(batch_rollout, pos_idx)))
                num_old_used += np.sum(np.take(batch_refresh, pos_idx))
                # return for used rollout samples
                rollout_idx = [
                    i for (i, num) in enumerate(batch_rollout) if num > 0
                ]
                pos_rollout_idx = np.intersect1d(rollout_idx, pos_idx)
                if len(pos_rollout_idx) > 0:
                    local_sil_rollout_used_return += np.sum(
                        np.take(adv, pos_rollout_idx))
                # return for used a3c samples
                a3c_idx = [
                    i for (i, num) in enumerate(batch_rollout) if num <= 0
                ]
                pos_a3c_idx = np.intersect1d(a3c_idx, pos_idx)
                if len(pos_a3c_idx) > 0:
                    local_sil_a3c_used_return += np.sum(
                        np.take(batch_returns, pos_a3c_idx))

                local_sil_ctr += 1

        local_sil_a3c_sampled += (self.batch_size * m - num_rollout_sampled)
        local_sil_rollout_sampled += num_rollout_sampled
        local_sil_a3c_used += num_a3c_used
        local_sil_rollout_used += num_rollout_used
        local_sil_old_used += num_old_used

        return local_sil_ctr, local_sil_a3c_sampled, local_sil_a3c_used, \
               local_sil_a3c_used_return, \
               local_sil_rollout_sampled, local_sil_rollout_used, \
               local_sil_rollout_used_return, \
               local_sil_old_used

    def update_priorities_once(self, sess, memory, index_list, batch_state,
                               batch_action, batch_returns):
        """Self-imitation update priorities once."""
        # copy weights from shared to local
        sess.run(self.sync)

        feed_dict = {
            self.local_net.s: batch_state,
            self.local_net.a_sil: batch_action,
            self.local_net.returns: batch_returns,
        }
        fetch = self.local_net.clipped_advs
        adv_clip = sess.run(fetch, feed_dict=feed_dict)
        memory.set_weights(index_list, adv_clip)
Esempio n. 4
0
    def __init__(self,
                 thread_index,
                 global_net,
                 local_net,
                 initial_learning_rate,
                 learning_rate_input,
                 grad_applier,
                 device=None,
                 no_op_max=30):
        """Initialize A3CTrainingThread class."""
        assert self.action_size != -1

        self.is_sil_thread = False
        self.is_refresh_thread = False

        self.thread_idx = thread_index
        self.learning_rate_input = learning_rate_input
        self.local_net = local_net

        self.no_op_max = no_op_max
        self.override_num_noops = 0 if self.no_op_max == 0 else None

        logger.info("===A3C thread_index: {}===".format(self.thread_idx))
        logger.info("device: {}".format(device))
        logger.info("use_sil: {}".format(
            colored(self.use_sil, "green" if self.use_sil else "red")))
        logger.info("local_t_max: {}".format(self.local_t_max))
        logger.info("action_size: {}".format(self.action_size))
        logger.info("entropy_beta: {}".format(self.entropy_beta))
        logger.info("gamma: {}".format(self.gamma))
        logger.info("reward_type: {}".format(self.reward_type))
        logger.info("transformed_bellman: {}".format(
            colored(self.transformed_bellman,
                    "green" if self.transformed_bellman else "red")))
        logger.info("clip_norm: {}".format(self.clip_norm))
        logger.info("use_grad_cam: {}".format(
            colored(self.use_grad_cam,
                    "green" if self.use_grad_cam else "red")))

        reward_clipped = True if self.reward_type == 'CLIP' else False
        local_vars = self.local_net.get_vars

        with tf.device(device):
            self.local_net.prepare_loss(entropy_beta=self.entropy_beta,
                                        critic_lr=0.5)
            var_refs = [v._ref() for v in local_vars()]

            self.gradients = tf.gradients(self.local_net.total_loss, var_refs)

        global_vars = global_net.get_vars

        with tf.device(device):
            if self.clip_norm is not None:
                self.gradients, grad_norm = tf.clip_by_global_norm(
                    self.gradients, self.clip_norm)
            self.gradients = list(zip(self.gradients, global_vars()))
            self.apply_gradients = grad_applier.apply_gradients(self.gradients)

        self.sync = self.local_net.sync_from(global_net)

        self.game_state = GameState(env_id=self.env_id,
                                    display=False,
                                    no_op_max=self.no_op_max,
                                    human_demo=False,
                                    episode_life=True,
                                    override_num_noops=self.override_num_noops)

        self.local_t = 0

        self.initial_learning_rate = initial_learning_rate

        self.episode_reward = 0
        self.episode_steps = 0

        # variable controlling log output
        self.prev_local_t = 0

        with tf.device(device):
            if self.use_grad_cam:
                self.action_meaning = self.game_state.env.unwrapped \
                    .get_action_meanings()
                self.local_net.build_grad_cam_grads()

        if self.use_sil:
            self.episode = SILReplayMemory(
                self.action_size,
                max_len=None,
                gamma=self.gamma,
                clip=reward_clipped,
                height=self.local_net.in_shape[0],
                width=self.local_net.in_shape[1],
                phi_length=self.local_net.in_shape[2],
                reward_constant=self.reward_constant)
Esempio n. 5
0
class A3CTrainingThread(CommonWorker):
    """Asynchronous Actor-Critic Training Thread Class."""
    log_interval = 100
    perf_log_interval = 1000
    local_t_max = 20
    entropy_beta = 0.01
    gamma = 0.99
    shaping_actions = -1  # -1 all actions, 0 exclude noop
    transformed_bellman = False
    clip_norm = 0.5
    use_grad_cam = False
    use_sil = False
    log_idx = 0
    reward_constant = 0

    def __init__(self,
                 thread_index,
                 global_net,
                 local_net,
                 initial_learning_rate,
                 learning_rate_input,
                 grad_applier,
                 device=None,
                 no_op_max=30):
        """Initialize A3CTrainingThread class."""
        assert self.action_size != -1

        self.is_sil_thread = False
        self.is_refresh_thread = False

        self.thread_idx = thread_index
        self.learning_rate_input = learning_rate_input
        self.local_net = local_net

        self.no_op_max = no_op_max
        self.override_num_noops = 0 if self.no_op_max == 0 else None

        logger.info("===A3C thread_index: {}===".format(self.thread_idx))
        logger.info("device: {}".format(device))
        logger.info("use_sil: {}".format(
            colored(self.use_sil, "green" if self.use_sil else "red")))
        logger.info("local_t_max: {}".format(self.local_t_max))
        logger.info("action_size: {}".format(self.action_size))
        logger.info("entropy_beta: {}".format(self.entropy_beta))
        logger.info("gamma: {}".format(self.gamma))
        logger.info("reward_type: {}".format(self.reward_type))
        logger.info("transformed_bellman: {}".format(
            colored(self.transformed_bellman,
                    "green" if self.transformed_bellman else "red")))
        logger.info("clip_norm: {}".format(self.clip_norm))
        logger.info("use_grad_cam: {}".format(
            colored(self.use_grad_cam,
                    "green" if self.use_grad_cam else "red")))

        reward_clipped = True if self.reward_type == 'CLIP' else False
        local_vars = self.local_net.get_vars

        with tf.device(device):
            self.local_net.prepare_loss(entropy_beta=self.entropy_beta,
                                        critic_lr=0.5)
            var_refs = [v._ref() for v in local_vars()]

            self.gradients = tf.gradients(self.local_net.total_loss, var_refs)

        global_vars = global_net.get_vars

        with tf.device(device):
            if self.clip_norm is not None:
                self.gradients, grad_norm = tf.clip_by_global_norm(
                    self.gradients, self.clip_norm)
            self.gradients = list(zip(self.gradients, global_vars()))
            self.apply_gradients = grad_applier.apply_gradients(self.gradients)

        self.sync = self.local_net.sync_from(global_net)

        self.game_state = GameState(env_id=self.env_id,
                                    display=False,
                                    no_op_max=self.no_op_max,
                                    human_demo=False,
                                    episode_life=True,
                                    override_num_noops=self.override_num_noops)

        self.local_t = 0

        self.initial_learning_rate = initial_learning_rate

        self.episode_reward = 0
        self.episode_steps = 0

        # variable controlling log output
        self.prev_local_t = 0

        with tf.device(device):
            if self.use_grad_cam:
                self.action_meaning = self.game_state.env.unwrapped \
                    .get_action_meanings()
                self.local_net.build_grad_cam_grads()

        if self.use_sil:
            self.episode = SILReplayMemory(
                self.action_size,
                max_len=None,
                gamma=self.gamma,
                clip=reward_clipped,
                height=self.local_net.in_shape[0],
                width=self.local_net.in_shape[1],
                phi_length=self.local_net.in_shape[2],
                reward_constant=self.reward_constant)

    def train(self, sess, global_t, train_rewards):
        """Train A3C."""
        states = []
        fullstates = []
        actions = []
        rewards = []
        values = []
        rho = []

        terminal_pseudo = False  # loss of life
        terminal_end = False  # real terminal

        # copy weights from shared to local
        sess.run(self.sync)

        start_local_t = self.local_t

        # t_max times loop
        for i in range(self.local_t_max):
            state = cv2.resize(self.game_state.s_t,
                               self.local_net.in_shape[:-1],
                               interpolation=cv2.INTER_AREA)
            fullstate = self.game_state.clone_full_state()

            pi_, value_, logits_ = self.local_net.run_policy_and_value(
                sess, state)
            action = self.pick_action(logits_)

            states.append(state)
            fullstates.append(fullstate)
            actions.append(action)
            values.append(value_)

            if self.thread_idx == self.log_idx \
               and self.local_t % self.log_interval == 0:
                log_msg1 = "lg={}".format(
                    np.array_str(logits_, precision=4, suppress_small=True))
                log_msg2 = "pi={}".format(
                    np.array_str(pi_, precision=4, suppress_small=True))
                log_msg3 = "V={:.4f}".format(value_)
                logger.debug(log_msg1)
                logger.debug(log_msg2)
                logger.debug(log_msg3)

            # process game
            self.game_state.step(action)

            # receive game result
            reward = self.game_state.reward
            terminal = self.game_state.terminal

            self.episode_reward += reward

            if self.use_sil:
                # save states in episode memory
                self.episode.add_item(self.game_state.s_t, fullstate, action,
                                      reward, terminal)

            if self.reward_type == 'CLIP':
                reward = np.sign(reward)

            rewards.append(reward)

            self.local_t += 1
            self.episode_steps += 1
            global_t += 1

            # s_t1 -> s_t
            self.game_state.update()

            if terminal:
                terminal_pseudo = True

                env = self.game_state.env
                name = 'EpisodicLifeEnv'
                if get_wrapper_by_name(env, name).was_real_done:
                    # reduce log freq
                    if self.thread_idx == self.log_idx:
                        log_msg = "train: worker={} global_t={} local_t={}".format(
                            self.thread_idx, global_t, self.local_t)
                        score_str = colored(
                            "score={}".format(self.episode_reward), "magenta")
                        steps_str = colored(
                            "steps={}".format(self.episode_steps), "blue")
                        log_msg += " {} {}".format(score_str, steps_str)
                        logger.debug(log_msg)

                    train_rewards['train'][global_t] = (self.episode_reward,
                                                        self.episode_steps)
                    self.record_summary(score=self.episode_reward,
                                        steps=self.episode_steps,
                                        episodes=None,
                                        global_t=global_t,
                                        mode='Train')
                    self.episode_reward = 0
                    self.episode_steps = 0
                    terminal_end = True

                self.game_state.reset(hard_reset=False)
                break

        cumsum_reward = 0.0
        if not terminal:
            state = cv2.resize(self.game_state.s_t,
                               self.local_net.in_shape[:-1],
                               interpolation=cv2.INTER_AREA)
            cumsum_reward = self.local_net.run_value(sess, state)

        actions.reverse()
        states.reverse()
        rewards.reverse()
        values.reverse()

        batch_state = []
        batch_action = []
        batch_adv = []
        batch_cumsum_reward = []

        # compute and accumulate gradients
        for (ai, ri, si, vi) in zip(actions, rewards, states, values):
            if self.transformed_bellman:
                ri = np.sign(ri) * self.reward_constant + ri
                cumsum_reward = transform_h(ri + self.gamma *
                                            transform_h_inv(cumsum_reward))
            else:
                cumsum_reward = ri + self.gamma * cumsum_reward
            advantage = cumsum_reward - vi

            # convert action to one-hot vector
            a = np.zeros([self.action_size])
            a[ai] = 1

            batch_state.append(si)
            batch_action.append(a)
            batch_adv.append(advantage)
            batch_cumsum_reward.append(cumsum_reward)

        cur_learning_rate = self._anneal_learning_rate(
            global_t, self.initial_learning_rate)

        feed_dict = {
            self.local_net.s: batch_state,
            self.local_net.a: batch_action,
            self.local_net.advantage: batch_adv,
            self.local_net.cumulative_reward: batch_cumsum_reward,
            self.learning_rate_input: cur_learning_rate,
        }

        sess.run(self.apply_gradients, feed_dict=feed_dict)

        t = self.local_t - self.prev_local_t
        if (self.thread_idx == self.log_idx and t >= self.perf_log_interval):
            self.prev_local_t += self.perf_log_interval
            elapsed_time = time.time() - self.start_time
            steps_per_sec = global_t / elapsed_time
            logger.info("worker-{}, log_worker-{}".format(
                self.thread_idx, self.log_idx))
            logger.info("Performance : {} STEPS in {:.0f} sec. {:.0f}"
                        " STEPS/sec. {:.2f}M STEPS/hour.".format(
                            global_t, elapsed_time, steps_per_sec,
                            steps_per_sec * 3600 / 1000000.))

        # return advanced local step size
        diff_local_t = self.local_t - start_local_t
        return diff_local_t, terminal_end, terminal_pseudo
    def __init__(self, thread_index, action_size, env_id,
                 global_a3c, local_a3c, update_in_rollout, nstep_bc,
                 global_pretrained_model, local_pretrained_model,
                 transformed_bellman=False, no_op_max=0,
                 device='/cpu:0', entropy_beta=0.01, clip_norm=None,
                 grad_applier=None, initial_learn_rate=0.007,
                 learning_rate_input=None):
        """Initialize RolloutThread class."""
        self.is_refresh_thread = True
        self.action_size = action_size
        self.thread_idx = thread_index
        self.transformed_bellman = transformed_bellman
        self.entropy_beta = entropy_beta
        self.clip_norm = clip_norm
        self.initial_learning_rate = initial_learn_rate
        self.learning_rate_input = learning_rate_input

        self.no_op_max = no_op_max
        self.override_num_noops = 0 if self.no_op_max == 0 else None

        logger.info("===REFRESH thread_index: {}===".format(self.thread_idx))
        logger.info("device: {}".format(device))
        logger.info("action_size: {}".format(self.action_size))
        logger.info("reward_type: {}".format(self.reward_type))
        logger.info("transformed_bellman: {}".format(
            colored(self.transformed_bellman,
                    "green" if self.transformed_bellman else "red")))
        logger.info("update in rollout: {}".format(
            colored(update_in_rollout, "green" if update_in_rollout else "red")))
        logger.info("N-step BC: {}".format(nstep_bc))

        self.reward_clipped = True if self.reward_type == 'CLIP' else False

        # setup local a3c
        self.local_a3c = local_a3c
        self.sync_a3c = self.local_a3c.sync_from(global_a3c)
        with tf.device(device):
            local_vars = self.local_a3c.get_vars
            self.local_a3c.prepare_loss(
                entropy_beta=self.entropy_beta, critic_lr=0.5)
            var_refs = [v._ref() for v in local_vars()]
            self.rollout_gradients = tf.gradients(self.local_a3c.total_loss, var_refs)
            global_vars = global_a3c.get_vars
            if self.clip_norm is not None:
                self.rollout_gradients, grad_norm = tf.clip_by_global_norm(
                    self.rollout_gradients, self.clip_norm)
            self.rollout_gradients = list(zip(self.rollout_gradients, global_vars()))
            self.rollout_apply_gradients = grad_applier.apply_gradients(self.rollout_gradients)

        # setup local pretrained model
        self.local_pretrained = None
        if nstep_bc > 0:
            assert local_pretrained_model is not None
            assert global_pretrained_model is not None
            self.local_pretrained = local_pretrained_model
            self.sync_pretrained = self.local_pretrained.sync_from(global_pretrained_model)

        # setup env
        self.rolloutgame = GameState(env_id=env_id, display=False,
                            no_op_max=0, human_demo=False, episode_life=True,
                            override_num_noops=0)
        self.local_t = 0
        self.episode_reward = 0
        self.episode_steps = 0

        self.action_meaning = self.rolloutgame.env.unwrapped.get_action_meanings()

        assert self.local_a3c is not None
        if nstep_bc > 0:
            assert self.local_pretrained is not None

        self.episode = SILReplayMemory(
            self.action_size, max_len=None, gamma=self.gamma,
            clip=self.reward_clipped,
            height=self.local_a3c.in_shape[0],
            width=self.local_a3c.in_shape[1],
            phi_length=self.local_a3c.in_shape[2],
            reward_constant=self.reward_constant)
class RefreshThread(CommonWorker):
    """Rollout Thread Class."""
    advice_confidence = 0.8
    gamma = 0.99

    def __init__(self, thread_index, action_size, env_id,
                 global_a3c, local_a3c, update_in_rollout, nstep_bc,
                 global_pretrained_model, local_pretrained_model,
                 transformed_bellman=False, no_op_max=0,
                 device='/cpu:0', entropy_beta=0.01, clip_norm=None,
                 grad_applier=None, initial_learn_rate=0.007,
                 learning_rate_input=None):
        """Initialize RolloutThread class."""
        self.is_refresh_thread = True
        self.action_size = action_size
        self.thread_idx = thread_index
        self.transformed_bellman = transformed_bellman
        self.entropy_beta = entropy_beta
        self.clip_norm = clip_norm
        self.initial_learning_rate = initial_learn_rate
        self.learning_rate_input = learning_rate_input

        self.no_op_max = no_op_max
        self.override_num_noops = 0 if self.no_op_max == 0 else None

        logger.info("===REFRESH thread_index: {}===".format(self.thread_idx))
        logger.info("device: {}".format(device))
        logger.info("action_size: {}".format(self.action_size))
        logger.info("reward_type: {}".format(self.reward_type))
        logger.info("transformed_bellman: {}".format(
            colored(self.transformed_bellman,
                    "green" if self.transformed_bellman else "red")))
        logger.info("update in rollout: {}".format(
            colored(update_in_rollout, "green" if update_in_rollout else "red")))
        logger.info("N-step BC: {}".format(nstep_bc))

        self.reward_clipped = True if self.reward_type == 'CLIP' else False

        # setup local a3c
        self.local_a3c = local_a3c
        self.sync_a3c = self.local_a3c.sync_from(global_a3c)
        with tf.device(device):
            local_vars = self.local_a3c.get_vars
            self.local_a3c.prepare_loss(
                entropy_beta=self.entropy_beta, critic_lr=0.5)
            var_refs = [v._ref() for v in local_vars()]
            self.rollout_gradients = tf.gradients(self.local_a3c.total_loss, var_refs)
            global_vars = global_a3c.get_vars
            if self.clip_norm is not None:
                self.rollout_gradients, grad_norm = tf.clip_by_global_norm(
                    self.rollout_gradients, self.clip_norm)
            self.rollout_gradients = list(zip(self.rollout_gradients, global_vars()))
            self.rollout_apply_gradients = grad_applier.apply_gradients(self.rollout_gradients)

        # setup local pretrained model
        self.local_pretrained = None
        if nstep_bc > 0:
            assert local_pretrained_model is not None
            assert global_pretrained_model is not None
            self.local_pretrained = local_pretrained_model
            self.sync_pretrained = self.local_pretrained.sync_from(global_pretrained_model)

        # setup env
        self.rolloutgame = GameState(env_id=env_id, display=False,
                            no_op_max=0, human_demo=False, episode_life=True,
                            override_num_noops=0)
        self.local_t = 0
        self.episode_reward = 0
        self.episode_steps = 0

        self.action_meaning = self.rolloutgame.env.unwrapped.get_action_meanings()

        assert self.local_a3c is not None
        if nstep_bc > 0:
            assert self.local_pretrained is not None

        self.episode = SILReplayMemory(
            self.action_size, max_len=None, gamma=self.gamma,
            clip=self.reward_clipped,
            height=self.local_a3c.in_shape[0],
            width=self.local_a3c.in_shape[1],
            phi_length=self.local_a3c.in_shape[2],
            reward_constant=self.reward_constant)


    def record_rollout(self, score=0, steps=0, old_return=0, new_return=0,
                       global_t=0, rollout_ctr=0, rollout_added_ctr=0,
                       mode='Rollout', confidence=None, episodes=None):
        """Record rollout summary."""
        summary = tf.Summary()
        summary.value.add(tag='{}/score'.format(mode),
                          simple_value=float(score))
        summary.value.add(tag='{}/old_return_from_s'.format(mode),
                          simple_value=float(old_return))
        summary.value.add(tag='{}/new_return_from_s'.format(mode),
                          simple_value=float(new_return))
        summary.value.add(tag='{}/steps'.format(mode),
                          simple_value=float(steps))
        summary.value.add(tag='{}/all_rollout_ctr'.format(mode),
                          simple_value=float(rollout_ctr))
        summary.value.add(tag='{}/rollout_added_ctr'.format(mode),
                          simple_value=float(rollout_added_ctr))
        if confidence is not None:
            summary.value.add(tag='{}/advice-confidence'.format(mode),
                              simple_value=float(confidence))
        if episodes is not None:
            summary.value.add(tag='{}/episodes'.format(mode),
                              simple_value=float(episodes))
        self.writer.add_summary(summary, global_t)
        self.writer.flush()

    def compute_return_for_state(self, rewards, terminal):
        """Compute expected return."""
        length = np.shape(rewards)[0]
        returns = np.empty_like(rewards, dtype=np.float32)

        if self.reward_clipped:
            rewards = np.clip(rewards, -1., 1.)
        else:
            rewards = np.sign(rewards) * self.reward_constant + rewards

        for i in reversed(range(length)):
            if terminal[i]:
                returns[i] = rewards[i] if self.reward_clipped else transform_h(rewards[i])
            else:
                if self.reward_clipped:
                    returns[i] = rewards[i] + self.gamma * returns[i+1]
                else:
                    # apply transformed expected return
                    exp_r_t = self.gamma * transform_h_inv(returns[i+1])
                    returns[i] = transform_h(rewards[i] + exp_r_t)
        return returns[0]

    def update_a3c(self, sess, actions, states, rewards, values, global_t):
        cumsum_reward = 0.0
        actions.reverse()
        states.reverse()
        rewards.reverse()
        values.reverse()

        batch_state = []
        batch_action = []
        batch_adv = []
        batch_cumsum_reward = []

        # compute and accumulate gradients
        for(ai, ri, si, vi) in zip(actions, rewards, states, values):
            if self.transformed_bellman:
                ri = np.sign(ri) * self.reward_constant + ri
                cumsum_reward = transform_h(
                    ri + self.gamma * transform_h_inv(cumsum_reward))
            else:
                cumsum_reward = ri + self.gamma * cumsum_reward
            advantage = cumsum_reward - vi

            # convert action to one-hot vector
            a = np.zeros([self.action_size])
            a[ai] = 1

            batch_state.append(si)
            batch_action.append(a)
            batch_adv.append(advantage)
            batch_cumsum_reward.append(cumsum_reward)

        cur_learning_rate = self._anneal_learning_rate(global_t,
                self.initial_learning_rate )

        feed_dict = {
            self.local_a3c.s: batch_state,
            self.local_a3c.a: batch_action,
            self.local_a3c.advantage: batch_adv,
            self.local_a3c.cumulative_reward: batch_cumsum_reward,
            self.learning_rate_input: cur_learning_rate,
            }

        sess.run(self.rollout_apply_gradients, feed_dict=feed_dict)

        return batch_adv

    def rollout(self, a3c_sess, folder, pretrain_sess, global_t, past_state,
                add_all_rollout, ep_max_steps, nstep_bc, update_in_rollout):
        """Perform one rollout until terminal."""
        a3c_sess.run(self.sync_a3c)
        if nstep_bc > 0:
            pretrain_sess.run(self.sync_pretrained)

        _, fs, old_a, old_return, _, _ = past_state

        states = []
        actions = []
        rewards = []
        values = []
        terminals = []
        confidences = []

        rollout_ctr, rollout_added_ctr = 0, 0
        rollout_new_return, rollout_old_return = 0, 0

        terminal_pseudo = False  # loss of life
        terminal_end = False  # real terminal
        add = False

        self.rolloutgame.reset(hard_reset=True)
        self.rolloutgame.restore_full_state(fs)
        # check if restore successful
        fs_check = self.rolloutgame.clone_full_state()
        assert fs_check.all() == fs.all()
        del fs_check

        start_local_t = self.local_t
        self.rolloutgame.step(0)

        # prevent rollout too long, set max_ep_steps to be lower than ALE default
        # see https://github.com/openai/gym/blob/54f22cf4db2e43063093a1b15d968a57a32b6e90/gym/envs/__init__.py#L635
        # but in all games tested, no rollout exceeds ep_max_steps
        while ep_max_steps > 0:
            state = cv2.resize(self.rolloutgame.s_t,
                       self.local_a3c.in_shape[:-1],
                       interpolation=cv2.INTER_AREA)
            fullstate = self.rolloutgame.clone_full_state()

            if nstep_bc > 0: # LiDER-TA or BC
                model_pi = self.local_pretrained.run_policy(pretrain_sess, state)
                action, confidence = self.choose_action_with_high_confidence(
                                          model_pi, exclude_noop=False)
                confidences.append(confidence) # not using "confidences" for anything
                nstep_bc -= 1
            else: # LiDER, refresh with current policy
                pi_, _, logits_ = self.local_a3c.run_policy_and_value(a3c_sess,
                                                                      state)
                action = self.pick_action(logits_)
                confidences.append(pi_[action])

            value_ = self.local_a3c.run_value(a3c_sess, state)
            values.append(value_)
            states.append(state)
            actions.append(action)

            self.rolloutgame.step(action)

            ep_max_steps -= 1

            reward = self.rolloutgame.reward
            terminal = self.rolloutgame.terminal
            terminals.append(terminal)

            self.episode_reward += reward

            self.episode.add_item(self.rolloutgame.s_t, fullstate, action,
                                  reward, terminal, from_rollout=True)

            if self.reward_type == 'CLIP':
                reward = np.sign(reward)
            rewards.append(reward)

            self.local_t += 1
            self.episode_steps += 1
            global_t += 1

            self.rolloutgame.update()

            if terminal:
                terminal_pseudo = True
                env = self.rolloutgame.env
                name = 'EpisodicLifeEnv'
                rollout_ctr += 1
                terminal_end = get_wrapper_by_name(env, name).was_real_done

                new_return = self.compute_return_for_state(rewards, terminals)

                if not add_all_rollout:
                    if new_return > old_return:
                        add = True
                else:
                    add = True

                if add:
                    rollout_added_ctr += 1
                    rollout_new_return += new_return
                    rollout_old_return += old_return
                    # update policy immediate using a good rollout
                    if update_in_rollout:
                        batch_adv = self.update_a3c(a3c_sess, actions, states, rewards, values, global_t)

                self.episode_reward = 0
                self.episode_steps = 0
                self.rolloutgame.reset(hard_reset=True)
                break

        diff_local_t = self.local_t - start_local_t

        return diff_local_t, terminal_end, terminal_pseudo, rollout_ctr, \
               rollout_added_ctr, add, rollout_new_return, rollout_old_return