Beispiel #1
0
    def __init__(self,
                 env_constructors,
                 start_serially=True,
                 blocking=False,
                 flatten=True):
        """
        Args:
            env_constructors (list[Callable]): a list of callable environment creators.
            start_serially (bool): whether to start environments serially or in parallel.
            blocking (bool): whether to step environments one after another.
            flatten (bool): whether to use flatten action and time_steps during
                communication to reduce overhead.

        Raises:
            ValueError: If the action or observation specs don't match.
        """
        super(ParallelAlfEnvironment, self).__init__()
        self._envs = []
        self._env_ids = []
        for env_id, ctor in enumerate(env_constructors):
            env = ProcessEnvironment(ctor, env_id=env_id, flatten=flatten)
            self._envs.append(env)
            self._env_ids.append(env_id)
        self._num_envs = len(env_constructors)
        self._blocking = blocking
        self._start_serially = start_serially
        self.start()
        self._action_spec = self._envs[0].action_spec()
        self._observation_spec = self._envs[0].observation_spec()
        self._reward_spec = self._envs[0].reward_spec()
        self._time_step_spec = self._envs[0].time_step_spec()
        self._env_info_spec = self._envs[0].env_info_spec()
        self._num_tasks = self._envs[0].num_tasks
        self._task_names = self._envs[0].task_names
        self._time_step_with_env_info_spec = self._time_step_spec._replace(
            env_info=self._env_info_spec)
        self._parallel_execution = True
        if any(env.action_spec() != self._action_spec for env in self._envs):
            raise ValueError(
                'All environments must have the same action spec.')
        if any(env.time_step_spec() != self._time_step_spec
               for env in self._envs):
            raise ValueError(
                'All environments must have the same time_step_spec.')
        self._flatten = flatten
 def test_close_no_hang_after_init(self):
     constructor = functools.partial(
         RandomAlfEnvironment,
         ts.TensorSpec((3, 3), torch.float32),
         ts.BoundedTensorSpec([1], torch.float32, minimum=-1.0,
                              maximum=1.0),
         episode_end_probability=0,
         min_duration=2,
         max_duration=2)
     env = ProcessEnvironment(constructor)
     env.start()
     env.close()
 def test_reraise_exception_in_step(self):
     crash_at_step = 3
     constructor = functools.partial(MockEnvironmentCrashInStep,
                                     crash_at_step)
     env = ProcessEnvironment(constructor)
     env.start()
     env.reset()
     action_spec = env.action_spec()
     env.step(action_spec.sample())
     env.step(action_spec.sample())
     with self.assertRaises(Exception):
         env.step(action_spec.sample())
 def test_reraise_exception_in_reset(self):
     constructor = MockEnvironmentCrashInReset
     env = ProcessEnvironment(constructor)
     env.start()
     with self.assertRaises(Exception):
         env.reset()