def update_policy(self, state_dict, process_id = None, policy_id = 0): if process_id is None: for remote in self.remotes: remote.send(("update_policy", (CloudpickleWrapper(state_dict), policy_id))) results = [remote.recv() for remote in self.remotes] else: self.remotes[process_id].send(("update_policy", (CloudpickleWrapper(state_dict), policy_id))) results = self.remotes[process_id].recv() return results
def _worker(remote: mp.connection.Connection, parent_remote: mp.connection.Connection, manager_fn_wrapper: CloudpickleWrapper) -> None: parent_remote.close() torch.set_num_threads(1) evaluation_manager = manager_fn_wrapper.var() while True: try: cmd, data = remote.recv() if cmd == "run_eval_episodes": num_episodes = data winners = [] num_game_steps = [] victory_points_all = [] policy_steps = [] for ep in range(num_episodes): winner, victory_points, total_steps, policy_decisions = evaluation_manager.run_evaluation_game( ) winners.append(winner) num_game_steps.append(total_steps) victory_points_all.append(victory_points) policy_steps.append(policy_decisions) remote.send((winners, num_game_steps, victory_points_all, policy_steps)) elif cmd == "update_policies": state_dicts = data.var evaluation_manager._update_policies(state_dicts) remote.send(True) except EOFError: break
def __init__(self, env_fns, start_method=None): self.waiting = False self.closed = False n_envs = len(env_fns) if start_method is None: # Fork is not a thread safe method (see issue #217) # but is more user friendly (does not require to wrap the code in # a `if __name__ == "__main__":`) forkserver_available = "forkserver" in multiprocessing.get_all_start_methods( ) start_method = "forkserver" if forkserver_available else "spawn" ctx = multiprocessing.get_context(start_method) self.remotes, self.work_remotes = zip( *[ctx.Pipe() for _ in range(n_envs)]) self.processes = [] for work_remote, remote, env_fn in zip(self.work_remotes, self.remotes, env_fns): args = (work_remote, remote, CloudpickleWrapper(env_fn)) # daemon=True: if the main process crashes, we should not cause things to hang process = ctx.Process(target=_worker, args=args, daemon=True) # pytype:disable=attribute-error process.start() self.processes.append(process) work_remote.close() self.remotes[0].send(("get_spaces", None)) observation_space, action_space = self.remotes[0].recv() VecEnv.__init__(self, len(env_fns), observation_space, action_space)
def _worker( remote: mp.connection.Connection, parent_remote: mp.connection.Connection, env_fn_wrapper: CloudpickleWrapper, render: bool, render_mode: str, ) -> None: # Import here to avoid a circular import from stable_baselines3.common.env_util import is_wrapped parent_remote.close() env = env_fn_wrapper.var() while True: try: cmd, data = remote.recv() if cmd == "step": observation, reward, done, info = env.step(data) if render: env.render(mode=render_mode) if done: # save final observation where user can get it, then reset info["terminal_observation"] = observation observation = env.reset() if render: env.render(mode=render_mode) remote.send((observation, reward, done, info)) elif cmd == "seed": remote.send(env.seed(data)) elif cmd == "reset": observation = env.reset() if render: env.render(mode=render_mode) remote.send(observation) elif cmd == "render": remote.send(env.render(data)) elif cmd == "close": env.close() remote.close() break elif cmd == "get_spaces": remote.send((env.observation_space, env.action_space)) elif cmd == "env_method": method = getattr(env, data[0]) remote.send(method(*data[1], **data[2])) elif cmd == "get_attr": remote.send(getattr(env, data)) elif cmd == "set_attr": remote.send(setattr(env, data[0], data[1])) elif cmd == "is_wrapped": remote.send(is_wrapped(env, data)) else: raise NotImplementedError( f"`{cmd}` is not implemented in the worker") except EOFError: break
def _worker( remote: mp.connection.Connection, parent_remote: mp.connection.Connection, manager_fn_wrapper: CloudpickleWrapper ) -> None: parent_remote.close() torch.set_num_threads(1) game_manager = manager_fn_wrapper.var() while True: try: cmd, data = remote.recv() if cmd == "gather_rollouts": observations, hidden_states, rewards, actions, action_masks, \ action_log_probs, dones = game_manager.gather_rollouts() game_manager._after_rollouts() remote.send( CloudpickleWrapper((observations, hidden_states, rewards, actions, action_masks, action_log_probs, dones)) ) elif cmd == "reset": game_manager.reset() remote.send(True) elif cmd == "close": remote.close() break elif cmd == "update_policy": state_dict = data[0].var policy_id = data[1] game_manager._update_policy(state_dict, policy_id) remote.send(True) elif cmd == "seed": np.random.seed(data) random.seed(data) torch.manual_seed(data) remote.send(True) else: raise NotImplementedError(f"`{cmd}` is not implemented in the worker") except EOFError: break
def __init__(self, game_manager_fns, start_method = None): self.waiting = False self.closed = False n_processes = len(game_manager_fns) if start_method is None: forkserver_available = "forkserver" in mp.get_all_start_methods() start_method = "forkserver" if forkserver_available else "spawn" ctx = mp.get_context(start_method) self.remotes, self.work_remotes = zip(*[ctx.Pipe() for _ in range(n_processes)]) self.processes = [] for work_remote, remote, game_manager_fn in zip(self.work_remotes, self.remotes, game_manager_fns): args = (work_remote, remote, CloudpickleWrapper(game_manager_fn)) process = ctx.Process(target=_worker, args=args, daemon=True) process.start() self.processes.append(process) work_remote.close()
def update_policies(self, state_dicts): for remote in self.remotes: remote.send(("update_policies", CloudpickleWrapper(state_dicts))) results = [remote.recv() for remote in self.remotes] return results