コード例 #1
0
    def test_dict_observation(self):
        obs_key = 'mjprofile'

        env = cartpole.swingup()

        # Make sure we are testing the right environment for the test.
        observation_spec = env.observation_spec()
        self.assertIsInstance(observation_spec, collections.OrderedDict)

        # The wrapper should only add one observation.
        wrapped = mujoco_profiling.Wrapper(env, observation_key=obs_key)

        wrapped_observation_spec = wrapped.observation_spec()
        self.assertIsInstance(wrapped_observation_spec,
                              collections.OrderedDict)

        expected_length = len(observation_spec) + 1
        self.assertLen(wrapped_observation_spec, expected_length)
        expected_keys = list(observation_spec.keys()) + [obs_key]
        self.assertEqual(expected_keys, list(wrapped_observation_spec.keys()))

        # Check that the added spec item is consistent with the added observation.
        time_step = wrapped.reset()
        profile_observation = time_step.observation[obs_key]
        wrapped_observation_spec[obs_key].validate(profile_observation)

        self.assertEqual(profile_observation.shape, (2, ))
        self.assertEqual(profile_observation.dtype, np.double)
コード例 #2
0
 def test_dynamic(self):
   camera_kwargs = get_camera_params(
       domain_name='cartpole', scale=0.1, dynamic=True)
   env = cartpole.swingup()
   env = camera.DistractingCameraEnv(env, camera_id=0, **camera_kwargs)
   env = pixels.Wrapper(env, render_kwargs={'camera_id': 0})
   action_spec = env.action_spec()
   time_step = env.reset()
   frames = []
   while not time_step.last() and len(frames) < 10:
     action = np.random.uniform(
         action_spec.minimum, action_spec.maximum, size=action_spec.shape)
     time_step = env.step(action)
     frames.append(time_step.observation['pixels'])
   self.assertEqual(frames[0].shape, (240, 320, 3))
コード例 #3
0
    def test_dict_observation(self, pixels_only):
        pixel_key = 'rgb'

        env = cartpole.swingup()

        # Make sure we are testing the right environment for the test.
        observation_spec = env.observation_spec()
        self.assertIsInstance(observation_spec, collections.OrderedDict)

        width = 320
        height = 240

        # The wrapper should only add one observation.
        wrapped = pixels.Wrapper(env,
                                 observation_key=pixel_key,
                                 pixels_only=pixels_only,
                                 render_kwargs={
                                     'width': width,
                                     'height': height
                                 })

        wrapped_observation_spec = wrapped.observation_spec()
        self.assertIsInstance(wrapped_observation_spec,
                              collections.OrderedDict)

        if pixels_only:
            self.assertEqual(1, len(wrapped_observation_spec))
            self.assertEqual([pixel_key],
                             list(wrapped_observation_spec.keys()))
        else:
            self.assertEqual(
                len(observation_spec) + 1, len(wrapped_observation_spec))
            expected_keys = list(observation_spec.keys()) + [pixel_key]
            self.assertEqual(expected_keys,
                             list(wrapped_observation_spec.keys()))

        # Check that the added spec item is consistent with the added observation.
        time_step = wrapped.reset()
        rgb_observation = time_step.observation[pixel_key]
        wrapped_observation_spec[pixel_key].validate(rgb_observation)

        self.assertEqual(rgb_observation.shape, (height, width, 3))
        self.assertEqual(rgb_observation.dtype, np.uint8)