def _generate_time_steps(self, trajectory_list): """A generator to yield single time-steps from a list of trajectories.""" for single_trajectory in trajectory_list: assert isinstance(single_trajectory, trajectory.Trajectory) # Skip writing trajectories that have only a single time-step -- this # could just be a repeated reset. if single_trajectory.num_time_steps <= 1: continue for index, time_step in enumerate(single_trajectory.time_steps): # The first time-step doesn't have reward/processed_reward, if so, just # setting it to 0.0 / 0 should be OK. raw_reward = time_step.raw_reward if not raw_reward: raw_reward = 0.0 processed_reward = time_step.processed_reward if not processed_reward: processed_reward = 0 action = time_step.action if action is None: # The last time-step doesn't have action, and this action shouldn't be # used, gym's spaces have a `sample` function, so let's just sample an # action and use that. action = self.action_space.sample() action = gym_spaces_utils.gym_space_encode( self.action_space, action) if six.PY3: # py3 complains that, to_example cannot handle np.int64 ! action_dtype = self.action_space.dtype if action_dtype in [np.int64, np.int32]: action = list(map(int, action)) elif action_dtype in [np.float64, np.float32]: action = list(map(float, action)) # same with processed_reward. processed_reward = int(processed_reward) assert time_step.observation is not None yield { TIMESTEP_FIELD: [index], ACTION_FIELD: action, # to_example errors on np.float32 RAW_REWARD_FIELD: [float(raw_reward)], PROCESSED_REWARD_FIELD: [processed_reward], # to_example doesn't know bools DONE_FIELD: [int(time_step.done)], OBSERVATION_FIELD: gym_spaces_utils.gym_space_encode(self.observation_space, time_step.observation), }
def test_box_space_encode(self): box_space = Box(low=0, high=10, shape=[2], dtype=np.int64) value = np.array([2, 3]) encoded_value = gym_spaces_utils.gym_space_encode(box_space, value) self.assertListEqual([2, 3], encoded_value)
def test_discrete_space_encode(self): discrete_space = Discrete(100) value = discrete_space.sample() encoded_value = gym_spaces_utils.gym_space_encode( discrete_space, value) self.assertListEqual([value], encoded_value)