Ejemplo n.º 1
0
    def __init__(self, env_fns, start_method=None):
        self.waiting = False
        self.closed = False
        n_envs = len(env_fns)

        if start_method is None:
            # Fork is not a thread safe method (see issue #217)
            # but is more user friendly (does not require to wrap the code in
            # a `if __name__ == "__main__":`)
            forkserver_available = "forkserver" in multiprocessing.get_all_start_methods(
            )
            start_method = "forkserver" if forkserver_available else "spawn"
        ctx = multiprocessing.get_context(start_method)

        self.remotes, self.work_remotes = zip(
            *[ctx.Pipe() for _ in range(n_envs)])
        self.processes = []
        for work_remote, remote, env_fn in zip(self.work_remotes, self.remotes,
                                               env_fns):
            args = (work_remote, remote, CloudpickleWrapper(env_fn))
            # daemon=True: if the main process crashes, we should not cause things to hang
            process = ctx.Process(target=_worker, args=args, daemon=True)  # pytype:disable=attribute-error
            process.start()
            self.processes.append(process)
            work_remote.close()

        self.remotes[0].send(("get_spaces", None))
        observation_space, action_space = self.remotes[0].recv()
        VecEnv.__init__(self, len(env_fns), observation_space, action_space)
Ejemplo n.º 2
0
    def __init__(self,
                 env_fns: List[Tuple[Callable[[], Union[str, gym.Env]],
                                     Callable[[torch.Tensor],
                                              Union[torch.Tensor, np.ndarray,
                                                    list]]]]):
        """ Constructor

        Args:
            env_fns: a list of pairs of functions, the first of which create one environment and the second is the mapper
                that converts the action outputs from the model to what can be accepted by the env object. 
                See VectorizedMultiAgentEnvWrapper.MultiAgentEnvWrapper for details
        """
        self.envs = [
            self.MultiAgentEnvWrapper(fn_pair[0](), fn_pair[1])
            for fn_pair in env_fns
        ]
        env = self.envs[0]
        VecEnv.__init__(self, len(env_fns), env.observation_space,
                        env.action_space)
        obs_space = env.observation_space
        self.keys, shapes, dtypes = obs_space_info(obs_space)

        self.buf_obs = OrderedDict([(k,
                                     np.zeros(
                                         (self.num_envs, ) + tuple(shapes[k]),
                                         dtype=dtypes[k])) for k in self.keys])
        self.buf_dones = np.zeros((self.num_envs, 1), dtype=bool)
        self.buf_rews = np.zeros((self.num_envs, 1, env.agent_num),
                                 dtype=np.float32)
        self.buf_infos = [{} for _ in range(self.num_envs)]
        self.actions = None
        self.metadata = env.metadata
Ejemplo n.º 3
0
    def __init__(self, env_fns: List[Callable[[], gym.Env]]):
        self.envs = [fn() for fn in env_fns]
        env = self.envs[0]
        VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space)
        obs_space = env.observation_space
        self.keys, shapes, dtypes = obs_space_info(obs_space)

        self.buf_obs = OrderedDict([(k, np.zeros((self.num_envs,) + tuple(shapes[k]), dtype=dtypes[k])) for k in self.keys])
        self.buf_dones = np.zeros((self.num_envs,), dtype=np.bool)
        self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32)
        self.buf_infos = [{} for _ in range(self.num_envs)]
        self.actions = None
        self.metadata = env.metadata
Ejemplo n.º 4
0
    def __init__(self, env_fns):
        self.envs = [fn() for fn in env_fns]
        env = self.envs[0]
        boxes = {}
        # Make observation space 1 x d, representing a single node.
        action_space = graph_action_space()
        observation_space = graph_obs_space(env.observation_space)
        VecEnv.__init__(self, len(env_fns), observation_space, action_space)

        self.buf_dones = np.zeros((self.num_envs, ), dtype=np.bool)
        self.buf_rews = np.zeros((self.num_envs, ), dtype=np.float32)
        self.buf_infos = [{} for _ in range(self.num_envs)]
        self.actions = None
        self.metadata = env.metadata
Ejemplo n.º 5
0
    def __init__(
            self,
            venv: VecEnv,
            filename: Optional[str] = None,
            info_keywords: Tuple[str, ...] = (),
    ):
        # Avoid circular import
        from stable_baselines3.common.monitor import Monitor, ResultsWriter

        # This check is not valid for special `VecEnv`
        # like the ones created by Procgen, that does follow completely
        # the `VecEnv` interface
        try:
            is_wrapped_with_monitor = venv.env_is_wrapped(Monitor)[0]
        except AttributeError:
            is_wrapped_with_monitor = False

        if is_wrapped_with_monitor:
            warnings.warn(
                "The environment is already wrapped with a `Monitor` wrapper"
                "but you are wrapping it with a `VecMonitor` wrapper, the `Monitor` statistics will be"
                "overwritten by the `VecMonitor` ones.",
                UserWarning,
            )

        VecEnvWrapper.__init__(self, venv)
        self.episode_returns = None
        self.episode_lengths = None
        self.episode_count = 0
        self.t_start = time.time()

        env_id = None
        if hasattr(venv, "spec") and venv.spec is not None:
            env_id = venv.spec.id

        if filename:
            self.results_writer = ResultsWriter(filename,
                                                header={
                                                    "t_start": self.t_start,
                                                    "env_id": env_id
                                                },
                                                extra_keys=info_keywords)
        else:
            self.results_writer = None
        self.info_keywords = info_keywords