def test_dict_observation(self): obs_key = 'mjprofile' env = cartpole.swingup() # Make sure we are testing the right environment for the test. observation_spec = env.observation_spec() self.assertIsInstance(observation_spec, collections.OrderedDict) # The wrapper should only add one observation. wrapped = mujoco_profiling.Wrapper(env, observation_key=obs_key) wrapped_observation_spec = wrapped.observation_spec() self.assertIsInstance(wrapped_observation_spec, collections.OrderedDict) expected_length = len(observation_spec) + 1 self.assertLen(wrapped_observation_spec, expected_length) expected_keys = list(observation_spec.keys()) + [obs_key] self.assertEqual(expected_keys, list(wrapped_observation_spec.keys())) # Check that the added spec item is consistent with the added observation. time_step = wrapped.reset() profile_observation = time_step.observation[obs_key] wrapped_observation_spec[obs_key].validate(profile_observation) self.assertEqual(profile_observation.shape, (2, )) self.assertEqual(profile_observation.dtype, np.double)
def test_dynamic(self): camera_kwargs = get_camera_params( domain_name='cartpole', scale=0.1, dynamic=True) env = cartpole.swingup() env = camera.DistractingCameraEnv(env, camera_id=0, **camera_kwargs) env = pixels.Wrapper(env, render_kwargs={'camera_id': 0}) action_spec = env.action_spec() time_step = env.reset() frames = [] while not time_step.last() and len(frames) < 10: action = np.random.uniform( action_spec.minimum, action_spec.maximum, size=action_spec.shape) time_step = env.step(action) frames.append(time_step.observation['pixels']) self.assertEqual(frames[0].shape, (240, 320, 3))
def test_dict_observation(self, pixels_only): pixel_key = 'rgb' env = cartpole.swingup() # Make sure we are testing the right environment for the test. observation_spec = env.observation_spec() self.assertIsInstance(observation_spec, collections.OrderedDict) width = 320 height = 240 # The wrapper should only add one observation. wrapped = pixels.Wrapper(env, observation_key=pixel_key, pixels_only=pixels_only, render_kwargs={ 'width': width, 'height': height }) wrapped_observation_spec = wrapped.observation_spec() self.assertIsInstance(wrapped_observation_spec, collections.OrderedDict) if pixels_only: self.assertEqual(1, len(wrapped_observation_spec)) self.assertEqual([pixel_key], list(wrapped_observation_spec.keys())) else: self.assertEqual( len(observation_spec) + 1, len(wrapped_observation_spec)) expected_keys = list(observation_spec.keys()) + [pixel_key] self.assertEqual(expected_keys, list(wrapped_observation_spec.keys())) # Check that the added spec item is consistent with the added observation. time_step = wrapped.reset() rgb_observation = time_step.observation[pixel_key] wrapped_observation_spec[pixel_key].validate(rgb_observation) self.assertEqual(rgb_observation.shape, (height, width, 3)) self.assertEqual(rgb_observation.dtype, np.uint8)