Ejemplo n.º 1
0
    def test_thread_env(self):
        self._env = thread_environment.ThreadEnvironment(
            lambda: suite_highway.load(environment_name='highway-v0'))
        self.assertIsInstance(self._env, alf_environment.AlfEnvironment)
        self.assertEqual(torch.float32, self._env.observation_spec().dtype)

        actions = self._env.action_spec().sample()
        for _ in range(10):
            time_step = self._env.step(actions)
Ejemplo n.º 2
0
    def test_thread_env(self):
        self._env = thread_environment.ThreadEnvironment(
            lambda: suite_safety_gym.load(environment_name=
                                          'Safexp-PointGoal1-v0'))
        self.assertIsInstance(self._env, alf_environment.AlfEnvironment)
        self.assertEqual(torch.float32, self._env.observation_spec().dtype)
        self.assertEqual((suite_safety_gym.VectorReward.REWARD_DIMENSION, ),
                         self._env.reward_spec().shape)

        actions = self._env.action_spec().sample()
        for _ in range(10):
            time_step = self._env.step(actions)
Ejemplo n.º 3
0
    def test_thread_env(self):
        game = 'SuperMarioBros-Nes'
        self._env = thread_environment.ThreadEnvironment(
            lambda: suite_mario.load(
                game=game, state='Level1-1', wrap_with_process=False))
        self.assertIsInstance(self._env, alf_environment.AlfEnvironment)
        self.assertEqual(torch.uint8, self._env.observation_spec().dtype)
        self.assertEqual((4, 84, 84), self._env.observation_spec().shape)

        actions = self._env.action_spec().sample()
        for _ in range(10):
            time_step = self._env.step(actions)
Ejemplo n.º 4
0
    def test_thread_env(self):
        env_name = 'SocialBot-CartPole-v0'
        self._env = thread_environment.ThreadEnvironment(
            lambda: suite_socialbot.load(environment_name=env_name,
                                         wrap_with_process=False))
        self.assertEqual(torch.float32, self._env.observation_spec().dtype)
        self.assertEqual((4, ), self._env.observation_spec().shape)
        self.assertEqual(torch.float32, self._env.action_spec().dtype)
        self.assertEqual((1, ), self._env.action_spec().shape)

        actions = self._env.action_spec().sample()
        for _ in range(10):
            time_step = self._env.step(actions)
Ejemplo n.º 5
0
    def test_thread_env(self):
        scene = 'lt_chasm'
        self._env = thread_environment.ThreadEnvironment(
            lambda: suite_dmlab.load(
                scene=scene,
                gym_env_wrappers=[
                    gym_wrappers.FrameGrayScale, gym_wrappers.FrameResize,
                    gym_wrappers.FrameStack
                ],
                wrap_with_process=False))
        self.assertIsInstance(self._env, alf_environment.AlfEnvironment)
        self.assertEqual((4, 84, 84), self._env.observation_spec().shape)

        for _ in range(10):
            actions = self._env.action_spec().sample()
            self._env.step(actions)
Ejemplo n.º 6
0
Archivo: utils.py Proyecto: zhuboli/alf
def create_environment(env_name='CartPole-v0',
                       env_load_fn=suite_gym.load,
                       num_parallel_environments=30,
                       nonparallel=False,
                       seed=None):
    """Create a batched environment.

    Args:
        env_name (str): env name
        env_load_fn (Callable) : callable that create an environment
            If env_load_fn has attribute ``batched`` and it is True,
            ``evn_load_fn(env_name, batch_size=num_parallel_environments)``
            will be used to create the batched environment. Otherwise, a
            ``ParallAlfEnvironment`` will be created.
        num_parallel_environments (int): num of parallel environments
        nonparallel (bool): force to create a single env in the current
            process. Used for correctly exposing game gin confs to tensorboard.

    Returns:
        AlfEnvironment:
    """

    if hasattr(env_load_fn, 'batched') and env_load_fn.batched:
        if nonparallel:
            return env_load_fn(env_name, batch_size=1)
        else:
            return env_load_fn(env_name, batch_size=num_parallel_environments)

    if nonparallel:
        # Each time we can only create one unwrapped env at most

        # Create and step the env in a separate thread. env `step` and `reset` must
        #   run in the same thread which the env is created in for some simulation
        #   environments such as social_bot(gazebo)
        alf_env = thread_environment.ThreadEnvironment(
            lambda: env_load_fn(env_name))
        if seed is None:
            alf_env.seed(np.random.randint(0, np.iinfo(np.int32).max))
        else:
            alf_env.seed(seed)
    else:
        # flatten=True will use flattened action and time_step in
        #   process environments to reduce communication overhead.
        alf_env = parallel_environment.ParallelAlfEnvironment(
            [functools.partial(env_load_fn, env_name)] *
            num_parallel_environments,
            flatten=True)

        if seed is None:
            alf_env.seed([
                np.random.randint(0,
                                  np.iinfo(np.int32).max)
                for i in range(num_parallel_environments)
            ])
        else:
            # We want deterministic behaviors for each environment, but different
            # behaviors among different individual environments (to increase the
            # diversity of environment data)!
            alf_env.seed([seed + i for i in range(num_parallel_environments)])

    return alf_env