def _worker(remote: mp.connection.Connection,
            parent_remote: mp.connection.Connection,
            manager_fn_wrapper: CloudpickleWrapper) -> None:
    parent_remote.close()

    torch.set_num_threads(1)
    evaluation_manager = manager_fn_wrapper.var()

    while True:
        try:
            cmd, data = remote.recv()
            if cmd == "run_eval_episodes":
                num_episodes = data
                winners = []
                num_game_steps = []
                victory_points_all = []
                policy_steps = []
                for ep in range(num_episodes):
                    winner, victory_points, total_steps, policy_decisions = evaluation_manager.run_evaluation_game(
                    )
                    winners.append(winner)
                    num_game_steps.append(total_steps)
                    victory_points_all.append(victory_points)
                    policy_steps.append(policy_decisions)
                remote.send((winners, num_game_steps, victory_points_all,
                             policy_steps))
            elif cmd == "update_policies":
                state_dicts = data.var
                evaluation_manager._update_policies(state_dicts)
                remote.send(True)
        except EOFError:
            break
def _worker(
    remote: mp.connection.Connection,
    parent_remote: mp.connection.Connection,
    env_fn_wrapper: CloudpickleWrapper,
    render: bool,
    render_mode: str,
) -> None:
    # Import here to avoid a circular import
    from stable_baselines3.common.env_util import is_wrapped

    parent_remote.close()
    env = env_fn_wrapper.var()
    while True:
        try:
            cmd, data = remote.recv()
            if cmd == "step":
                observation, reward, done, info = env.step(data)
                if render:
                    env.render(mode=render_mode)
                if done:
                    # save final observation where user can get it, then reset
                    info["terminal_observation"] = observation
                    observation = env.reset()
                    if render:
                        env.render(mode=render_mode)
                remote.send((observation, reward, done, info))
            elif cmd == "seed":
                remote.send(env.seed(data))
            elif cmd == "reset":
                observation = env.reset()
                if render:
                    env.render(mode=render_mode)
                remote.send(observation)
            elif cmd == "render":
                remote.send(env.render(data))
            elif cmd == "close":
                env.close()
                remote.close()
                break
            elif cmd == "get_spaces":
                remote.send((env.observation_space, env.action_space))
            elif cmd == "env_method":
                method = getattr(env, data[0])
                remote.send(method(*data[1], **data[2]))
            elif cmd == "get_attr":
                remote.send(getattr(env, data))
            elif cmd == "set_attr":
                remote.send(setattr(env, data[0], data[1]))
            elif cmd == "is_wrapped":
                remote.send(is_wrapped(env, data))
            else:
                raise NotImplementedError(
                    f"`{cmd}` is not implemented in the worker")
        except EOFError:
            break
Esempio n. 3
0
def _worker(
        remote: mp.connection.Connection, parent_remote: mp.connection.Connection, manager_fn_wrapper: CloudpickleWrapper
) -> None:
    parent_remote.close()

    torch.set_num_threads(1)
    game_manager = manager_fn_wrapper.var()

    while True:
        try:
            cmd, data = remote.recv()
            if cmd == "gather_rollouts":
                observations, hidden_states, rewards, actions, action_masks, \
                    action_log_probs, dones = game_manager.gather_rollouts()
                game_manager._after_rollouts()
                remote.send(
                    CloudpickleWrapper((observations, hidden_states, rewards, actions, action_masks,
                                        action_log_probs, dones))
                )
            elif cmd == "reset":
                game_manager.reset()
                remote.send(True)
            elif cmd == "close":
                remote.close()
                break
            elif cmd == "update_policy":
                state_dict = data[0].var
                policy_id = data[1]
                game_manager._update_policy(state_dict, policy_id)

                remote.send(True)
            elif cmd == "seed":
                np.random.seed(data)
                random.seed(data)
                torch.manual_seed(data)
                remote.send(True)
            else:
                raise NotImplementedError(f"`{cmd}` is not implemented in the worker")
        except EOFError:
            break