コード例 #1
0
    def test_thread_env(self):
        game = 'SuperMarioBros-Nes'
        self._env = thread_environment.ThreadEnvironment(
            lambda: suite_mario.load(
                game=game, state='Level1-1', wrap_with_process=False))
        self.assertIsInstance(self._env, alf_environment.AlfEnvironment)
        self.assertEqual(torch.uint8, self._env.observation_spec().dtype)
        self.assertEqual((4, 84, 84), self._env.observation_spec().shape)

        actions = self._env.action_spec().sample()
        for _ in range(10):
            time_step = self._env.step(actions)
コード例 #2
0
ファイル: suite_mario_test.py プロジェクト: ruizhaogit/alf
    def test_mario_env(self):
        ctor = lambda: suite_mario.load(
            'SuperMarioBros-Nes', 'Level1-1', wrap_with_process=False)

        self._env = parallel_py_environment.ParallelPyEnvironment([ctor] * 4)
        env = tf_py_environment.TFPyEnvironment(self._env)
        self.assertEqual(np.uint8, env.observation_spec().dtype)
        self.assertEqual((84, 84, 4), env.observation_spec().shape)

        random_policy = random_tf_policy.RandomTFPolicy(
            env.time_step_spec(), env.action_spec())

        metrics = [
            AverageReturnMetric(batch_size=4),
            AverageEpisodeLengthMetric(batch_size=4),
            EnvironmentSteps(),
            NumberOfEpisodes()
        ]
        driver = dynamic_step_driver.DynamicStepDriver(env, random_policy,
                                                       metrics, 10000)
        driver.run(maximum_iterations=10000)
コード例 #3
0
 def ctor(game, env_id=None):
     return suite_mario.load(game=game,
                             state='Level1-1',
                             wrap_with_process=False)