Esempio n. 1
0
    def _generate_time_steps(self, trajectory_list):
        """A generator to yield single time-steps from a list of trajectories."""
        for single_trajectory in trajectory_list:
            assert isinstance(single_trajectory, trajectory.Trajectory)

            # Skip writing trajectories that have only a single time-step -- this
            # could just be a repeated reset.

            if single_trajectory.num_time_steps <= 1:
                continue

            for index, time_step in enumerate(single_trajectory.time_steps):

                # The first time-step doesn't have reward/processed_reward, if so, just
                # setting it to 0.0 / 0 should be OK.
                raw_reward = time_step.raw_reward
                if not raw_reward:
                    raw_reward = 0.0

                processed_reward = time_step.processed_reward
                if not processed_reward:
                    processed_reward = 0

                action = time_step.action
                if action is None:
                    # The last time-step doesn't have action, and this action shouldn't be
                    # used, gym's spaces have a `sample` function, so let's just sample an
                    # action and use that.
                    action = self.action_space.sample()
                action = gym_spaces_utils.gym_space_encode(
                    self.action_space, action)

                if six.PY3:
                    # py3 complains that, to_example cannot handle np.int64 !

                    action_dtype = self.action_space.dtype
                    if action_dtype in [np.int64, np.int32]:
                        action = list(map(int, action))
                    elif action_dtype in [np.float64, np.float32]:
                        action = list(map(float, action))

                    # same with processed_reward.
                    processed_reward = int(processed_reward)

                assert time_step.observation is not None

                yield {
                    TIMESTEP_FIELD: [index],
                    ACTION_FIELD:
                    action,
                    # to_example errors on np.float32
                    RAW_REWARD_FIELD: [float(raw_reward)],
                    PROCESSED_REWARD_FIELD: [processed_reward],
                    # to_example doesn't know bools
                    DONE_FIELD: [int(time_step.done)],
                    OBSERVATION_FIELD:
                    gym_spaces_utils.gym_space_encode(self.observation_space,
                                                      time_step.observation),
                }
 def test_box_space_encode(self):
     box_space = Box(low=0, high=10, shape=[2], dtype=np.int64)
     value = np.array([2, 3])
     encoded_value = gym_spaces_utils.gym_space_encode(box_space, value)
     self.assertListEqual([2, 3], encoded_value)
 def test_discrete_space_encode(self):
     discrete_space = Discrete(100)
     value = discrete_space.sample()
     encoded_value = gym_spaces_utils.gym_space_encode(
         discrete_space, value)
     self.assertListEqual([value], encoded_value)