def test_thread_env(self): game = 'SuperMarioBros-Nes' self._env = thread_environment.ThreadEnvironment( lambda: suite_mario.load( game=game, state='Level1-1', wrap_with_process=False)) self.assertIsInstance(self._env, alf_environment.AlfEnvironment) self.assertEqual(torch.uint8, self._env.observation_spec().dtype) self.assertEqual((4, 84, 84), self._env.observation_spec().shape) actions = self._env.action_spec().sample() for _ in range(10): time_step = self._env.step(actions)
def test_mario_env(self): ctor = lambda: suite_mario.load( 'SuperMarioBros-Nes', 'Level1-1', wrap_with_process=False) self._env = parallel_py_environment.ParallelPyEnvironment([ctor] * 4) env = tf_py_environment.TFPyEnvironment(self._env) self.assertEqual(np.uint8, env.observation_spec().dtype) self.assertEqual((84, 84, 4), env.observation_spec().shape) random_policy = random_tf_policy.RandomTFPolicy( env.time_step_spec(), env.action_spec()) metrics = [ AverageReturnMetric(batch_size=4), AverageEpisodeLengthMetric(batch_size=4), EnvironmentSteps(), NumberOfEpisodes() ] driver = dynamic_step_driver.DynamicStepDriver(env, random_policy, metrics, 10000) driver.run(maximum_iterations=10000)
def ctor(game, env_id=None): return suite_mario.load(game=game, state='Level1-1', wrap_with_process=False)