Esempio n. 1
0
class A3CTrainingThread(CommonWorker):
    """Asynchronous Actor-Critic Training Thread Class."""
    log_interval = 100
    perf_log_interval = 1000
    local_t_max = 20
    entropy_beta = 0.01
    gamma = 0.99
    shaping_actions = -1  # -1 all actions, 0 exclude noop
    transformed_bellman = False
    clip_norm = 0.5
    use_grad_cam = False
    use_sil = False
    log_idx = 0
    reward_constant = 0

    def __init__(self,
                 thread_index,
                 global_net,
                 local_net,
                 initial_learning_rate,
                 learning_rate_input,
                 grad_applier,
                 device=None,
                 no_op_max=30):
        """Initialize A3CTrainingThread class."""
        assert self.action_size != -1

        self.is_sil_thread = False
        self.is_refresh_thread = False

        self.thread_idx = thread_index
        self.learning_rate_input = learning_rate_input
        self.local_net = local_net

        self.no_op_max = no_op_max
        self.override_num_noops = 0 if self.no_op_max == 0 else None

        logger.info("===A3C thread_index: {}===".format(self.thread_idx))
        logger.info("device: {}".format(device))
        logger.info("use_sil: {}".format(
            colored(self.use_sil, "green" if self.use_sil else "red")))
        logger.info("local_t_max: {}".format(self.local_t_max))
        logger.info("action_size: {}".format(self.action_size))
        logger.info("entropy_beta: {}".format(self.entropy_beta))
        logger.info("gamma: {}".format(self.gamma))
        logger.info("reward_type: {}".format(self.reward_type))
        logger.info("transformed_bellman: {}".format(
            colored(self.transformed_bellman,
                    "green" if self.transformed_bellman else "red")))
        logger.info("clip_norm: {}".format(self.clip_norm))
        logger.info("use_grad_cam: {}".format(
            colored(self.use_grad_cam,
                    "green" if self.use_grad_cam else "red")))

        reward_clipped = True if self.reward_type == 'CLIP' else False
        local_vars = self.local_net.get_vars

        with tf.device(device):
            self.local_net.prepare_loss(entropy_beta=self.entropy_beta,
                                        critic_lr=0.5)
            var_refs = [v._ref() for v in local_vars()]

            self.gradients = tf.gradients(self.local_net.total_loss, var_refs)

        global_vars = global_net.get_vars

        with tf.device(device):
            if self.clip_norm is not None:
                self.gradients, grad_norm = tf.clip_by_global_norm(
                    self.gradients, self.clip_norm)
            self.gradients = list(zip(self.gradients, global_vars()))
            self.apply_gradients = grad_applier.apply_gradients(self.gradients)

        self.sync = self.local_net.sync_from(global_net)

        self.game_state = GameState(env_id=self.env_id,
                                    display=False,
                                    no_op_max=self.no_op_max,
                                    human_demo=False,
                                    episode_life=True,
                                    override_num_noops=self.override_num_noops)

        self.local_t = 0

        self.initial_learning_rate = initial_learning_rate

        self.episode_reward = 0
        self.episode_steps = 0

        # variable controlling log output
        self.prev_local_t = 0

        with tf.device(device):
            if self.use_grad_cam:
                self.action_meaning = self.game_state.env.unwrapped \
                    .get_action_meanings()
                self.local_net.build_grad_cam_grads()

        if self.use_sil:
            self.episode = SILReplayMemory(
                self.action_size,
                max_len=None,
                gamma=self.gamma,
                clip=reward_clipped,
                height=self.local_net.in_shape[0],
                width=self.local_net.in_shape[1],
                phi_length=self.local_net.in_shape[2],
                reward_constant=self.reward_constant)

    def train(self, sess, global_t, train_rewards):
        """Train A3C."""
        states = []
        fullstates = []
        actions = []
        rewards = []
        values = []
        rho = []

        terminal_pseudo = False  # loss of life
        terminal_end = False  # real terminal

        # copy weights from shared to local
        sess.run(self.sync)

        start_local_t = self.local_t

        # t_max times loop
        for i in range(self.local_t_max):
            state = cv2.resize(self.game_state.s_t,
                               self.local_net.in_shape[:-1],
                               interpolation=cv2.INTER_AREA)
            fullstate = self.game_state.clone_full_state()

            pi_, value_, logits_ = self.local_net.run_policy_and_value(
                sess, state)
            action = self.pick_action(logits_)

            states.append(state)
            fullstates.append(fullstate)
            actions.append(action)
            values.append(value_)

            if self.thread_idx == self.log_idx \
               and self.local_t % self.log_interval == 0:
                log_msg1 = "lg={}".format(
                    np.array_str(logits_, precision=4, suppress_small=True))
                log_msg2 = "pi={}".format(
                    np.array_str(pi_, precision=4, suppress_small=True))
                log_msg3 = "V={:.4f}".format(value_)
                logger.debug(log_msg1)
                logger.debug(log_msg2)
                logger.debug(log_msg3)

            # process game
            self.game_state.step(action)

            # receive game result
            reward = self.game_state.reward
            terminal = self.game_state.terminal

            self.episode_reward += reward

            if self.use_sil:
                # save states in episode memory
                self.episode.add_item(self.game_state.s_t, fullstate, action,
                                      reward, terminal)

            if self.reward_type == 'CLIP':
                reward = np.sign(reward)

            rewards.append(reward)

            self.local_t += 1
            self.episode_steps += 1
            global_t += 1

            # s_t1 -> s_t
            self.game_state.update()

            if terminal:
                terminal_pseudo = True

                env = self.game_state.env
                name = 'EpisodicLifeEnv'
                if get_wrapper_by_name(env, name).was_real_done:
                    # reduce log freq
                    if self.thread_idx == self.log_idx:
                        log_msg = "train: worker={} global_t={} local_t={}".format(
                            self.thread_idx, global_t, self.local_t)
                        score_str = colored(
                            "score={}".format(self.episode_reward), "magenta")
                        steps_str = colored(
                            "steps={}".format(self.episode_steps), "blue")
                        log_msg += " {} {}".format(score_str, steps_str)
                        logger.debug(log_msg)

                    train_rewards['train'][global_t] = (self.episode_reward,
                                                        self.episode_steps)
                    self.record_summary(score=self.episode_reward,
                                        steps=self.episode_steps,
                                        episodes=None,
                                        global_t=global_t,
                                        mode='Train')
                    self.episode_reward = 0
                    self.episode_steps = 0
                    terminal_end = True

                self.game_state.reset(hard_reset=False)
                break

        cumsum_reward = 0.0
        if not terminal:
            state = cv2.resize(self.game_state.s_t,
                               self.local_net.in_shape[:-1],
                               interpolation=cv2.INTER_AREA)
            cumsum_reward = self.local_net.run_value(sess, state)

        actions.reverse()
        states.reverse()
        rewards.reverse()
        values.reverse()

        batch_state = []
        batch_action = []
        batch_adv = []
        batch_cumsum_reward = []

        # compute and accumulate gradients
        for (ai, ri, si, vi) in zip(actions, rewards, states, values):
            if self.transformed_bellman:
                ri = np.sign(ri) * self.reward_constant + ri
                cumsum_reward = transform_h(ri + self.gamma *
                                            transform_h_inv(cumsum_reward))
            else:
                cumsum_reward = ri + self.gamma * cumsum_reward
            advantage = cumsum_reward - vi

            # convert action to one-hot vector
            a = np.zeros([self.action_size])
            a[ai] = 1

            batch_state.append(si)
            batch_action.append(a)
            batch_adv.append(advantage)
            batch_cumsum_reward.append(cumsum_reward)

        cur_learning_rate = self._anneal_learning_rate(
            global_t, self.initial_learning_rate)

        feed_dict = {
            self.local_net.s: batch_state,
            self.local_net.a: batch_action,
            self.local_net.advantage: batch_adv,
            self.local_net.cumulative_reward: batch_cumsum_reward,
            self.learning_rate_input: cur_learning_rate,
        }

        sess.run(self.apply_gradients, feed_dict=feed_dict)

        t = self.local_t - self.prev_local_t
        if (self.thread_idx == self.log_idx and t >= self.perf_log_interval):
            self.prev_local_t += self.perf_log_interval
            elapsed_time = time.time() - self.start_time
            steps_per_sec = global_t / elapsed_time
            logger.info("worker-{}, log_worker-{}".format(
                self.thread_idx, self.log_idx))
            logger.info("Performance : {} STEPS in {:.0f} sec. {:.0f}"
                        " STEPS/sec. {:.2f}M STEPS/hour.".format(
                            global_t, elapsed_time, steps_per_sec,
                            steps_per_sec * 3600 / 1000000.))

        # return advanced local step size
        diff_local_t = self.local_t - start_local_t
        return diff_local_t, terminal_end, terminal_pseudo
Esempio n. 2
0
class A3CTrainingThread(object):
    log_interval = 100
    performance_log_interval = 1000
    local_t_max = 20
    demo_t_max = 20
    use_lstm = False
    action_size = -1
    entropy_beta = 0.01
    demo_entropy_beta = 0.01
    gamma = 0.99
    use_mnih_2015 = False
    env_id = None
    reward_type = 'CLIP'  # CLIP | LOG | RAW
    finetune_upper_layers_oinly = False
    shaping_reward = 0.001
    shaping_factor = 1.
    shaping_gamma = 0.85
    advice_confidence = 0.8
    shaping_actions = -1  # -1 all actions, 0 exclude noop
    transformed_bellman = False
    clip_norm = 0.5
    use_grad_cam = False

    def __init__(self,
                 thread_index,
                 global_network,
                 initial_learning_rate,
                 learning_rate_input,
                 grad_applier,
                 max_global_time_step,
                 device=None,
                 pretrained_model=None,
                 pretrained_model_sess=None,
                 advice=False,
                 reward_shaping=False):
        assert self.action_size != -1

        self.thread_index = thread_index
        self.learning_rate_input = learning_rate_input
        self.max_global_time_step = max_global_time_step
        self.use_pretrained_model_as_advice = advice
        self.use_pretrained_model_as_reward_shaping = reward_shaping

        logger.info("thread_index: {}".format(self.thread_index))
        logger.info("local_t_max: {}".format(self.local_t_max))
        logger.info("use_lstm: {}".format(
            colored(self.use_lstm, "green" if self.use_lstm else "red")))
        logger.info("action_size: {}".format(self.action_size))
        logger.info("entropy_beta: {}".format(self.entropy_beta))
        logger.info("gamma: {}".format(self.gamma))
        logger.info("reward_type: {}".format(self.reward_type))
        logger.info("finetune_upper_layers_only: {}".format(
            colored(self.finetune_upper_layers_only,
                    "green" if self.finetune_upper_layers_only else "red")))
        logger.info("use_pretrained_model_as_advice: {}".format(
            colored(
                self.use_pretrained_model_as_advice,
                "green" if self.use_pretrained_model_as_advice else "red")))
        logger.info("use_pretrained_model_as_reward_shaping: {}".format(
            colored(
                self.use_pretrained_model_as_reward_shaping, "green"
                if self.use_pretrained_model_as_reward_shaping else "red")))
        logger.info("transformed_bellman: {}".format(
            colored(self.transformed_bellman,
                    "green" if self.transformed_bellman else "red")))
        logger.info("clip_norm: {}".format(self.clip_norm))
        logger.info("use_grad_cam: {}".format(
            colored(self.use_grad_cam,
                    "green" if self.use_grad_cam else "red")))

        if self.use_lstm:
            GameACLSTMNetwork.use_mnih_2015 = self.use_mnih_2015
            self.local_network = GameACLSTMNetwork(self.action_size,
                                                   thread_index, device)
        else:
            GameACFFNetwork.use_mnih_2015 = self.use_mnih_2015
            self.local_network = GameACFFNetwork(self.action_size,
                                                 thread_index, device)

        with tf.device(device):
            self.local_network.prepare_loss(entropy_beta=self.entropy_beta,
                                            critic_lr=0.5)
            local_vars = self.local_network.get_vars
            if self.finetune_upper_layers_only:
                local_vars = self.local_network.get_vars_upper
            var_refs = [v._ref() for v in local_vars()]

            self.gradients = tf.gradients(self.local_network.total_loss,
                                          var_refs)

        global_vars = global_network.get_vars
        if self.finetune_upper_layers_only:
            global_vars = global_network.get_vars_upper

        with tf.device(device):
            if self.clip_norm is not None:
                self.gradients, grad_norm = tf.clip_by_global_norm(
                    self.gradients, self.clip_norm)
            self.gradients = list(zip(self.gradients, global_vars()))
            self.apply_gradients = grad_applier.apply_gradients(self.gradients)

            #self.apply_gradients = grad_applier.apply_gradients(
            #    global_vars(),
            #    self.gradients)

        self.sync = self.local_network.sync_from(
            global_network, upper_layers_only=self.finetune_upper_layers_only)

        self.game_state = GameState(env_id=self.env_id,
                                    display=False,
                                    no_op_max=30,
                                    human_demo=False,
                                    episode_life=True)

        self.local_t = 0

        self.initial_learning_rate = initial_learning_rate

        self.episode_reward = 0
        self.episode_steps = 0

        # variable controlling log output
        self.prev_local_t = 0

        self.is_demo_thread = False

        with tf.device(device):
            if self.use_grad_cam:
                self.action_meaning = self.game_state.env.unwrapped.get_action_meanings(
                )
                self.local_network.build_grad_cam_grads()

        self.pretrained_model = pretrained_model
        self.pretrained_model_sess = pretrained_model_sess
        self.psi = 0.9 if self.use_pretrained_model_as_advice else 0.0
        self.advice_ctr = 0
        self.shaping_ctr = 0
        self.last_rho = 0.

        if self.use_pretrained_model_as_advice or self.use_pretrained_model_as_reward_shaping:
            assert self.pretrained_model is not None

    def _anneal_learning_rate(self, global_time_step):
        learning_rate = self.initial_learning_rate * (
            self.max_global_time_step -
            global_time_step) / self.max_global_time_step
        if learning_rate < 0.0:
            learning_rate = 0.0
        return learning_rate

    def choose_action(self, logits):
        """sample() in https://github.com/ppyht2/tf-a2c/blob/master/src/policy.py"""
        noise = np.random.uniform(0, 1, np.shape(logits))
        return np.argmax(logits - np.log(-np.log(noise)))

    def choose_action_with_high_confidence(self, pi_values, exclude_noop=True):
        actions_confidence = []
        # exclude NOOP action
        for action in range(1 if exclude_noop else 0, self.action_size):
            actions_confidence.append(pi_values[action][0][0])
        max_confidence_action = np.argmax(actions_confidence)
        confidence = actions_confidence[max_confidence_action]
        return (max_confidence_action + (1 if exclude_noop else 0)), confidence

    def set_summary_writer(self, writer):
        self.writer = writer

    def record_summary(self,
                       score=0,
                       steps=0,
                       episodes=None,
                       global_t=0,
                       mode='Test'):
        summary = tf.Summary()
        summary.value.add(tag='{}/score'.format(mode),
                          simple_value=float(score))
        summary.value.add(tag='{}/steps'.format(mode),
                          simple_value=float(steps))
        if episodes is not None:
            summary.value.add(tag='{}/episodes'.format(mode),
                              simple_value=float(episodes))
        self.writer.add_summary(summary, global_t)
        self.writer.flush()

    def set_start_time(self, start_time):
        self.start_time = start_time

    def generate_cam(self, sess, test_cam_si, global_t):
        cam_side_img = []
        for i in range(len(test_cam_si)):
            # get max action per demo state
            readout_t = self.local_network.run_policy(sess, test_cam_si[i])
            action = np.argmax(readout_t)

            # convert action to one-hot vector
            action_onehot = [0.] * self.game_state.env.action_space.n
            action_onehot[action] = 1.

            # compute grad cam for conv layer 3
            activations, gradients = self.local_network.evaluate_grad_cam(
                sess, test_cam_si[i], action_onehot)
            cam = grad_cam(activations, gradients)
            cam_img = visualize_cam(cam)

            side_by_side = generate_image_for_cam_video(
                test_cam_si[i], cam_img, global_t, i,
                self.action_meaning[action])

            cam_side_img.append(side_by_side)

        return cam_side_img

    def generate_cam_video(self,
                           sess,
                           time_per_step,
                           global_t,
                           folder,
                           demo_memory_cam,
                           demo_cam_human=False):
        # use one demonstration data to record cam
        # only need to make movie for demo data once
        cam_side_img = self.generate_cam(sess, demo_memory_cam, global_t)

        path = '/frames/demo-cam_side_img'
        if demo_cam_human:
            path += '_human'

        make_movie(cam_side_img,
                   folder + '{}{ep:010d}'.format(path, ep=(global_t)),
                   duration=len(cam_side_img) * time_per_step,
                   true_image=True,
                   salience=False)
        del cam_side_img

    def testing_model(self,
                      sess,
                      max_steps,
                      global_t,
                      folder,
                      demo_memory_cam=None,
                      demo_cam_human=False):
        logger.info("Testing model at global_t={}...".format(global_t))
        # copy weights from shared to local
        sess.run(self.sync)

        if demo_memory_cam is not None:
            self.generate_cam_video(sess, 0.03, global_t, folder,
                                    demo_memory_cam, demo_cam_human)
            return
        else:
            self.game_state.reset(hard_reset=True)
            max_steps += 4
            test_memory = ReplayMemory(
                84,
                84,
                np.random.RandomState(),
                max_steps=max_steps,
                phi_length=4,
                num_actions=self.game_state.env.action_space.n,
                wrap_memory=False,
                full_state_size=self.game_state.clone_full_state().shape[0])
            for _ in range(4):
                test_memory.add(self.game_state.x_t,
                                0,
                                self.game_state.reward,
                                self.game_state.terminal,
                                self.game_state.lives,
                                fullstate=self.game_state.full_state)

        episode_buffer = []
        test_memory_cam = []

        total_reward = 0
        total_steps = 0
        episode_reward = 0
        episode_steps = 0
        n_episodes = 0
        terminal = False
        while True:
            #pi_ = self.local_network.run_policy(sess, self.game_state.s_t)
            test_memory_cam.append(self.game_state.s_t)
            episode_buffer.append(self.game_state.get_screen_rgb())
            pi_, value_, logits_ = self.local_network.run_policy_and_value(
                sess, self.game_state.s_t)
            #action = self.choose_action(logits_)
            action = np.argmax(pi_)

            # take action
            self.game_state.step(action)
            terminal = self.game_state.terminal
            memory_full = episode_steps == max_steps - 5
            terminal_ = terminal or memory_full

            # store the transition to replay memory
            test_memory.add(self.game_state.x_t1,
                            action,
                            self.game_state.reward,
                            terminal_,
                            self.game_state.lives,
                            fullstate=self.game_state.full_state1)

            # update the old values
            episode_reward += self.game_state.reward
            episode_steps += 1

            # s_t = s_t1
            self.game_state.update()

            if terminal_:
                if get_wrapper_by_name(
                        self.game_state.env,
                        'EpisodicLifeEnv').was_real_done or memory_full:
                    time_per_step = 0.03
                    images = np.array(episode_buffer)
                    make_movie(images,
                               folder +
                               '/frames/image{ep:010d}'.format(ep=global_t),
                               duration=len(images) * time_per_step,
                               true_image=True,
                               salience=False)
                    break

                self.game_state.reset(hard_reset=False)
                if self.use_lstm:
                    self.local_network.reset_state()

        total_reward = episode_reward
        total_steps = episode_steps
        log_data = (global_t, self.thread_index, total_reward, total_steps)
        logger.info(
            "test: global_t={} worker={} final score={} final steps={}".format(
                *log_data))

        self.generate_cam_video(sess, 0.03, global_t, folder,
                                np.array(test_memory_cam))
        test_memory.save(name='test_cam', folder=folder, resize=True)

        if self.use_lstm:
            self.local_network.reset_state()

        return

    def testing(self, sess, max_steps, global_t, folder, demo_memory_cam=None):
        logger.info("Evaluate policy at global_t={}...".format(global_t))
        # copy weights from shared to local
        sess.run(self.sync)

        if demo_memory_cam is not None and global_t % 5000000 == 0:
            self.generate_cam_video(sess, 0.03, global_t, folder,
                                    demo_memory_cam)

        episode_buffer = []
        self.game_state.reset(hard_reset=True)
        episode_buffer.append(self.game_state.get_screen_rgb())

        total_reward = 0
        total_steps = 0
        episode_reward = 0
        episode_steps = 0
        n_episodes = 0
        while max_steps > 0:
            #pi_ = self.local_network.run_policy(sess, self.game_state.s_t)
            pi_, value_, logits_ = self.local_network.run_policy_and_value(
                sess, self.game_state.s_t)
            if False:
                action = np.random.choice(range(self.action_size), p=pi_)
            else:
                action = self.choose_action(logits_)

            if self.use_pretrained_model_as_advice:
                psi = self.psi if self.psi > 0.001 else 0.0
                if psi > np.random.rand():
                    model_pi = self.pretrained_model.run_policy(
                        self.pretrained_model_sess, self.game_state.s_t)
                    model_action, confidence = self.choose_action_with_high_confidence(
                        model_pi, exclude_noop=False)
                    if model_action > self.shaping_actions and confidence >= self.advice_confidence:
                        action = model_action

            # take action
            self.game_state.step(action)
            terminal = self.game_state.terminal

            if n_episodes == 0 and global_t % 5000000 == 0:
                episode_buffer.append(self.game_state.get_screen_rgb())

            episode_reward += self.game_state.reward
            episode_steps += 1
            max_steps -= 1

            # s_t = s_t1
            self.game_state.update()

            if terminal:
                if get_wrapper_by_name(self.game_state.env,
                                       'EpisodicLifeEnv').was_real_done:
                    if n_episodes == 0 and global_t % 5000000 == 0:
                        time_per_step = 0.0167
                        images = np.array(episode_buffer)
                        make_movie(
                            images,
                            folder +
                            '/frames/image{ep:010d}'.format(ep=global_t),
                            duration=len(images) * time_per_step,
                            true_image=True,
                            salience=False)
                        episode_buffer = []
                    n_episodes += 1
                    score_str = colored("score={}".format(episode_reward),
                                        "magenta")
                    steps_str = colored("steps={}".format(episode_steps),
                                        "blue")
                    log_data = (global_t, self.thread_index, n_episodes,
                                score_str, steps_str, total_steps)
                    logger.debug(
                        "test: global_t={} worker={} trial={} {} {} total_steps={}"
                        .format(*log_data))
                    total_reward += episode_reward
                    total_steps += episode_steps
                    episode_reward = 0
                    episode_steps = 0

                self.game_state.reset(hard_reset=False)
                if self.use_lstm:
                    self.local_network.reset_state()

        if n_episodes == 0:
            total_reward = episode_reward
            total_steps = episode_steps
        else:
            # (timestep, total sum of rewards, total # of steps before terminating)
            total_reward = total_reward / n_episodes
            total_steps = total_steps // n_episodes

        log_data = (global_t, self.thread_index, total_reward, total_steps,
                    n_episodes)
        logger.info(
            "test: global_t={} worker={} final score={} final steps={} # trials={}"
            .format(*log_data))

        self.record_summary(score=total_reward,
                            steps=total_steps,
                            episodes=n_episodes,
                            global_t=global_t,
                            mode='Test')

        # reset variables used in training
        self.episode_reward = 0
        self.episode_steps = 0
        self.game_state.reset(hard_reset=True)
        self.last_rho = 0.
        if self.is_demo_thread:
            self.replay_mem_reset()

        if self.use_lstm:
            self.local_network.reset_state()
        return total_reward, total_steps, n_episodes

    def pretrain_init(self, demo_memory):
        self.demo_memory_size = len(demo_memory)
        self.demo_memory = demo_memory
        self.replay_mem_reset()

    def replay_mem_reset(self, demo_memory_idx=None):
        if demo_memory_idx is not None:
            self.demo_memory_idx = demo_memory_idx
        else:
            # new random episode
            self.demo_memory_idx = np.random.randint(0, self.demo_memory_size)
        self.demo_memory_count = np.random.randint(
            0,
            len(self.demo_memory[self.demo_memory_idx]) - self.local_t_max)
        # if self.demo_memory_count+self.local_t_max < len(self.demo_memory[self.demo_memory_idx]):
        #           self.demo_memory_max_count = np.random.randint(self.demo_memory_count+self.local_t_max, len(self.demo_memory[self.demo_memory_idx]))
        # else:
        #           self.demo_memory_max_count = len(self.demo_memory[self.demo_memory_idx])
        logger.debug(
            "worker={} mem_reset demo_memory_idx={} demo_memory_start={}".
            format(self.thread_index, self.demo_memory_idx,
                   self.demo_memory_count))
        s_t, action, reward, terminal = self.demo_memory[self.demo_memory_idx][
            self.demo_memory_count]
        self.demo_memory_action = action
        self.demo_memory_reward = reward
        self.demo_memory_terminal = terminal
        if not self.demo_memory[self.demo_memory_idx].imgs_normalized:
            self.demo_memory_s_t = s_t * (1.0 / 255.0)
        else:
            self.demo_memory_s_t = s_t

    def replay_mem_process(self):
        self.demo_memory_count += 1
        s_t, action, reward, terminal = self.demo_memory[self.demo_memory_idx][
            self.demo_memory_count]
        self.demo_memory_next_action = action
        self.demo_memory_reward = reward
        self.demo_memory_terminal = terminal
        if not self.demo_memory[self.demo_memory_idx].imgs_normalized:
            self.demo_memory_s_t1 = s_t * (1.0 / 255.0)
        else:
            self.demo_memory_s_t1 = s_t

    def replay_mem_update(self):
        self.demo_memory_action = self.demo_memory_next_action
        self.demo_memory_s_t = self.demo_memory_s_t1

    def demo_process(self, sess, global_t, demo_memory_idx=None):
        states = []
        actions = []
        rewards = []
        values = []

        demo_ended = False
        terminal_end = False

        # copy weights from shared to local
        sess.run(self.sync)

        start_local_t = self.local_t

        if self.use_lstm:
            reset_lstm_state = False
            start_lstm_state = self.local_network.lstm_state_out

        # t_max times loop
        for i in range(self.demo_t_max):
            pi_, value_, logits_ = self.local_network.run_policy_and_value(
                sess, self.demo_memory_s_t)
            action = self.demo_memory_action
            time.sleep(0.0025)

            states.append(self.demo_memory_s_t)
            actions.append(action)
            values.append(value_)

            if (self.thread_index == 0) and (self.local_t % self.log_interval
                                             == 0):
                log_msg = "lg={}".format(
                    np.array_str(logits_, precision=4, suppress_small=True))
                log_msg += " pi={}".format(
                    np.array_str(pi_, precision=4, suppress_small=True))
                log_msg += " V={:.4f}".format(value_)
                logger.debug(log_msg)

            # process replay memory
            self.replay_mem_process()

            # receive replay memory result
            reward = self.demo_memory_reward
            terminal = self.demo_memory_terminal

            self.episode_reward += reward

            if self.reward_type == 'LOG':
                reward = np.sign(reward) * np.log(1 + np.abs(reward))
            elif self.reward_type == 'CLIP':
                # clip reward
                reward = np.sign(reward)

            rewards.append(reward)

            self.local_t += 1
            self.episode_steps += 1

            # demo_memory_s_t1 -> demo_memory_s_t
            self.replay_mem_update()
            s_t = self.demo_memory_s_t

            if terminal or self.demo_memory_count == len(
                    self.demo_memory[self.demo_memory_idx]):
                logger.debug("worker={} score={}".format(
                    self.thread_index, self.episode_reward))
                demo_ended = True
                if terminal:
                    terminal_end = True
                    if self.use_lstm:
                        self.local_network.reset_state()

                else:
                    # some demo episodes doesn't reach terminal state
                    if self.use_lstm:
                        reset_lstm_state = True

                self.episode_reward = 0
                self.episode_steps = 0
                self.replay_mem_reset(demo_memory_idx=demo_memory_idx)
                break

        cumulative_reward = 0.0
        if not terminal_end:
            cumulative_reward = self.local_network.run_value(sess, s_t)

        actions.reverse()
        states.reverse()
        rewards.reverse()
        values.reverse()

        batch_state = []
        batch_action = []
        batch_adv = []
        batch_cumulative_reward = []

        # compute and accmulate gradients
        for (ai, ri, si, vi) in zip(actions, rewards, states, values):
            cumulative_reward = ri + self.gamma * cumulative_reward
            advantage = cumulative_reward - vi

            # convert action to one-hot vector
            a = np.zeros([self.action_size])
            a[ai] = 1

            batch_state.append(si)
            batch_action.append(a)
            batch_adv.append(advantage)
            batch_cumulative_reward.append(cumulative_reward)

        cur_learning_rate = self._anneal_learning_rate(global_t)  #* 0.005

        if self.use_lstm:
            batch_state.reverse()
            batch_action.reverse()
            batch_adv.reverse()
            batch_cumulative_reward.reverse()

            sess.run(self.apply_gradients,
                     feed_dict={
                         self.local_network.s: batch_state,
                         self.local_network.a: batch_action,
                         self.local_network.advantage: batch_adv,
                         self.local_network.cumulative_reward:
                         batch_cumulative_reward,
                         self.local_network.initial_lstm_state:
                         start_lstm_state,
                         self.local_network.step_size: [len(batch_action)],
                         self.learning_rate_input: cur_learning_rate
                     })

            # some demo episodes doesn't reach terminal state
            if reset_lstm_state:
                self.local_network.reset_state()
                reset_lstm_state = False
        else:
            sess.run(self.apply_gradients,
                     feed_dict={
                         self.local_network.s: batch_state,
                         self.local_network.a: batch_action,
                         self.local_network.advantage: batch_adv,
                         self.local_network.cumulative_reward: batch_R,
                         self.learning_rate_input: cur_learning_rate
                     })

        if (self.thread_index == 0) and (self.local_t - self.prev_local_t >=
                                         self.performance_log_interval):
            self.prev_local_t += self.performance_log_interval

        # return advancd local step size
        diff_local_t = self.local_t - start_local_t
        return diff_local_t, demo_ended

    def process(self, sess, global_t, train_rewards):
        states = []
        actions = []
        rewards = []
        values = []
        rho = []

        terminal_end = False

        # copy weights from shared to local
        sess.run(self.sync)

        start_local_t = self.local_t

        if self.use_lstm:
            start_lstm_state = self.local_network.lstm_state_out

        # t_max times loop
        for i in range(self.local_t_max):
            pi_, value_, logits_ = self.local_network.run_policy_and_value(
                sess, self.game_state.s_t)
            action = self.choose_action(logits_)

            model_pi = None
            confidence = 0.
            if self.use_pretrained_model_as_advice:
                self.psi = 0.9999 * (
                    0.9999**
                    global_t) if self.psi > 0.001 else 0.0  # 0.99995 works
                if self.psi > np.random.rand():
                    model_pi = self.pretrained_model.run_policy(
                        self.pretrained_model_sess, self.game_state.s_t)
                    model_action, confidence = self.choose_action_with_high_confidence(
                        model_pi, exclude_noop=False)
                    if (model_action > self.shaping_actions
                            and confidence >= self.advice_confidence):
                        action = model_action
                        self.advice_ctr += 1
            if self.use_pretrained_model_as_reward_shaping:
                #if action > 0:
                if model_pi is None:
                    model_pi = self.pretrained_model.run_policy(
                        self.pretrained_model_sess, self.game_state.s_t)
                    confidence = model_pi[action][0][0]
                if (action > self.shaping_actions
                        and confidence >= self.advice_confidence):
                    #rho.append(round(confidence, 5))
                    rho.append(self.shaping_reward)
                    self.shaping_ctr += 1
                else:
                    rho.append(0.)
                #self.shaping_ctr += 1

            states.append(self.game_state.s_t)
            actions.append(action)
            values.append(value_)

            if self.thread_index == 0 and self.local_t % self.log_interval == 0:
                log_msg1 = "lg={}".format(
                    np.array_str(logits_, precision=4, suppress_small=True))
                log_msg2 = "pi={}".format(
                    np.array_str(pi_, precision=4, suppress_small=True))
                log_msg3 = "V={:.4f}".format(value_)
                if self.use_pretrained_model_as_advice:
                    log_msg3 += " psi={:.4f}".format(self.psi)
                logger.debug(log_msg1)
                logger.debug(log_msg2)
                logger.debug(log_msg3)

            # process game
            self.game_state.step(action)

            # receive game result
            reward = self.game_state.reward
            terminal = self.game_state.terminal
            if self.use_pretrained_model_as_reward_shaping:
                if reward < 0 and reward > 0:
                    rho[i] = 0.
                    j = i - 1
                    while j > i - 5:
                        if rewards[j] != 0:
                            break
                        rho[j] = 0.
                        j -= 1
            #     if self.game_state.loss_life:
            #     if self.game_state.gain_life or reward > 0:
            #         rho[i] = 0.
            #         j = i-1
            #         k = 1
            #         while j >= 0:
            #             if rewards[j] != 0:
            #                 rho[j] = self.shaping_reward * (self.gamma ** -1)
            #                 break
            #             rho[j] = self.shaping_reward / k
            #             j -= 1
            #             k += 1

            self.episode_reward += reward

            if self.reward_type == 'LOG':
                reward = np.sign(reward) * np.log(1 + np.abs(reward))
            elif self.reward_type == 'CLIP':
                # clip reward
                reward = np.sign(reward)

            rewards.append(reward)

            self.local_t += 1
            self.episode_steps += 1
            global_t += 1

            # s_t1 -> s_t
            self.game_state.update()

            if terminal:
                if get_wrapper_by_name(self.game_state.env,
                                       'EpisodicLifeEnv').was_real_done:
                    log_msg = "train: worker={} global_t={}".format(
                        self.thread_index, global_t)
                    if self.use_pretrained_model_as_advice:
                        log_msg += " advice_ctr={}".format(self.advice_ctr)
                    if self.use_pretrained_model_as_reward_shaping:
                        log_msg += " shaping_ctr={}".format(self.shaping_ctr)
                    score_str = colored("score={}".format(self.episode_reward),
                                        "magenta")
                    steps_str = colored("steps={}".format(self.episode_steps),
                                        "blue")
                    log_msg += " {} {}".format(score_str, steps_str)
                    logger.debug(log_msg)
                    train_rewards['train'][global_t] = (self.episode_reward,
                                                        self.episode_steps)
                    self.record_summary(score=self.episode_reward,
                                        steps=self.episode_steps,
                                        episodes=None,
                                        global_t=global_t,
                                        mode='Train')
                    self.episode_reward = 0
                    self.episode_steps = 0
                    terminal_end = True

                self.last_rho = 0.
                if self.use_lstm:
                    self.local_network.reset_state()
                self.game_state.reset(hard_reset=False)
                break

        cumulative_reward = 0.0
        if not terminal:
            cumulative_reward = self.local_network.run_value(
                sess, self.game_state.s_t)

        actions.reverse()
        states.reverse()
        rewards.reverse()
        values.reverse()

        batch_state = []
        batch_action = []
        batch_adv = []
        batch_cumulative_reward = []

        if self.use_pretrained_model_as_reward_shaping:
            rho.reverse()
            rho.append(self.last_rho)
            self.last_rho = rho[0]
            i = 0
            # compute and accumulate gradients
            for (ai, ri, si, vi) in zip(actions, rewards, states, values):
                # Wiewiora et al.(2003) Principled Methods for Advising RL agents
                # Look-Back Advice
                #F = rho[i] - (self.shaping_gamma**-1) * rho[i+1]
                #F = rho[i] - self.shaping_gamma * rho[i+1]
                f = (self.shaping_gamma**-1) * rho[i] - rho[i + 1]
                if (i == 0 and terminal) or (f != 0 and (ri > 0 or ri < 0)):
                    #logger.warn("averted additional F in absorbing state")
                    F = 0.
                # if (F < 0. and ri > 0) or (F > 0. and ri < 0):
                #     logger.warn("Negative reward shaping F={} ri={} rho[s]={} rhos[s-1]={}".format(F, ri, rho[i], rho[i+1]))
                #     F = 0.
                cumulative_reward = (ri + f * self.shaping_factor
                                     ) + self.gamma * cumulative_reward
                advantage = cumulative_reward - vi

                a = np.zeros([self.action_size])
                a[ai] = 1

                batch_state.append(si)
                batch_action.append(a)
                batch_adv.append(advantage)
                batch_cumulative_reward.append(cumulative_reward)
                i += 1
        else:

            def h(z, eps=10**-2):
                return (np.sign(z) *
                        (np.sqrt(np.abs(z) + 1.) - 1.)) + (eps * z)

            def h_inv(z, eps=10**-2):
                return np.sign(z) * (np.square(
                    (np.sqrt(1 + 4 * eps *
                             (np.abs(z) + 1 + eps)) - 1) / (2 * eps)) - 1)

            def h_log(z, eps=.6):
                return (np.sign(z) * np.log(1. + np.abs(z)) * eps)

            def h_inv_log(z, eps=.6):
                return np.sign(z) * (np.exp(np.abs(z) / eps) - 1)

            # compute and accumulate gradients
            for (ai, ri, si, vi) in zip(actions, rewards, states, values):
                if self.transformed_bellman:
                    cumulative_reward = h(ri + self.gamma *
                                          h_inv(cumulative_reward))
                else:
                    cumulative_reward = ri + self.gamma * cumulative_reward
                advantage = cumulative_reward - vi

                # convert action to one-hot vector
                a = np.zeros([self.action_size])
                a[ai] = 1

                batch_state.append(si)
                batch_action.append(a)
                batch_adv.append(advantage)
                batch_cumulative_reward.append(cumulative_reward)

        cur_learning_rate = self._anneal_learning_rate(global_t)

        if self.use_lstm:
            batch_state.reverse()
            batch_action.reverse()
            batch_adv.reverse()
            batch_cumulative_reward.reverse()

            sess.run(self.apply_gradients,
                     feed_dict={
                         self.local_network.s: batch_state,
                         self.local_network.a: batch_action,
                         self.local_network.advantage: batch_adv,
                         self.local_network.cumulative_reward:
                         batch_cumulative_reward,
                         self.local_network.initial_lstm_state:
                         start_lstm_state,
                         self.local_network.step_size: [len(batch_action)],
                         self.learning_rate_input: cur_learning_rate
                     })
        else:
            sess.run(self.apply_gradients,
                     feed_dict={
                         self.local_network.s: batch_state,
                         self.local_network.a: batch_action,
                         self.local_network.advantage: batch_adv,
                         self.local_network.cumulative_reward:
                         batch_cumulative_reward,
                         self.learning_rate_input: cur_learning_rate
                     })

        if (self.thread_index == 0) and (self.local_t - self.prev_local_t >=
                                         self.performance_log_interval):
            self.prev_local_t += self.performance_log_interval
            elapsed_time = time.time() - self.start_time
            steps_per_sec = global_t / elapsed_time
            logger.info(
                "Performance : {} STEPS in {:.0f} sec. {:.0f} STEPS/sec. {:.2f}M STEPS/hour"
                .format(global_t, elapsed_time, steps_per_sec,
                        steps_per_sec * 3600 / 1000000.))

        # return advanced local step size
        diff_local_t = self.local_t - start_local_t
        return diff_local_t, terminal_end
class RefreshThread(CommonWorker):
    """Rollout Thread Class."""
    advice_confidence = 0.8
    gamma = 0.99

    def __init__(self, thread_index, action_size, env_id,
                 global_a3c, local_a3c, update_in_rollout, nstep_bc,
                 global_pretrained_model, local_pretrained_model,
                 transformed_bellman=False, no_op_max=0,
                 device='/cpu:0', entropy_beta=0.01, clip_norm=None,
                 grad_applier=None, initial_learn_rate=0.007,
                 learning_rate_input=None):
        """Initialize RolloutThread class."""
        self.is_refresh_thread = True
        self.action_size = action_size
        self.thread_idx = thread_index
        self.transformed_bellman = transformed_bellman
        self.entropy_beta = entropy_beta
        self.clip_norm = clip_norm
        self.initial_learning_rate = initial_learn_rate
        self.learning_rate_input = learning_rate_input

        self.no_op_max = no_op_max
        self.override_num_noops = 0 if self.no_op_max == 0 else None

        logger.info("===REFRESH thread_index: {}===".format(self.thread_idx))
        logger.info("device: {}".format(device))
        logger.info("action_size: {}".format(self.action_size))
        logger.info("reward_type: {}".format(self.reward_type))
        logger.info("transformed_bellman: {}".format(
            colored(self.transformed_bellman,
                    "green" if self.transformed_bellman else "red")))
        logger.info("update in rollout: {}".format(
            colored(update_in_rollout, "green" if update_in_rollout else "red")))
        logger.info("N-step BC: {}".format(nstep_bc))

        self.reward_clipped = True if self.reward_type == 'CLIP' else False

        # setup local a3c
        self.local_a3c = local_a3c
        self.sync_a3c = self.local_a3c.sync_from(global_a3c)
        with tf.device(device):
            local_vars = self.local_a3c.get_vars
            self.local_a3c.prepare_loss(
                entropy_beta=self.entropy_beta, critic_lr=0.5)
            var_refs = [v._ref() for v in local_vars()]
            self.rollout_gradients = tf.gradients(self.local_a3c.total_loss, var_refs)
            global_vars = global_a3c.get_vars
            if self.clip_norm is not None:
                self.rollout_gradients, grad_norm = tf.clip_by_global_norm(
                    self.rollout_gradients, self.clip_norm)
            self.rollout_gradients = list(zip(self.rollout_gradients, global_vars()))
            self.rollout_apply_gradients = grad_applier.apply_gradients(self.rollout_gradients)

        # setup local pretrained model
        self.local_pretrained = None
        if nstep_bc > 0:
            assert local_pretrained_model is not None
            assert global_pretrained_model is not None
            self.local_pretrained = local_pretrained_model
            self.sync_pretrained = self.local_pretrained.sync_from(global_pretrained_model)

        # setup env
        self.rolloutgame = GameState(env_id=env_id, display=False,
                            no_op_max=0, human_demo=False, episode_life=True,
                            override_num_noops=0)
        self.local_t = 0
        self.episode_reward = 0
        self.episode_steps = 0

        self.action_meaning = self.rolloutgame.env.unwrapped.get_action_meanings()

        assert self.local_a3c is not None
        if nstep_bc > 0:
            assert self.local_pretrained is not None

        self.episode = SILReplayMemory(
            self.action_size, max_len=None, gamma=self.gamma,
            clip=self.reward_clipped,
            height=self.local_a3c.in_shape[0],
            width=self.local_a3c.in_shape[1],
            phi_length=self.local_a3c.in_shape[2],
            reward_constant=self.reward_constant)


    def record_rollout(self, score=0, steps=0, old_return=0, new_return=0,
                       global_t=0, rollout_ctr=0, rollout_added_ctr=0,
                       mode='Rollout', confidence=None, episodes=None):
        """Record rollout summary."""
        summary = tf.Summary()
        summary.value.add(tag='{}/score'.format(mode),
                          simple_value=float(score))
        summary.value.add(tag='{}/old_return_from_s'.format(mode),
                          simple_value=float(old_return))
        summary.value.add(tag='{}/new_return_from_s'.format(mode),
                          simple_value=float(new_return))
        summary.value.add(tag='{}/steps'.format(mode),
                          simple_value=float(steps))
        summary.value.add(tag='{}/all_rollout_ctr'.format(mode),
                          simple_value=float(rollout_ctr))
        summary.value.add(tag='{}/rollout_added_ctr'.format(mode),
                          simple_value=float(rollout_added_ctr))
        if confidence is not None:
            summary.value.add(tag='{}/advice-confidence'.format(mode),
                              simple_value=float(confidence))
        if episodes is not None:
            summary.value.add(tag='{}/episodes'.format(mode),
                              simple_value=float(episodes))
        self.writer.add_summary(summary, global_t)
        self.writer.flush()

    def compute_return_for_state(self, rewards, terminal):
        """Compute expected return."""
        length = np.shape(rewards)[0]
        returns = np.empty_like(rewards, dtype=np.float32)

        if self.reward_clipped:
            rewards = np.clip(rewards, -1., 1.)
        else:
            rewards = np.sign(rewards) * self.reward_constant + rewards

        for i in reversed(range(length)):
            if terminal[i]:
                returns[i] = rewards[i] if self.reward_clipped else transform_h(rewards[i])
            else:
                if self.reward_clipped:
                    returns[i] = rewards[i] + self.gamma * returns[i+1]
                else:
                    # apply transformed expected return
                    exp_r_t = self.gamma * transform_h_inv(returns[i+1])
                    returns[i] = transform_h(rewards[i] + exp_r_t)
        return returns[0]

    def update_a3c(self, sess, actions, states, rewards, values, global_t):
        cumsum_reward = 0.0
        actions.reverse()
        states.reverse()
        rewards.reverse()
        values.reverse()

        batch_state = []
        batch_action = []
        batch_adv = []
        batch_cumsum_reward = []

        # compute and accumulate gradients
        for(ai, ri, si, vi) in zip(actions, rewards, states, values):
            if self.transformed_bellman:
                ri = np.sign(ri) * self.reward_constant + ri
                cumsum_reward = transform_h(
                    ri + self.gamma * transform_h_inv(cumsum_reward))
            else:
                cumsum_reward = ri + self.gamma * cumsum_reward
            advantage = cumsum_reward - vi

            # convert action to one-hot vector
            a = np.zeros([self.action_size])
            a[ai] = 1

            batch_state.append(si)
            batch_action.append(a)
            batch_adv.append(advantage)
            batch_cumsum_reward.append(cumsum_reward)

        cur_learning_rate = self._anneal_learning_rate(global_t,
                self.initial_learning_rate )

        feed_dict = {
            self.local_a3c.s: batch_state,
            self.local_a3c.a: batch_action,
            self.local_a3c.advantage: batch_adv,
            self.local_a3c.cumulative_reward: batch_cumsum_reward,
            self.learning_rate_input: cur_learning_rate,
            }

        sess.run(self.rollout_apply_gradients, feed_dict=feed_dict)

        return batch_adv

    def rollout(self, a3c_sess, folder, pretrain_sess, global_t, past_state,
                add_all_rollout, ep_max_steps, nstep_bc, update_in_rollout):
        """Perform one rollout until terminal."""
        a3c_sess.run(self.sync_a3c)
        if nstep_bc > 0:
            pretrain_sess.run(self.sync_pretrained)

        _, fs, old_a, old_return, _, _ = past_state

        states = []
        actions = []
        rewards = []
        values = []
        terminals = []
        confidences = []

        rollout_ctr, rollout_added_ctr = 0, 0
        rollout_new_return, rollout_old_return = 0, 0

        terminal_pseudo = False  # loss of life
        terminal_end = False  # real terminal
        add = False

        self.rolloutgame.reset(hard_reset=True)
        self.rolloutgame.restore_full_state(fs)
        # check if restore successful
        fs_check = self.rolloutgame.clone_full_state()
        assert fs_check.all() == fs.all()
        del fs_check

        start_local_t = self.local_t
        self.rolloutgame.step(0)

        # prevent rollout too long, set max_ep_steps to be lower than ALE default
        # see https://github.com/openai/gym/blob/54f22cf4db2e43063093a1b15d968a57a32b6e90/gym/envs/__init__.py#L635
        # but in all games tested, no rollout exceeds ep_max_steps
        while ep_max_steps > 0:
            state = cv2.resize(self.rolloutgame.s_t,
                       self.local_a3c.in_shape[:-1],
                       interpolation=cv2.INTER_AREA)
            fullstate = self.rolloutgame.clone_full_state()

            if nstep_bc > 0: # LiDER-TA or BC
                model_pi = self.local_pretrained.run_policy(pretrain_sess, state)
                action, confidence = self.choose_action_with_high_confidence(
                                          model_pi, exclude_noop=False)
                confidences.append(confidence) # not using "confidences" for anything
                nstep_bc -= 1
            else: # LiDER, refresh with current policy
                pi_, _, logits_ = self.local_a3c.run_policy_and_value(a3c_sess,
                                                                      state)
                action = self.pick_action(logits_)
                confidences.append(pi_[action])

            value_ = self.local_a3c.run_value(a3c_sess, state)
            values.append(value_)
            states.append(state)
            actions.append(action)

            self.rolloutgame.step(action)

            ep_max_steps -= 1

            reward = self.rolloutgame.reward
            terminal = self.rolloutgame.terminal
            terminals.append(terminal)

            self.episode_reward += reward

            self.episode.add_item(self.rolloutgame.s_t, fullstate, action,
                                  reward, terminal, from_rollout=True)

            if self.reward_type == 'CLIP':
                reward = np.sign(reward)
            rewards.append(reward)

            self.local_t += 1
            self.episode_steps += 1
            global_t += 1

            self.rolloutgame.update()

            if terminal:
                terminal_pseudo = True
                env = self.rolloutgame.env
                name = 'EpisodicLifeEnv'
                rollout_ctr += 1
                terminal_end = get_wrapper_by_name(env, name).was_real_done

                new_return = self.compute_return_for_state(rewards, terminals)

                if not add_all_rollout:
                    if new_return > old_return:
                        add = True
                else:
                    add = True

                if add:
                    rollout_added_ctr += 1
                    rollout_new_return += new_return
                    rollout_old_return += old_return
                    # update policy immediate using a good rollout
                    if update_in_rollout:
                        batch_adv = self.update_a3c(a3c_sess, actions, states, rewards, values, global_t)

                self.episode_reward = 0
                self.episode_steps = 0
                self.rolloutgame.reset(hard_reset=True)
                break

        diff_local_t = self.local_t - start_local_t

        return diff_local_t, terminal_end, terminal_pseudo, rollout_ctr, \
               rollout_added_ctr, add, rollout_new_return, rollout_old_return