示例#1
0
def make_eval_env(all_args):
    def get_env_fn(rank):
        def init_env():
            if all_args.env_name == "MPE":
                env = MPEEnv(all_args)
            else:
                print("Can not support the " +
                      all_args.env_name + "environment.")
                raise NotImplementedError
            env.seed(all_args.seed * 50000 + rank * 10000)
            return env
        return init_env
    if all_args.n_eval_rollout_threads == 1:
        return DummyVecEnv([get_env_fn(0)])
    else:
        return SubprocVecEnv([get_env_fn(i) for i in range(all_args.n_eval_rollout_threads)])
示例#2
0
def make_train_env(all_args):
    """
    the wrapper to instantiate the Highway env with multiple vehicles controlled by trained agents, Value Iteration based RL agent, training agent and rule-based agents (Intelligent Driver Model, IDM model).
    """
    def get_env_fn(rank):
        def init_env():
            if all_args.env_name == "Highway":
                env = HighwayEnv(all_args)
            else:
                print("Can not support the " +
                      all_args.env_name + "environment.")
                raise NotImplementedError
            env.seed(all_args.seed + rank * 1000)
            return env
        return init_env
    if all_args.n_rollout_threads == 1:
        return DummyVecEnv([get_env_fn(0)])
    else:
        return SubprocVecEnv([get_env_fn(i) for i in range(all_args.n_rollout_threads)])
示例#3
0
def make_parallel_env(args):
    def get_env_fn(rank):
        def init_env():
            if args.env_name == "cleanup":
                env = CleanupEnv(num_agents=args.num_agents)
            elif args.env_name == "harvest":
                env = HarvestEnv(num_agents=args.num_agents)
            else:
                print("Can not support the " + args.env_name + "environment.")
                raise NotImplementedError
            env.seed(args.seed + rank * 1000)
            return env

        return init_env

    if args.n_rollout_threads == 1:
        return DummyVecEnv([get_env_fn(0)])
    else:
        return SubprocVecEnv(
            [get_env_fn(i) for i in range(args.n_rollout_threads)])
示例#4
0
def make_train_env(all_args):
    def get_env_fn(rank):
        def init_env():
            if True:
                agent_spec = AgentSpec(
                    interface=AgentInterface.from_type(
                        AgentType.Laner, max_episode_steps=all_args.episode_length
                    )
                )
                AGENT_ID = [str(i) for i in range(all_args.num_agents)]

                env = gym.make(
                    "smarts.env:hiway-v0",
                    scenarios=all_args.scenarios,
                    agent_specs={i: agent_spec for i in AGENT_ID},
                    headless=all_args.headless,
                    visdom=False,
                    timestep_sec=0.1,
                    sumo_headless=True,
                    seed=all_args.seed + rank * 1000,
                    # zoo_workers=[("143.110.210.157", 7432)], # Distribute social agents across these workers
                    auth_key=all_args.auth_key,
                    # envision_record_data_replay_path="./data_replay",
                )
                env = SmartWrapper(env, all_args.num_agents)
            else:
                print("Can not support the " +
                      all_args.env_name + "environment.")
                raise NotImplementedError
            # env.seed(all_args.seed + rank * 1000)
            return env

        return init_env

    if all_args.n_rollout_threads == 1:
        return DummyVecEnv([get_env_fn(0)])
    else:
        return SubprocVecEnv([get_env_fn(i) for i in range(all_args.n_rollout_threads)])