Exemplo n.º 1
0
 def ctor(scene, env_id=None):
     return suite_dmlab.load(
         scene=scene,
         gym_env_wrappers=[
             gym_wrappers.FrameGrayScale, gym_wrappers.FrameResize,
             gym_wrappers.FrameStack
         ],
         wrap_with_process=False)
Exemplo n.º 2
0
 def test_dmlab_env(self):
     ctor = lambda: suite_dmlab.load(scene='lt_chasm',
                                     gym_env_wrappers=[
                                         wrappers.FrameGrayScale, wrappers.
                                         FrameResize, wrappers.FrameStack
                                     ],
                                     wrap_with_process=False)
     self._env = parallel_py_environment.ParallelPyEnvironment([ctor] * 2)
     env = tf_py_environment.TFPyEnvironment(self._env)
     self.assertEqual((84, 84, 4), env.observation_spec().shape)
Exemplo n.º 3
0
    def test_process_env(self):
        scene = 'lt_chasm'
        self._env = suite_dmlab.load(
            scene=scene,
            gym_env_wrappers=[
                gym_wrappers.FrameGrayScale, gym_wrappers.FrameResize,
                gym_wrappers.FrameStack
            ],
            wrap_with_process=True)
        self.assertIsInstance(self._env, alf_environment.AlfEnvironment)
        self.assertEqual((4, 84, 84), self._env.observation_spec().shape)

        for _ in range(10):
            actions = self._env.action_spec().sample()
            self._env.step(actions)
Exemplo n.º 4
0
    def test_dmlab_env_run(self, scene):
        ctor = lambda: suite_dmlab.load(scene=scene,
                                        gym_env_wrappers=
                                        [wrappers.FrameResize],
                                        wrap_with_process=False)

        self._env = parallel_py_environment.ParallelPyEnvironment([ctor] * 4)
        env = tf_py_environment.TFPyEnvironment(self._env)
        self.assertEqual((84, 84, 3), env.observation_spec().shape)

        random_policy = random_tf_policy.RandomTFPolicy(
            env.time_step_spec(), env.action_spec())

        driver = dynamic_step_driver.DynamicStepDriver(env=env,
                                                       policy=random_policy,
                                                       observers=None,
                                                       num_steps=10)

        driver.run(maximum_iterations=10)