def make_eval_env(all_args): def get_env_fn(rank): def init_env(): if all_args.env_name == "MPE": env = MPEEnv(all_args) else: print("Can not support the " + all_args.env_name + "environment.") raise NotImplementedError env.seed(all_args.seed * 50000 + rank * 10000) return env return init_env if all_args.n_eval_rollout_threads == 1: return DummyVecEnv([get_env_fn(0)]) else: return SubprocVecEnv([get_env_fn(i) for i in range(all_args.n_eval_rollout_threads)])
def make_train_env(all_args): """ the wrapper to instantiate the Highway env with multiple vehicles controlled by trained agents, Value Iteration based RL agent, training agent and rule-based agents (Intelligent Driver Model, IDM model). """ def get_env_fn(rank): def init_env(): if all_args.env_name == "Highway": env = HighwayEnv(all_args) else: print("Can not support the " + all_args.env_name + "environment.") raise NotImplementedError env.seed(all_args.seed + rank * 1000) return env return init_env if all_args.n_rollout_threads == 1: return DummyVecEnv([get_env_fn(0)]) else: return SubprocVecEnv([get_env_fn(i) for i in range(all_args.n_rollout_threads)])
def make_parallel_env(args): def get_env_fn(rank): def init_env(): if args.env_name == "cleanup": env = CleanupEnv(num_agents=args.num_agents) elif args.env_name == "harvest": env = HarvestEnv(num_agents=args.num_agents) else: print("Can not support the " + args.env_name + "environment.") raise NotImplementedError env.seed(args.seed + rank * 1000) return env return init_env if args.n_rollout_threads == 1: return DummyVecEnv([get_env_fn(0)]) else: return SubprocVecEnv( [get_env_fn(i) for i in range(args.n_rollout_threads)])
def make_train_env(all_args): def get_env_fn(rank): def init_env(): if True: agent_spec = AgentSpec( interface=AgentInterface.from_type( AgentType.Laner, max_episode_steps=all_args.episode_length ) ) AGENT_ID = [str(i) for i in range(all_args.num_agents)] env = gym.make( "smarts.env:hiway-v0", scenarios=all_args.scenarios, agent_specs={i: agent_spec for i in AGENT_ID}, headless=all_args.headless, visdom=False, timestep_sec=0.1, sumo_headless=True, seed=all_args.seed + rank * 1000, # zoo_workers=[("143.110.210.157", 7432)], # Distribute social agents across these workers auth_key=all_args.auth_key, # envision_record_data_replay_path="./data_replay", ) env = SmartWrapper(env, all_args.num_agents) else: print("Can not support the " + all_args.env_name + "environment.") raise NotImplementedError # env.seed(all_args.seed + rank * 1000) return env return init_env if all_args.n_rollout_threads == 1: return DummyVecEnv([get_env_fn(0)]) else: return SubprocVecEnv([get_env_fn(i) for i in range(all_args.n_rollout_threads)])