Ejemplo n.º 1
0
    def split(self):
        """Split an EpisodeBatch into a list of EpisodeBatches.

        The opposite of concatenate.

        Returns:
            list[EpisodeBatch]: A list of EpisodeBatches, with one
                episode per batch.

        """
        episodes = []
        start = 0
        for i, length in enumerate(self.lengths):
            stop = start + length
            eps = EpisodeBatch(env_spec=self.env_spec,
                               observations=self.observations[start:stop],
                               last_observations=np.asarray(
                                   [self.last_observations[i]]),
                               actions=self.actions[start:stop],
                               rewards=self.rewards[start:stop],
                               env_infos=tensor_utils.slice_nested_dict(
                                   self.env_infos, start, stop),
                               agent_infos=tensor_utils.slice_nested_dict(
                                   self.agent_infos, start, stop),
                               step_types=self.step_types[start:stop],
                               lengths=np.asarray([length]))
            episodes.append(eps)
            start = stop
        return episodes
Ejemplo n.º 2
0
    def split(self):
        """Split a TrajectoryBatch into a list of TrajectoryBatches.

    The opposite of concatenate.

    Returns:
        list[TrajectoryBatch]: A list of TrajectoryBatches, with one
            trajectory per batch.

    """
        trajectories = []
        start = 0
        for i, length in enumerate(self.lengths):
            stop = start + length
            traj = TrajectoryBatch(
                env_spec=self.env_spec,
                observations=self.observations[start:stop],
                last_observations=np.asarray([self.last_observations[i]]),
                actions=self.actions[start:stop],
                rewards=self.rewards[start:stop],
                terminals=self.terminals[start:stop],
                env_infos=tensor_utils.slice_nested_dict(
                    self.env_infos, start, stop),
                agent_infos=tensor_utils.slice_nested_dict(
                    self.agent_infos, start, stop),
                lengths=np.asarray([length]),
            )
            trajectories.append(traj)
            start = stop
        return trajectories
Ejemplo n.º 3
0
 def split(self):
     trajectories = []
     start = 0
     for i, length in enumerate(self.lengths):
         stop = start + length
         traj = SkillTrajectoryBatch(env_spec=self.env_spec,
                                     num_skills=self.num_skills,
                                     skills=self.skills[start:stop],
                                     states=self.states[start:stop],
                                     last_states=np.asarray(
                                         [self.last_states[i]]),
                                     actions=self.actions[start:stop],
                                     env_rewards=self.env_rewards[
                                                 start:stop],
                                     self_rewards=self.self_rewards[
                                                  start:stop],
                                     terminals=self.terminals[start:stop],
                                     env_infos=tensor_utils.slice_nested_dict(
                                         self.env_infos, start, stop),
                                     agent_infos=tensor_utils.slice_nested_dict(
                                         self.agent_infos, start, stop),
                                     lengths=np.asarray([length]))
         trajectories.append(traj)
         start = stop
     return trajectories