Exemple #1
0
def make_eval_env(all_args):
    def get_env_fn(rank):
        def init_env():
            if all_args.env_name == "MPE":
                env = MPEEnv(all_args)
            else:
                print("Can not support the " +
                      all_args.env_name + "environment.")
                raise NotImplementedError
            env.seed(all_args.seed * 50000 + rank * 10000)
            return env
        return init_env
    if all_args.n_eval_rollout_threads == 1:
        return DummyVecEnv([get_env_fn(0)])
    else:
        return SubprocVecEnv([get_env_fn(i) for i in range(all_args.n_eval_rollout_threads)])
Exemple #2
0
def make_train_env(all_args):
    """
    the wrapper to instantiate the Highway env with multiple vehicles controlled by trained agents, Value Iteration based RL agent, training agent and rule-based agents (Intelligent Driver Model, IDM model).
    """
    def get_env_fn(rank):
        def init_env():
            if all_args.env_name == "Highway":
                env = HighwayEnv(all_args)
            else:
                print("Can not support the " +
                      all_args.env_name + "environment.")
                raise NotImplementedError
            env.seed(all_args.seed + rank * 1000)
            return env
        return init_env
    if all_args.n_rollout_threads == 1:
        return DummyVecEnv([get_env_fn(0)])
    else:
        return SubprocVecEnv([get_env_fn(i) for i in range(all_args.n_rollout_threads)])
Exemple #3
0
def make_train_env(all_args):
    def get_env_fn(rank):
        def init_env():
            if all_args.env_name == "Harvest":
                env = HarvestEnv(all_args)
            elif all_args.env_name == "Cleanup":
                env = CleanupEnv(all_args)
            else:
                print("Can not support the " + all_args.env_name +
                      "environment.")
                raise NotImplementedError
            env.seed(all_args.seed + rank * 1000)
            return env

        return init_env

    if all_args.n_rollout_threads == 1:
        return DummyVecEnv([get_env_fn(0)])
    else:
        return SubprocVecEnv(
            [get_env_fn(i) for i in range(all_args.n_rollout_threads)])
Exemple #4
0
def make_train_env(all_args):
    def get_env_fn(rank):
        def init_env():
            if True:
                agent_spec = AgentSpec(
                    interface=AgentInterface.from_type(
                        AgentType.Laner, max_episode_steps=all_args.episode_length
                    )
                )
                AGENT_ID = [str(i) for i in range(all_args.num_agents)]

                env = gym.make(
                    "smarts.env:hiway-v0",
                    scenarios=all_args.scenarios,
                    agent_specs={i: agent_spec for i in AGENT_ID},
                    headless=all_args.headless,
                    visdom=False,
                    timestep_sec=0.1,
                    sumo_headless=True,
                    seed=all_args.seed + rank * 1000,
                    # zoo_workers=[("143.110.210.157", 7432)], # Distribute social agents across these workers
                    auth_key=all_args.auth_key,
                    # envision_record_data_replay_path="./data_replay",
                )
                env = SmartWrapper(env, all_args.num_agents)
            else:
                print("Can not support the " +
                      all_args.env_name + "environment.")
                raise NotImplementedError
            # env.seed(all_args.seed + rank * 1000)
            return env

        return init_env

    if all_args.n_rollout_threads == 1:
        return DummyVecEnv([get_env_fn(0)])
    else:
        return SubprocVecEnv([get_env_fn(i) for i in range(all_args.n_rollout_threads)])
Exemple #5
0
if __name__ == "__main__":
    from onpolicy.config import get_config
    from onpolicy.envs.gfootball.gfootball_env import GoogleFootballEnv
    from onpolicy.envs.env_wrappers import SubprocVecEnv
    import gym
    gym.logger.set_level(gym.logger.ERROR)

    args = get_config().parse_known_args()[0]
    config = {
        'all_args':
        args,
        'envs':
        SubprocVecEnv([
            lambda: GoogleFootballEnv(num_of_left_agents=3,
                                      env_name='test_example_multiagent',
                                      representation="simple115v2",
                                      channel_dimensions=(48, 36))
            for i in range(args.n_rollout_threads)
        ]),
        'eval_envs':
        SubprocVecEnv([
            lambda: GoogleFootballEnv(num_of_left_agents=3,
                                      env_name='test_example_multiagent',
                                      representation="simple115v2",
                                      channel_dimensions=(48, 36))
            for i in range(args.n_eval_rollout_threads)
        ]),
        'device':
        None,
        'num_agents':
        3,