Exemple #1
0
 def update_batch_priorities(self, priorities):
     """Takes in new priorities (i.e. from the algorithm after a training
     step) and sends them to priority tree as ``priorities ** alpha``; the
     tree internally remembers which indexes were sampled for this batch.
     """
     priorities = numpify_buffer(priorities)
     self.priority_tree.update_batch_priorities(priorities**self.alpha)
Exemple #2
0
 def collect_evaluation(self, itr):
     traj_infos = [self.TrajInfoCls() for _ in range(len(self.envs))]
     observations = list()
     for env in self.envs:
         observations.append(env.reset())
     observation = buffer_from_example(observations[0], len(self.envs))
     for b, o in enumerate(observations):
         observation[b] = o
     action = buffer_from_example(self.envs[0].action_space.null_value(),
                                  len(self.envs))
     reward = np.zeros(len(self.envs), dtype="float32")
     obs_pyt, act_pyt, rew_pyt = torchify_buffer(
         (observation, action, reward))
     self.agent.reset()
     self.agent.eval_mode(itr)
     for t in range(self.max_T):
         act_pyt, agent_info = self.agent.step(obs_pyt, act_pyt, rew_pyt)
         action = numpify_buffer(act_pyt)
         for b, env in enumerate(self.envs):
             o, r, d, env_info = env.step(action[b])
             traj_infos[b].step(observation[b], action[b], r, d,
                                agent_info[b], env_info)
             if getattr(env_info, "traj_done", d):
                 self.traj_infos_queue.put(traj_infos[b].terminate(o))
                 traj_infos[b] = self.TrajInfoCls()
                 o = env.reset()
             if d:
                 action[b] = 0  # Next prev_action.
                 r = 0
                 self.agent.reset_one(idx=b)
             observation[b] = o
             reward[b] = r
         if self.sync.stop_eval.value:
             break
     self.traj_infos_queue.put(None)  # End sentinel.
Exemple #3
0
 def collect_evaluation(self, itr):
     traj_infos = [self.TrajInfoCls() for _ in range(len(self.envs))]
     completed_traj_infos = list()
     observations = list()
     for env in self.envs:
         observations.append(env.reset())
     observation = buffer_from_example(observations[0], len(self.envs))
     action = buffer_from_example(
         self.envs[0].action_space.sample(null=True), len(self.envs))
     reward = np.zeros(len(self.envs), dtype="float32")
     obs_pyt, act_pyt, rew_pyt = torchify_buffer(
         (observation, action, reward))
     self.agent.reset()
     for t in range(self.max_T):
         act_pyt, agent_info = self.agent.step(obs_pyt, act_pyt, rew_pyt)
         action = numpify_buffer(act_pyt)
         for b, env in enumerate(self.envs):
             o, r, d, env_info = env.step(action[b])
             traj_infos[b].step(observation[b], action[b], r, d,
                                agent_info[b], env_info)
             if getattr(env_info, "traj_done", d):
                 completed_traj_infos.append(traj_infos[b].terminate(o))
                 traj_infos[b] = self.TrajInfoCls()
                 o = env.reset()
             if d:
                 action[b] = 0  # Prev_action for next step.
                 r = 0
                 self.agent.reset_one(idx=b)
             observation[b] = o
             reward[b] = r
         if (self.max_trajectories is not None
                 and len(completed_traj_infos) >= self.max_trajectories):
             break
     return completed_traj_infos
Exemple #4
0
    def collect_batch(self, agent_inputs, traj_infos, itr):
        # Numpy arrays can be written to from numpy arrays or torch tensors
        # (whereas torch tensors can only be written to from torch tensors).
        agent_buf, env_buf = self.samples_np.agent, self.samples_np.env
        completed_infos = list()
        observation, action, reward = agent_inputs
        b = np.where(self.done)[0]
        observation[b] = self.temp_observation[b]
        self.done[:] = False  # Did resets between batches.
        obs_pyt, act_pyt, rew_pyt = torchify_buffer(agent_inputs)
        agent_buf.prev_action[0] = action  # Leading prev_action.

        if env_buf.prev_reward[0].ndim > reward.ndim:
            reward = reward[:, None].repeat(env_buf.prev_reward[0].shape[-1],
                                            -1)
        env_buf.prev_reward[0] = reward

        self.agent.sample_mode(itr)
        for t in range(self.batch_T):
            env_buf.observation[t] = observation
            # Agent inputs and outputs are torch tensors.
            act_pyt, agent_info = self.agent.step(obs_pyt, act_pyt, rew_pyt)
            action = numpify_buffer(act_pyt)
            for b, env in enumerate(self.envs):
                if self.done[b]:
                    action[b] = 0  # Record blank.
                    reward[b] = 0
                    if agent_info:
                        agent_info[b] = 0
                    # Leave self.done[b] = True, record that.
                    continue
                # Environment inputs and outputs are numpy arrays.
                o, r, d, env_info = env.step(action[b])
                traj_infos[b].step(observation[b], action[b], r, d,
                                   agent_info[b], env_info)
                if getattr(env_info, "traj_done", d):
                    completed_infos.append(traj_infos[b].terminate(o))
                    traj_infos[b] = self.TrajInfoCls()
                    self.need_reset[b] = True
                if d:
                    self.temp_observation[b] = o
                    o = 0  # Record blank.
                observation[b] = o
                reward[b] = r
                self.done[b] = d
                if env_info:
                    env_buf.env_info[t, b] = env_info
            agent_buf.action[t] = action
            env_buf.reward[t] = reward
            env_buf.done[t] = self.done
            if agent_info:
                agent_buf.agent_info[t] = agent_info

        if "bootstrap_value" in agent_buf:
            # agent.value() should not advance rnn state.
            agent_buf.bootstrap_value[:] = self.agent.value(
                obs_pyt, act_pyt, rew_pyt)

        return AgentInputs(observation, action,
                           reward), traj_infos, completed_infos
Exemple #5
0
    def collect_evaluation(self, itr):
        assert self.max_trajectories == len(self.envs)
        traj_infos = [self.TrajInfoCls() for _ in range(len(self.envs))]
        completed_traj_infos = list()
        observations = list()
        for env in self.envs:
            observations.append(env.reset())
        observation = buffer_from_example(observations[0], len(self.envs))
        for b, o in enumerate(observations):
            observation[b] = o
        action = buffer_from_example(self.envs[0].action_space.null_value(),
                                     len(self.envs))
        reward = np.zeros(len(self.envs), dtype="float32")
        obs_pyt, act_pyt, rew_pyt = torchify_buffer(
            (observation, action, reward))
        self.agent.reset()
        self.agent.eval_mode(itr)
        live_envs = list(range(len(self.envs)))
        for t in range(self.max_T):
            act_pyt, agent_info = self.agent.step(obs_pyt, act_pyt, rew_pyt)
            action = numpify_buffer(act_pyt)

            b = 0
            while b < len(
                    live_envs
            ):  # don't want to do a for loop since live envs changes over time
                env_id = live_envs[b]
                o, r, d, env_info = self.envs[env_id].step(action[b])
                traj_infos[env_id].step(observation[b], action[b], r, d,
                                        agent_info[b], env_info)
                if getattr(env_info, "traj_done", d):
                    completed_traj_infos.append(
                        traj_infos[env_id].terminate(o))

                    observation = delete_ind_from_array(observation, b)
                    reward = delete_ind_from_array(reward, b)
                    action = delete_ind_from_array(action, b)
                    obs_pyt, act_pyt, rew_pyt = torchify_buffer(
                        (observation, action, reward))

                    del live_envs[b]
                    b -= 1  # live_envs[b] is now the next env, so go back one.
                else:
                    observation[b] = o
                    reward[b] = r

                b += 1

                if (self.max_trajectories is not None and
                        len(completed_traj_infos) >= self.max_trajectories):
                    logger.log("Evaluation reached max num trajectories "
                               f"({self.max_trajectories}).")
                    return completed_traj_infos

        if t == self.max_T - 1:
            logger.log("Evaluation reached max num time steps "
                       f"({self.max_T}).")
        return completed_traj_infos
    def obtain_samples(self, itr, mode='sample'):
        agent_buf, env_buf = self.samples_np.agent, self.samples_np.env

        # Reset agent inputs
        observation, action, reward = self.agent_inputs
        obs_pyt, act_pyt, rew_pyt = torchify_buffer(self.agent_inputs)
        action[:], reward[:] = self.env.action_space.null_value(
        ), 0  # reset agent inputs

        # reset environment and agent
        observation[:] = self.env.reset()
        self.agent.reset()
        agent_buf.prev_action[0], env_buf.prev_reward[
            0] = action, reward  # Leading prev_action.

        # perform episode
        if mode == 'sample':
            self.agent.sample_mode(itr)
        elif mode == 'eval':
            self.agent.eval_mode(itr)
        traj_infos = [
            self.TrajInfoCls(**self.traj_info_kwargs)
            for _ in range(self.batch_spec.B)
        ]
        for t in range(self.batch_spec.T):
            env_buf.observation[t] = observation

            act_pyt, agent_info = self.agent.step(obs_pyt, act_pyt, rew_pyt)
            action = numpify_buffer(
                act_pyt
            )  # todo why doing this? they are sharing the same memory

            o, r, _, env_info = self.env.step(action)
            d = (t == self.batch_spec.T - 1)
            for b in range(self.batch_spec.B):
                traj_infos[b].step(observation[b], action[b], r[b], d,
                                   agent_info[b], env_info)
                if env_info:
                    env_buf.env_info[t, b] = env_info
            observation[:] = o
            reward[:] = r

            agent_buf.action[t] = action
            env_buf.reward[t] = reward
            if agent_info:
                agent_buf.agent_info[t] = agent_info

        if "bootstrap_value" in agent_buf:
            agent_buf.bootstrap_value[:] = self.agent.value(
                obs_pyt, act_pyt, rew_pyt)

        return self.samples_pyt, traj_infos
Exemple #7
0
    def collect_batch(self, agent_inputs, traj_infos, itr):
        # Numpy arrays can be written to from numpy arrays or torch tensors
        # (whereas torch tensors can only be written to from torch tensors).
        agent_buf, env_buf = self.samples_np.agent, self.samples_np.env
        completed_infos = list()
        observation, action, reward = agent_inputs
        obs_pyt, act_pyt, rew_pyt = torchify_buffer(agent_inputs)
        agent_buf.prev_action[0] = action  # Leading prev_action.

        if env_buf.prev_reward[0].ndim > reward.ndim:
            reward = reward[:, None].repeat(env_buf.prev_reward[0].shape[-1],
                                            -1)
        env_buf.prev_reward[0] = reward

        self.agent.sample_mode(itr)
        for t in range(self.batch_T):
            env_buf.observation[t] = observation
            # Agent inputs and outputs are torch tensors.
            act_pyt, agent_info = self.agent.step(obs_pyt, act_pyt, rew_pyt)
            action = numpify_buffer(act_pyt)
            for b, env in enumerate(self.envs):
                # Environment inputs and outputs are numpy arrays.
                o, r, d, env_info = env.step(action[b])
                traj_infos[b].step(observation[b], action[b], r, d,
                                   agent_info[b], env_info)
                if getattr(env_info, "traj_done", d):
                    completed_infos.append(traj_infos[b].terminate(o))
                    traj_infos[b] = self.TrajInfoCls()
                    o = env.reset()
                if d:
                    self.agent.reset_one(idx=b)
                observation[b] = o
                reward[b] = r
                env_buf.done[t, b] = d
                # if d or  getattr(env_info, "traj_done", d):
                #     import pdb; pdb.set_trace()
                if env_info:
                    env_buf.env_info[t, b] = env_info
            agent_buf.action[t] = action
            env_buf.reward[t] = reward
            if agent_info:
                agent_buf.agent_info[t] = agent_info

        if "bootstrap_value" in agent_buf:
            # agent.value() should not advance rnn state.
            agent_buf.bootstrap_value[:] = self.agent.value(
                obs_pyt, act_pyt, rew_pyt)

        return AgentInputs(observation, action,
                           reward), traj_infos, completed_infos
Exemple #8
0
 def collect_evaluation(self, itr, include_observations=False):
     traj_infos = [
         self.TrajInfoCls(include_observations=include_observations)
         for _ in range(len(self.envs))
     ]
     completed_traj_infos = list()
     observations = list()
     for env in self.envs:
         observations.append(env.reset())
     observation = buffer_from_example(observations[0], len(self.envs))
     action = buffer_from_example(self.envs[0].action_space.null_value(),
                                  len(self.envs))
     reward = np.zeros(len(self.envs), dtype="float32")
     obs_pyt, act_pyt, rew_pyt = torchify_buffer(
         (observation, action, reward))
     self.agent.reset()
     self.agent.eval_mode(itr)
     for t in range(self.max_T):
         act_pyt, agent_info = self.agent.step(obs_pyt, act_pyt, rew_pyt)
         action = numpify_buffer(act_pyt)
         for b, env in enumerate(self.envs):
             o, r, d, env_info = env.step(action[b])
             if include_observations:
                 traj_infos[b].step(env.render(), action[b], r, d,
                                    agent_info[b], env_info)
             else:
                 traj_infos[b].step(observation[b], action[b], r, d,
                                    agent_info[b], env_info)
             if getattr(env_info, "traj_done", d):
                 completed_traj_infos.append(traj_infos[b].terminate(o))
                 traj_infos[b] = self.TrajInfoCls(
                     include_observations=include_observations)
                 o = env.reset()
             if d:
                 action[b] = 0  # Prev_action for next step.
                 r = 0
                 self.agent.reset_one(idx=b)
             observation[b] = o
             reward[b] = r
         if (self.max_trajectories is not None
                 and len(completed_traj_infos) >= self.max_trajectories):
             logger.log("Evaluation reached max num trajectories "
                        f"({self.max_trajectories}).")
             break
     if t == self.max_T - 1:
         logger.log("Evaluation reached max num time steps "
                    f"({self.max_T}).")
     return completed_traj_infos
Exemple #9
0
    def collect_batch(self, agent_inputs, traj_infos, itr):
        # Numpy arrays can be written to from numpy arrays or torch tensors
        # (whereas torch tensors can only be written to from torch tensors).
        agent_buf, env_buf = self.samples_np.agent, self.samples_np.env
        completed_infos = list()
        observation, action, reward = agent_inputs
        obs_pyt, act_pyt, rew_pyt = torchify_buffer(agent_inputs)
        agent_buf.prev_action[0] = action  # Leading prev_action.
        env_buf.prev_reward[0] = reward
        self.agent.sample_mode(itr)
        for t in range(self.batch_T):
            env_buf.observation[t] = observation
            # Agent inputs and outputs are torch tensors.
            act_pyt, agent_info = self.agent.step(obs_pyt, act_pyt, rew_pyt)
            action = numpify_buffer(act_pyt)
            for b, env in enumerate(self.envs):
                # Environment inputs and outputs are numpy arrays.
                o, r, d, env_info = env.step(action[b])
                traj_infos[b].step(observation[b], action[b], r, d,
                                   agent_info[b], env_info)
                if getattr(env_info, "traj_done", d):
                    completed_infos.append(traj_infos[b].terminate(o))
                    traj_infos[b] = self.TrajInfoCls()
                    o = env.reset()
                if d:
                    self.agent.reset_one(idx=b)
                observation[b] = o
                reward[b] = r
                env_buf.done[t, b] = d
                if env_info:
                    env_buf.env_info[t, b] = env_info
            agent_buf.action[t] = action
            env_buf.reward[t] = reward
            if agent_info:
                agent_buf.agent_info[t] = agent_info
            if "next_observation" in env_buf:  # Modified to include next_obs
                env_buf.next_observation[t] = observation

        # Modified to include int_bootstrap_value
        if "bootstrap_value" in agent_buf:
            bootstraps = self.agent.value(obs_pyt, act_pyt, rew_pyt)
            agent_buf.bootstrap_value[:] = bootstraps.ext_value
            agent_buf.int_bootstrap_value[:] = bootstraps.int_value

        return AgentInputs(observation, action,
                           reward), traj_infos, completed_infos
def simulate_policy(env, agent):
    # snapshot = torch.load(path_to_params, map_location=torch.device('cpu'))
    # agent_state_dict = snapshot['agent_state_dict']
    # env = GymEnvWrapper(gym.make(env_id, render=True))
    # env = gym.make('HopperPyBulletEnv-v0')
    # env.render(mode='human')
    # env = GymEnvWrapper(env)
    # agent_kwargs = dict(ModelCls=PiMcpVisionModel, QModelCls=QofMcpVisionModel)
    # agent = SacAgent(**agent_kwargs)
    # agent = SacAgent(model_kwargs=dict(hidden_sizes=[512,256, 256]), q_model_kwargs=dict(hidden_sizes=[512, 256, 256]))
    # agent = MujocoFfAgent(ModelCls=PPOMcpModel)
    # agent.initialize(env_spaces=env.spaces)
    # agent.load_state_dict(agent_state_dict)
    # agent.eval_mode(0)
    obs = env.reset()
    observation = buffer_from_example(obs, 1)
    loop_time = 0.04
    while True:
        observation[0] = env.reset()
        action = buffer_from_example(env.action_space.null_value(), 1)
        reward = np.zeros(1, dtype="float32")
        obs_pyt, act_pyt, rew_pyt = torchify_buffer(
            (observation, action, reward))
        done = False
        step = 0
        reward_sum = 0
        env.render()
        # time.sleep(5)
        while not done:
            loop_start = time.time()
            step += 1
            act_pyt, agent_info = agent.step(obs_pyt, act_pyt, rew_pyt)
            action = numpify_buffer(act_pyt)
            obs, reward, done, info = env.step(action[0])
            reward_sum += reward
            observation[0] = obs
            rew_pyt[0] = reward
            sleep_time = loop_time - (time.time() - loop_start)
            sleep_time = 0 if (sleep_time < 0) else sleep_time
            time.sleep(sleep_time)
            env.render(mode='human')
        print('return: ' + str(reward_sum) + '  num_steps: ' + str(step))
Exemple #11
0
    def _collect_batch(self, itr):
        """Collect batch of experience from environment (collector.collect_batch)"""
        agent_buf, env_buf = self.samples_np.agent, self.samples_np.env
        completed_infos = list()
        o, a, r = self.agent_inputs  # Previous last inputs
        o_p, a_p, r_p = torchify_buffer(self.agent_inputs)
        self.agent.sample_mode(itr)
        agent_buf.prev_action[0] = a  # Store previous action
        env_buf.prev_reward[0] = r  # Store previous reward
        for t in range(self.batch_spec.T):
            env_buf.observation[t] = o  # Store observation
            # Agent inputs and outputs are torch tensors.
            a_p, agent_info = self.agent.step(o_p, a_p, r_p)
            a = numpify_buffer(a_p)
            o[:], r[:], d, info = self.env.step(a)
            self.traj_infos.step(o,
                                 a,
                                 r,
                                 d,
                                 agent_info,
                                 info,
                                 reset_dones=False)
            # Get completed infos (non-tensor). Environment auto-resets
            completed_infos += self.traj_infos.terminate(d)
            if np.sum(d): self.agent.reset_multiple(indexes=d)
            env_buf.done[t] = d
            env_buf.reward[t] = r
            agent_buf.action[t] = a
            if info:
                env_buf.env_info[t] = info
            if agent_info:
                agent_buf.agent_info[t] = agent_info

        if "bootstrap_value" in agent_buf:
            # agent.value() should not advance rnn state.
            agent_buf.bootstrap_value[:] = self.agent.value(o_p, a_p, r_p)

        return AgentInputs(o, a, r), self.traj_infos, completed_infos
Exemple #12
0
def simulate_policy(env, agent, render):
    static_decoder_path = './qec/referee_decoders/nn_d5_DP_p5'
    static_decoder = load_model(static_decoder_path, compile=True)
    obs = env.reset()
    observation = buffer_from_example(obs, 1)
    loop_time = 0.01
    returns = []
    mses = []
    lifetimes = []
    while True:
        observation[0] = env.reset()
        action = buffer_from_example(env.action_space.null_value(), 1)
        reward = np.zeros(1, dtype="float32")
        obs_pyt, act_pyt, rew_pyt = torchify_buffer(
            (observation, action, reward))
        agent.reset()
        done = False
        step = 0
        reward_sum = 0
        while not done:
            loop_start = time.time()
            step += 1
            act_pyt, agent_info = agent.step(obs_pyt, act_pyt, rew_pyt)
            action = numpify_buffer(act_pyt)[0]
            obs, reward, done, info = env.step(action)
            # done = np.argmax(static_decoder(info.static_decoder_input)[0]) != info.correct_label
            reward_sum += reward
            observation[0] = obs
            rew_pyt[0] = float(reward)

        returns.append(reward_sum)
        lifetimes.append(info.lifetime)
        print('avg return: ' + str(sum(returns) / len(returns)) + ' return: ' +
              str(reward_sum) + '  num_steps: ' + str(step))
        print(
            f'average lifetime: {sum(lifetimes)/len(lifetimes)} lifetime: {info.lifetime}'
        )
Exemple #13
0
    def collect_batch(self, agent_inputs, traj_infos, itr):
        # Numpy arrays can be written to from numpy arrays or torch tensors
        # (whereas torch tensors can only be written to from torch tensors).
        agent_buf, env_buf = self.samples_np.agent, self.samples_np.env
        completed_infos = list()
        observation, action, reward = agent_inputs
        obs_pyt, act_pyt, rew_pyt = torchify_buffer(agent_inputs)
        agent_buf.prev_action[0] = action  # Leading prev_action.
        env_buf.prev_reward[0] = reward
        self.agent.sample_mode(itr)

        if not hasattr(self, 'last_itr_envs_updated'):
            self.last_itr_envs_updated = [itr] * len(self.envs)

        for t in range(self.batch_T):
            env_buf.observation[t] = observation
            # Agent inputs and outputs are torch tensors.
            act_pyt, agent_info = self.agent.step(obs_pyt, act_pyt, rew_pyt)
            action = numpify_buffer(act_pyt)
            for b, env in enumerate(self.envs):
                # Environment inputs and outputs are numpy arrays.
                o, r, d, env_info = env.step(action[b])
                traj_infos[b].step(observation[b], action[b], r, d, agent_info[b],
                                   env_info)
                if getattr(env_info, "traj_done", d):
                    completed_infos.append(traj_infos[b].terminate(o))
                    traj_infos[b] = self.TrajInfoCls()
                    o = env.reset()

                    # Prioritized level replay related
                    if self.prioritized_level_replay:
                        seed = self.sync.seeds[b]
                        env.set_env_seed(seed)
                        o = env.reset()

                    # Curriculum related:
                    if 'smooth' in self.curriculum and self.curriculum['smooth']:
                        env_reward_threshold = self.curriculum['threshold']
                        if self.sync.glob_average_return.value > env_reward_threshold and (itr - 2) >= \
                                self.last_itr_envs_updated[b]:
                            self.last_itr_envs_updated[b] = itr
                            env.change_curriculum_phase_smooth()
                            self.sync.difficulty.value = env.difficulty
                            o = env.reset()
                    else:
                        _, env_reward_threshold, min_itrs, next_env = self.curriculum[env.get_env_name()]
                        if self.sync.glob_average_return.value > env_reward_threshold or (itr - min_itrs) > \
                                self.last_itr_envs_updated[b]:
                            self.last_itr_envs_updated[b] = itr
                            curriculum_id, *_ = self.curriculum[next_env]
                            self.sync.curriculum_stage.value = curriculum_id
                            env.change_curriculum_phase(next_env)
                            o = env.reset()
                if d:
                    self.agent.reset_one(idx=b)
                observation[b] = o
                reward[b] = r
                env_buf.done[t, b] = d
                if env_info:
                    env_buf.env_info[t, b] = env_info

            agent_buf.action[t] = action
            env_buf.reward[t] = reward
            if agent_info:
                agent_buf.agent_info[t] = agent_info

        if "bootstrap_value" in agent_buf:
            # agent.value() should not advance rnn state.
            agent_buf.bootstrap_value[:] = self.agent.value(obs_pyt, act_pyt, rew_pyt)

        return AgentInputs(observation, action, reward), traj_infos, completed_infos
Exemple #14
0
    def collect_evaluation(self, itr):
        traj_infos = [self.TrajInfoCls() for _ in range(len(self.envs))]
        observations = list()
        for env in self.envs:
            observations.append(env.reset())
        observation = buffer_from_example(observations[0], len(self.envs))
        for b, o in enumerate(observations):
            observation[b] = o
        action = buffer_from_example(self.envs[0].action_space.null_value(),
                                     len(self.envs))
        reward = np.zeros(len(self.envs), dtype="float32")
        obs_pyt, act_pyt, rew_pyt = torchify_buffer(
            (observation, action, reward))
        self.agent.reset()
        self.agent.eval_mode(itr)

        #* Modifying the eval logic here: always return traj for each env of a worker
        # obs_pyt: num_eval_env_per x obs_dim(3); act_pyt: num_eval_env_per x act_dim(1); rew_pyt: num_eval_env_per
        envs_done_flag = np.zeros((len(self.envs)))
        for t in range(self.max_T):  # max_T=100, not eval_max_steps
            act_pyt, agent_info = self.agent.step(obs_pyt, act_pyt, rew_pyt)
            action = numpify_buffer(act_pyt)

            # Go through each env in a worker
            for b, env in enumerate(self.envs):
                o, r, d, env_info = env.step(action[b])
                traj_infos[b].step(observation[b], action[b], r, d,
                                   agent_info[b], env_info)

                # Right now this one is never activated since our custom env (pendulum) does not return any info at each step
                if getattr(env_info, "traj_done", d):
                    self.traj_infos_queue.put(traj_infos[b].terminate(o))
                    traj_infos[b] = self.TrajInfoCls()
                    o = env.reset()
                    envs_done_flag[b] = 1

                # Right now this one is never activated since our custom env (pendulum) does not say done
                if d:
                    action[b] = 0  # Next prev_action.
                    r = 0
                    self.agent.reset_one(
                        idx=b)  # this does not do anything right now
                    envs_done_flag[b] = 1

                # Save saliency
                if t == 10 and b == 0 and self.agent.saliency_dir is not None:
                    saliency(img=o,
                             model=self.agent.model,
                             save_path=self.agent.saliency_dir + str(itr) +
                             '.png')

                observation[b] = o
                reward[b] = r
            if self.sync.stop_eval.value:
                break

        # Regardless, add to queue TODO: need to tell traj_info the global index of envs (like which image was used)
        for b in range(len(self.envs)):
            if envs_done_flag[b] < 1e-4:
                self.traj_infos_queue.put(traj_infos[b].terminate(o))

        self.traj_infos_queue.put(None)  # End sentinel.
Exemple #15
0
    def collect_batch(self, agent_inputs, traj_infos, itr):
        # Numpy arrays can be written to from numpy arrays or torch tensors
        # (whereas torch tensors can only be written to from torch tensors).
        """
        收集(即采样)一批数据。这个函数在Sampler类(例如SerialSampler)的obtain_samples()函数中会被调用。
        这里面会发生推断action的过程(NN的前向传播)。
        整个过程中,会发生几种step(步进)事件 :在agent中step(),在environment中step(),在trajectory中step()。

        :param agent_inputs: 上一次收集到的数据,类型为AgentInputs(一个namedarraytuple),包含observation, action, reward的信息。
        :param traj_infos: TrajInfo类对象组成的一个list,包含trajectory的一些统计信息。
        :param itr: 第几次迭代。
        :return: AgentInputs, list(TrajInfo对象), list(TrajInfo对象)
        """
        agent_buf, env_buf = self.samples_np.agent, self.samples_np.env  # self.samples_np在Sampler类中的initialize()函数里初始化
        completed_infos = list()
        observation, action, reward = agent_inputs  # 右式:一个namedarraytuple,参见 rlpyt/agents/base.py 中的 AgentInputs
        obs_pyt, act_pyt, rew_pyt = torchify_buffer(
            agent_inputs)  # 转换成torch.Tensor格式
        agent_buf.prev_action[0] = action  # Leading prev_action.
        env_buf.prev_reward[0] = reward
        self.agent.sample_mode(itr)
        """
        下面这段代码有两层loop,第一层是对人为指定的time step进行loop,第二层是对所有environment进行loop,最后的效果就是每个environment
        走time step步。
        # TODO:
        """
        for t in range(
                self.batch_T):  # batch_T:在采集数据的时候,每个environment走多少个time step
            env_buf.observation[t] = observation
            # Agent inputs and outputs are torch tensors.
            act_pyt, agent_info = self.agent.step(
                obs_pyt, act_pyt, rew_pyt)  # 根据输入选择一个action,策略网络的前向传播过程在这里发生
            action = numpify_buffer(
                act_pyt)  # action由torch.Tensor转换成numpy array格式
            for b, env in enumerate(self.envs):
                # Environment inputs and outputs are numpy arrays.
                o, r, d, env_info = env.step(
                    action[b])  # 计算reward,统计environment信息等
                traj_infos[b].step(observation[b], action[b], r, d,
                                   agent_info[b], env_info)  # 统计trajectory的信息
                # EnvInfo里traj_done属性为True的情况,对游戏来说不一定是玩到赢了一关,也有可能是游戏玩得差game over了
                if getattr(
                        env_info, "traj_done", d
                ):  # 在environment里step()的时候会记录traj_done信息(是否走到了trajectory尽头)
                    completed_infos.append(
                        traj_infos[b].terminate(o))  # 传入的observation参数其实没有用
                    traj_infos[b] = self.TrajInfoCls()  # TrajInfo类的对象
                    o = env.reset()
                if d:  # done标志。对游戏来说,done的情况包含一局游戏game over,也包含没有剩余的生命了(TODO:确认是否正确?)
                    self.agent.reset_one(
                        idx=b)  # 只对RecurrentAgentMixin有用,用处暂时不理解(TODO:)
                observation[b] = o
                reward[b] = r
                env_buf.done[t, b] = d
                if env_info:
                    env_buf.env_info[t, b] = env_info
            agent_buf.action[t] = action
            env_buf.reward[t] = reward
            if agent_info:
                agent_buf.agent_info[t] = agent_info

        if "bootstrap_value" in agent_buf:
            # agent.value() should not advance rnn state.
            agent_buf.bootstrap_value[:] = self.agent.value(
                obs_pyt, act_pyt, rew_pyt)

        return AgentInputs(observation, action,
                           reward), traj_infos, completed_infos
Exemple #16
0
    def collect_batch(self, agent_inputs, traj_infos, itr):
        # Numpy arrays can be written to from numpy arrays or torch tensors
        # (whereas torch tensors can only be written to from torch tensors).
        agent_buf, env_buf = self.samples_np.agent, self.samples_np.env
        completed_infos = list()
        observation, action, reward = agent_inputs
        """
        以array形式输出满足条件(即非0)元素,即已经done的env的indices。最后的[0]并不是表示取所有indices里的第1个元素,而是恰恰是指取所有
        done的indices,因为np.where()返回的是一个array,里面只有一个元素,它是一个array,保存的就是所有indices。
        """
        b = np.where(self.done)[0]
        observation[b] = self.temp_observation[b]
        self.done[:] = False  # Did resets between batches.
        obs_pyt, act_pyt, rew_pyt = torchify_buffer(agent_inputs)
        agent_buf.prev_action[0] = action  # Leading prev_action.
        env_buf.prev_reward[0] = reward
        self.agent.sample_mode(itr)
        for t in range(self.batch_T):
            env_buf.observation[t] = observation
            # Agent inputs and outputs are torch tensors.
            act_pyt, agent_info = self.agent.step(obs_pyt, act_pyt, rew_pyt)
            action = numpify_buffer(act_pyt)
            for b, env in enumerate(self.envs):
                if self.done[b]:
                    action[b] = 0  # Record blank.
                    reward[b] = 0
                    if agent_info:
                        agent_info[b] = 0
                    # Leave self.done[b] = True, record that.
                    continue
                # Environment inputs and outputs are numpy arrays.
                o, r, d, env_info = env.step(action[b])
                traj_infos[b].step(observation[b], action[b], r, d,
                                   agent_info[b], env_info)
                # EnvInfo里traj_done属性为True的情况,对游戏来说不一定是玩到赢了一关,也有可能是游戏玩得差game over了
                if getattr(env_info, "traj_done", d):
                    completed_infos.append(traj_infos[b].terminate(o))
                    traj_infos[b] = self.TrajInfoCls()
                    self.need_reset[
                        b] = True  # 这里设置的这个标志,在后面的reset_if_needed()函数中会根据这个标志来重置若干种数据
                if d:  # done标志
                    self.temp_observation[b] = o
                    o = 0  # Record blank.
                observation[b] = o
                reward[b] = r
                self.done[b] = d
                if env_info:
                    env_buf.env_info[t, b] = env_info
            agent_buf.action[t] = action
            env_buf.reward[t] = reward
            env_buf.done[t] = self.done
            if agent_info:
                agent_buf.agent_info[t] = agent_info

        if "bootstrap_value" in agent_buf:
            # agent.value() should not advance rnn state.
            agent_buf.bootstrap_value[:] = self.agent.value(
                obs_pyt, act_pyt, rew_pyt)

        return AgentInputs(observation, action,
                           reward), traj_infos, completed_infos
Exemple #17
0
    def collect_evaluation(self, itr):
        if isinstance(self.envs[0], CWTO_EnvWrapperAtari):
            observer_traj_infos = [
                self.TrajInfoCls(n_obs=env.window_size, serial=env.serial)
                for env in self.envs
            ]
        else:
            observer_traj_infos = [
                self.TrajInfoCls(n_obs=env.obs_size, serial=env.serial)
                for env in self.envs
            ]
        player_traj_infos = [self.TrajInfoCls() for _ in range(len(self.envs))]
        observer_observations = list()
        player_observations = list()
        for env in self.envs:
            observer_observations.append(env.reset())
            player_observations.append(
                env.player_observation_space.null_value())
        observer_observation = buffer_from_example(observer_observations[0],
                                                   len(self.envs))
        player_observation = buffer_from_example(player_observations[0],
                                                 len(self.envs))
        observer_reward = np.zeros(len(self.envs), dtype="float32")
        player_reward = np.zeros(len(self.envs), dtype="float32")
        for b, o in enumerate(observer_observations):
            observer_observation[b] = o
        observer_action = buffer_from_example(
            self.envs[0].observer_action_space.null_value(), len(self.envs))
        player_action = buffer_from_example(
            self.envs[0].player_action_space.null_value(), len(self.envs))
        observer_obs_pyt, observer_act_pyt, observer_rew_pyt = torchify_buffer(
            (observer_observation, observer_action, observer_reward))
        player_obs_pyt, player_act_pyt, player_rew_pyt = torchify_buffer(
            (player_observation, player_action, player_reward))
        self.agent.reset()
        self.agent.eval_mode(itr)
        observer_agent_info = [{} for _ in range(len(self.envs))]
        player_agent_info = [{} for _ in range(len(self.envs))]
        prev_reset = np.zeros(len(self.envs), dtype=bool)
        for t in range(self.max_T):
            for _ in range(2):
                if self.envs[0].player_turn:
                    player_act_pyt, player_agent_info = self.agent.step(
                        player_obs_pyt, player_act_pyt, player_rew_pyt)
                    player_action = numpify_buffer(player_act_pyt)
                    for b, env in enumerate(self.envs):
                        o, r, d, env_info = env.step(player_action[b])
                        # if d and (env.observer_reward_shaping is not None):
                        r_obs, cost_obs = env.observer_reward_shaping(
                            r, env.last_obs_act)
                        # else:
                        #     r_obs = r
                        observer_traj_infos[b].step(observer_observation[b],
                                                    observer_action[b],
                                                    r_obs,
                                                    d,
                                                    observer_agent_info[b],
                                                    env_info,
                                                    cost=cost_obs,
                                                    obs_act=env.last_obs_act)
                        if getattr(env_info, "traj_done", d):
                            self.observer_traj_infos_queue.put(
                                observer_traj_infos[b].terminate(o))
                            observer_traj_infos[b] = self.TrajInfoCls(
                                n_obs=env.obs_size, serial=env.serial)
                            o = env.reset()
                            prev_reset[b] = True
                            # if env.player_reward_shaping is not None:
                            r_ply, cost_ply = env.player_reward_shaping(
                                r, env.last_obs_act)
                            # else:
                            #     r_ply = r
                            if self.log_full_obs:
                                obs_to_log = env.last_obs
                            else:
                                obs_to_log = player_observation[b]
                            player_traj_infos[b].step(obs_to_log,
                                                      player_action[b], r_ply,
                                                      d, player_agent_info[b],
                                                      env_info, cost_ply)
                            self.player_traj_infos_queue.put(
                                player_traj_infos[b].terminate(
                                    env.player_observation_space.null_value()))
                            player_traj_infos[b] = self.TrajInfoCls()
                        if d:
                            observer_action[b] = 0
                            player_action[b] = 0
                            r_obs = 0
                            r_ply = 0
                            player_reward[b] = r_ply
                            self.agent.reset_one(idx=b)
                        observer_observation[b] = o
                        observer_reward[b] = r_obs
                else:
                    while not self.envs[0].player_turn:
                        pturn = self.envs[0].player_turn
                        observer_act_pyt, observer_agent_info = self.agent.step(
                            observer_obs_pyt, observer_act_pyt,
                            observer_rew_pyt)
                        observer_action = numpify_buffer(observer_act_pyt)
                        for b, env in enumerate(self.envs):
                            assert pturn == env.player_turn
                            o, r, d, env_info = env.step(observer_action[b])
                            assert not d
                            if env.player_turn:
                                r_ply, cost_ply = env.player_reward_shaping(
                                    r, env.last_obs_act)
                                if prev_reset[b]:
                                    prev_reset[b] = False
                                else:
                                    player_reward[b] = r_ply
                                    if self.log_full_obs:
                                        obs_to_log = env.last_obs
                                    else:
                                        obs_to_log = player_observation[b]
                                    player_traj_infos[b].step(
                                        obs_to_log, player_action[b], r_ply, d,
                                        player_agent_info[b], env_info,
                                        cost_ply)
                                player_observation[b] = o

                            else:
                                observer_traj_infos[b].step(
                                    observer_observation[b],
                                    observer_action[b],
                                    r,
                                    d,
                                    observer_agent_info[b],
                                    env_info,
                                    cost=0)
                                observer_observation[b] = o
                                observer_reward[b] = r
                if self.player_sync.stop_eval.value:
                    break

        self.observer_traj_infos_queue.put(None)  # End sentinel.
        self.player_traj_infos_queue.put(None)  # End sentinel.
Exemple #18
0
    def _collect_batch(self, agent_inputs, traj_infos, itr):
        # Numpy arrays can be written to from numpy arrays or torch tensors
        # (whereas torch tensors can only be written to from torch tensors).
        agent_buf, env_buf = self.samples_np.agent, self.samples_np.env
        completed_infos = list()
        observation, action, reward = agent_inputs
        obs_pyt, act_pyt, rew_pyt = torchify_buffer(agent_inputs)
        agent_buf.prev_action[0] = action  # Leading prev_action.
        env_buf.prev_reward[0] = reward
        self.agent.sample_mode(itr)
        for t in range(self.batch_T):
            if (t * len(self.envs)) % 400 == 0:
                self.agent.recv_shared_memory()
            env_buf.observation[t] = observation
            # Agent inputs and outputs are torch tensors.
            act_pyt, agent_info = self.agent.step(obs_pyt, act_pyt, rew_pyt)
            action = numpify_buffer(act_pyt)
            static_decoder_inputs = []
            correct_labels = []
            env_infos = []
            done = []
            for b, env in enumerate(self.envs):
                # Environment inputs and outputs are numpy arrays.
                o, r, d, env_info = env.step(action[b])
                done.append(d)
                observation[b] = o
                reward[b] = r
                env_infos.append(env_info)
                static_decoder_inputs.append(env_info.static_decoder_input)
                correct_labels.append(env_info.correct_label)

            static_decoder_inputs = np.stack(static_decoder_inputs)
            correct_labels = np.stack(correct_labels)
            label_prediction = np.argmax(
                self.static_decoder(static_decoder_inputs),
                axis=-1).squeeze(axis=1)
            done = label_prediction != correct_labels

            for b, env in enumerate(self.envs):
                traj_infos[b].step(observation[b], action[b], reward[b],
                                   done[b], agent_info[b], env_infos[b])
                if getattr(env_info, "traj_done", done[b]):
                    completed_infos.append(traj_infos[b].terminate(
                        observation[b]))
                    traj_infos[b] = self.TrajInfoCls()
                    observation[b] = env.reset()
                if done[b]:
                    self.agent.reset_one(idx=b)
                env_buf.done[t, b] = done[b]
                if env_info:
                    env_buf.env_info[t, b] = env_infos[b]
            agent_buf.action[t] = action
            env_buf.reward[t] = reward
            if agent_info:
                agent_buf.agent_info[t] = agent_info

        if "bootstrap_value" in agent_buf:
            # agent.value() should not advance rnn state.
            agent_buf.bootstrap_value[:] = self.agent.value(
                obs_pyt, act_pyt, rew_pyt)

        return AgentInputs(observation, action,
                           reward), traj_infos, completed_infos
Exemple #19
0
 def update_batch_priorities(self, priorities):
     with self.rw_lock.write_lock:
         priorities = numpify_buffer(priorities)
         self.default_priority = max(priorities)
         self.priority_tree.update_batch_priorities(priorities**self.alpha)
Exemple #20
0
 def update_batch_priorities(self, priorities):
     priorities = numpify_buffer(priorities)
     self.priority_tree.update_batch_priorities(priorities**self.alpha)
Exemple #21
0
    def collect_batch(self, player_agent_inputs, observer_agent_inputs,
                      player_traj_infos, observer_traj_infos, itr):
        # Numpy arrays can be written to from numpy arrays or torch tensors
        # (whereas torch tensors can only be written to from torch tensors).
        player_agent_buf, player_env_buf = self.player_samples_np.agent, self.player_samples_np.env
        observer_agent_buf, observer_env_buf = self.observer_samples_np.agent, self.observer_samples_np.env
        player_completed_infos = list()
        observer_completed_infos = list()
        observer_observation, observer_action, observer_reward = observer_agent_inputs
        player_observation, player_action, player_reward = player_agent_inputs
        observer_obs_pyt, observer_act_pyt, observer_rew_pyt = torchify_buffer(
            observer_agent_inputs)
        player_obs_pyt, player_act_pyt, player_rew_pyt = torchify_buffer(
            player_agent_inputs)
        observer_agent_buf.prev_action[0] = np.reshape(
            observer_action,
            observer_agent_buf.prev_action[0].shape)  # Leading prev_action.
        observer_env_buf.prev_reward[0] = observer_reward
        player_agent_buf.prev_action[0] = np.reshape(
            player_action,
            player_agent_buf.prev_action[0].shape)  # Leading prev_action.
        player_env_buf.prev_reward[0] = player_reward
        self.agent.sample_mode(itr)
        observer_agent_info = [{} for _ in range(len(self.envs))]
        player_agent_info = [{} for _ in range(len(self.envs))]
        ser_count = 0
        t = 0
        prev_reset = np.zeros(len(self.envs), dtype=bool)
        player_done = np.zeros(len(self.envs), dtype=bool)
        player_env_info = [None for _ in range(len(self.envs))]
        observer_done = np.zeros(len(self.envs), dtype=bool)
        observer_env_info = [None for _ in range(len(self.envs))]
        while t < self.batch_T:
            # all envs must be in the same player_turn status!
            if self.envs[0].player_turn:
                player_env_buf.observation[t] = player_observation
                player_env_buf.reward[t] = player_reward
                for ee in range(len(self.envs)):
                    player_env_buf.done[t, ee] = player_done[ee]
                    if player_env_info[ee] is not None:
                        player_env_buf.env_info[t, ee] = player_env_info[ee]
                player_done = np.zeros(len(self.envs), dtype=bool)
                player_env_info = [None for _ in range(len(self.envs))]
                player_act_pyt, player_agent_info = self.agent.step(
                    player_obs_pyt, player_act_pyt.float(), player_rew_pyt)
                player_action = numpify_buffer(player_act_pyt)
                for b, env in enumerate(self.envs):
                    o, r, d, env_info = env.step(player_action[b])
                    # if d and (env.observer_reward_shaping is not None):
                    r_obs, cost_obs = env.observer_reward_shaping(
                        r, env.last_obs_act)
                    # else:
                    #     r_obs = r
                    observer_traj_infos[b].step(observer_observation[b],
                                                observer_action[b],
                                                r_obs,
                                                d,
                                                observer_agent_info[b],
                                                env_info,
                                                cost=cost_obs,
                                                obs_act=env.last_obs_act)
                    if getattr(env_info, "traj_done", d):
                        observer_completed_infos.append(
                            observer_traj_infos[b].terminate(o))
                        if isinstance(env, CWTO_EnvWrapperAtari):
                            observer_traj_infos[b] = self.TrajInfoCls(
                                n_obs=env.window_size, serial=env.serial)
                        else:
                            observer_traj_infos[b] = self.TrajInfoCls(
                                n_obs=env.obs_size, serial=env.serial)
                        # if env.player_reward_shaping is not None:
                        r_ply, cost_ply = env.player_reward_shaping(
                            r, env.last_obs_act)
                        # else:
                        #     r_ply = r
                        player_traj_infos[b].step(player_observation[b],
                                                  player_action[b], r_ply, d,
                                                  player_agent_info[b],
                                                  env_info, cost_ply)
                        player_completed_infos.append(
                            player_traj_infos[b].terminate(
                                env.player_observation_space.null_value()))
                        player_traj_infos[b] = self.TrajInfoCls()
                        prev_reset[b] = True
                        o = env.reset()
                    if d:
                        self.agent.reset_one(idx=b)
                        player_reward[b] = r_ply
                        player_done[b] = d

                        if env_info:
                            player_env_buf.env_info[t, b] = env_info
                            player_env_info[b] = env_info

                    observer_observation[b] = o
                    observer_reward[b] = r_obs

                    observer_done[b] = d
                    if env_info:
                        observer_env_info[b] = env_info

                player_agent_buf.action[t] = player_action
                if player_agent_info:
                    player_agent_buf.agent_info[t] = player_agent_info
                t += 1
            else:
                while not self.envs[0].player_turn:
                    pturn = self.envs[0].player_turn
                    observer_env_buf.observation[
                        ser_count] = observer_observation
                    observer_env_buf.reward[ser_count] = observer_reward
                    for ee in range(len(self.envs)):
                        observer_env_buf.done[ser_count,
                                              ee] = observer_done[ee]
                        if observer_env_info[ee] is not None:
                            observer_env_buf.env_info[
                                ser_count, ee] = observer_env_info[ee]
                    observer_done = np.zeros(len(self.envs), dtype=bool)
                    observer_env_info = [None for _ in range(len(self.envs))]
                    observer_act_pyt, observer_agent_info = self.agent.step(
                        observer_obs_pyt, observer_act_pyt, observer_rew_pyt)
                    observer_action = numpify_buffer(observer_act_pyt)
                    for b, env in enumerate(self.envs):
                        assert pturn == env.player_turn
                        o, r, d, env_info = env.step(observer_action[b])
                        assert not d
                        if env.player_turn:
                            if prev_reset[b]:
                                prev_reset[b] = False
                            else:
                                r_ply, cost_ply = env.player_reward_shaping(
                                    r, env.last_obs_act)
                                player_traj_infos[b].step(
                                    player_observation[b], player_action[b],
                                    r_ply, d, player_agent_info[b], env_info,
                                    cost_ply)
                                player_reward[b] = r_ply
                                player_done[b] = d
                                if env_info:
                                    player_env_info[b] = env_info

                            player_observation[b] = o
                        else:
                            # no shaping here - it will be included in "player_turn", also, cost = 0
                            observer_traj_infos[b].step(
                                observer_observation[b],
                                observer_action[b],
                                r,
                                d,
                                observer_agent_info[b],
                                env_info,
                                cost=0)
                            observer_observation[b] = o
                            observer_reward[b] = r
                            observer_env_buf.done[ser_count, b] = d
                            observer_done[b] = d
                            if env_info:
                                observer_env_buf.env_info[ser_count,
                                                          b] = env_info
                                observer_env_info[b] = env_info

                    observer_agent_buf.action[ser_count] = observer_action
                    if observer_agent_info:
                        observer_agent_buf.agent_info[
                            ser_count] = observer_agent_info
                    ser_count += 1

        if "bootstrap_value" in player_agent_buf:
            # agent.value() should not advance rnn state.
            player_agent_buf.bootstrap_value[:] = self.agent.value(
                player_obs_pyt, player_act_pyt, player_rew_pyt, is_player=True)
        if "bootstrap_value" in observer_agent_buf:
            # agent.value() should not advance rnn state.
            observer_agent_buf.bootstrap_value[:] = self.agent.value(
                observer_obs_pyt,
                observer_act_pyt,
                observer_rew_pyt,
                is_player=False)

        return AgentInputs(
            player_observation, player_action, player_reward
        ), player_traj_infos, player_completed_infos, AgentInputs(
            observer_observation, observer_action,
            observer_reward), observer_traj_infos, observer_completed_infos
Exemple #22
0
    def collect_evaluation(self, itr, max_episodes=1):
        assert len(
            self.envs
        ) == 1, 'qec eval collector needs max 1 env. Otherwise evaluation will be biased'
        traj_infos = [self.TrajInfoCls() for _ in range(len(self.envs))]
        observations = list()
        for env in self.envs:
            observations.append(env.reset())
        observation = buffer_from_example(observations[0], len(self.envs))
        for b, o in enumerate(observations):
            observation[b] = o
        action = buffer_from_example(self.envs[0].action_space.null_value(),
                                     len(self.envs))
        reward = np.zeros(len(self.envs), dtype="float32")
        obs_pyt, act_pyt, rew_pyt = torchify_buffer(
            (observation, action, reward))
        self.agent.reset()
        self.agent.eval_mode(itr)
        num_completed_episodes = 0
        for t in range(self.max_T):
            act_pyt, agent_info = self.agent.step(obs_pyt, act_pyt, rew_pyt)
            action = numpify_buffer(act_pyt)
            static_decoder_inputs = []
            correct_labels = []
            env_infos = []
            done = []
            for b, env in enumerate(self.envs):
                o, r, d, env_info = env.step(action[b])
                done.append(d)
                observation[b] = o
                reward[b] = r
                env_infos.append(env_info)
                static_decoder_inputs.append(env_info.static_decoder_input)
                correct_labels.append(env_info.correct_label)

            static_decoder_inputs = np.stack(static_decoder_inputs)
            correct_labels = np.stack(correct_labels)
            label_prediction = np.argmax(
                self.static_decoder(static_decoder_inputs),
                axis=-1).squeeze(axis=1)
            done = label_prediction != correct_labels

            for b, env in enumerate(self.envs):
                traj_infos[b].step(observation[b], action[b], reward[b],
                                   done[b], agent_info[b], env_infos[b])
                if getattr(env_infos[b], "traj_done", done[b]):
                    self.traj_infos_queue.put(traj_infos[b].terminate(
                        observation[b]))
                    traj_infos[b] = self.TrajInfoCls()
                    observation[b] = env.reset()
                if done[b]:
                    action[b] = 0  # Next prev_action.
                    reward[b] = 0
                    self.agent.reset_one(idx=b)
                    num_completed_episodes += 1
            if num_completed_episodes >= max_episodes:
                print('reached max episodes')
                break
            if self.sync.stop_eval.value:
                print(f'sync stop')
                break
        self.traj_infos_queue.put(None)  # End sentinel.
    def collect_batch(self, agent_inputs, traj_infos, itr):
        # Numpy arrays can be written to from numpy arrays or torch tensors
        # (whereas torch tensors can only be written to from torch tensors).
        agent_buf, env_buf = self.samples_np.agent, self.samples_np.env
        completed_infos = list()

        observation, action, reward_tot = agent_inputs

        b = np.where(self.done)[0]
        # observation[b] = self.temp_observation[b]
        self.done[:] = False  # Did resets between batches.

        # torchifying syncs components of agent_inputs (observation, action, reward_tot)
        # with obs_pyt, act_pyt, rew_tot_pyt. Pytorch tensors point to the original numpy
        # array so updating observation will update obs_pyt etc.
        obs_pyt, act_pyt, rew_tot_pyt = torchify_buffer(agent_inputs)

        agent_buf.prev_action[0] = action  # Leading prev_action
        env_buf.prev_reward[0] = reward_tot  # Leading previous total reward
        self.agent.sample_mode(itr)
        for t in range(self.batch_T):

            env_buf.observation[t] = observation  # [0 : T]
            # Agent inputs and outputs are torch tensors.
            act_pyt, agent_info = self.agent.step(obs_pyt, act_pyt,
                                                  rew_tot_pyt)
            action = numpify_buffer(act_pyt)
            for b, env in enumerate(self.envs):
                if self.done[b]:
                    action[b] = 0  # Record blank.
                    reward_tot[b] = 0
                    if agent_info:
                        agent_info[b] = 0
                    # Leave self.done[b] = True, record that.
                    continue

                # Environment inputs and outputs are numpy arrays.
                o, r_ext, d, env_info = env.step(action[b])

                r_ext_log = r_ext  # to ensure r_ext gets recorded regardless
                if self.no_extrinsic:
                    r_ext = 0.0

                traj_infos[b].step(observation[b], action[b], r_ext_log, d,
                                   agent_info[b], env_info)
                if getattr(env_info, "traj_done", d):
                    completed_infos.append(traj_infos[b].terminate())
                    traj_infos[b] = self.TrajInfoCls()
                    self.need_reset[b] = True
                if d:
                    self.temp_observation[b] = o
                    o = 0  # Record blank.
                self.done[b] = d

                observation[b] = o
                reward_tot[b] = r_ext

                if env_info:
                    env_buf.env_info[t, b] = env_info

            agent_buf.action[t] = action
            env_buf.reward[t] = reward_tot
            env_buf.next_observation[t] = observation  # [1 : T+1]
            env_buf.done[t] = self.done
            if agent_info:
                agent_buf.agent_info[t] = agent_info

        if "bootstrap_value" in agent_buf:
            # agent.value() should not advance rnn state.
            agent_buf.bootstrap_value[:] = self.agent.value(
                obs_pyt, act_pyt, rew_tot_pyt)

        if "int_bootstrap_value" in agent_buf:
            agent_buf.int_bootstrap_value[:] = self.agent.value(obs_pyt,
                                                                act_pyt,
                                                                rew_tot_pyt,
                                                                ret_int=True)

        return AgentInputs(observation, action,
                           reward_tot), traj_infos, completed_infos