def make_eval_env(all_args): def get_env_fn(rank): def init_env(): if all_args.env_name == "MPE": env = MPEEnv(all_args) else: print("Can not support the " + all_args.env_name + "environment.") raise NotImplementedError env.seed(all_args.seed * 50000 + rank * 10000) return env return init_env if all_args.n_eval_rollout_threads == 1: return DummyVecEnv([get_env_fn(0)]) else: return SubprocVecEnv([get_env_fn(i) for i in range(all_args.n_eval_rollout_threads)])
def make_train_env(all_args): """ the wrapper to instantiate the Highway env with multiple vehicles controlled by trained agents, Value Iteration based RL agent, training agent and rule-based agents (Intelligent Driver Model, IDM model). """ def get_env_fn(rank): def init_env(): if all_args.env_name == "Highway": env = HighwayEnv(all_args) else: print("Can not support the " + all_args.env_name + "environment.") raise NotImplementedError env.seed(all_args.seed + rank * 1000) return env return init_env if all_args.n_rollout_threads == 1: return DummyVecEnv([get_env_fn(0)]) else: return SubprocVecEnv([get_env_fn(i) for i in range(all_args.n_rollout_threads)])
def make_train_env(all_args): def get_env_fn(rank): def init_env(): if all_args.env_name == "Harvest": env = HarvestEnv(all_args) elif all_args.env_name == "Cleanup": env = CleanupEnv(all_args) else: print("Can not support the " + all_args.env_name + "environment.") raise NotImplementedError env.seed(all_args.seed + rank * 1000) return env return init_env if all_args.n_rollout_threads == 1: return DummyVecEnv([get_env_fn(0)]) else: return SubprocVecEnv( [get_env_fn(i) for i in range(all_args.n_rollout_threads)])
def make_train_env(all_args): def get_env_fn(rank): def init_env(): if True: agent_spec = AgentSpec( interface=AgentInterface.from_type( AgentType.Laner, max_episode_steps=all_args.episode_length ) ) AGENT_ID = [str(i) for i in range(all_args.num_agents)] env = gym.make( "smarts.env:hiway-v0", scenarios=all_args.scenarios, agent_specs={i: agent_spec for i in AGENT_ID}, headless=all_args.headless, visdom=False, timestep_sec=0.1, sumo_headless=True, seed=all_args.seed + rank * 1000, # zoo_workers=[("143.110.210.157", 7432)], # Distribute social agents across these workers auth_key=all_args.auth_key, # envision_record_data_replay_path="./data_replay", ) env = SmartWrapper(env, all_args.num_agents) else: print("Can not support the " + all_args.env_name + "environment.") raise NotImplementedError # env.seed(all_args.seed + rank * 1000) return env return init_env if all_args.n_rollout_threads == 1: return DummyVecEnv([get_env_fn(0)]) else: return SubprocVecEnv([get_env_fn(i) for i in range(all_args.n_rollout_threads)])
if __name__ == "__main__": from onpolicy.config import get_config from onpolicy.envs.gfootball.gfootball_env import GoogleFootballEnv from onpolicy.envs.env_wrappers import SubprocVecEnv import gym gym.logger.set_level(gym.logger.ERROR) args = get_config().parse_known_args()[0] config = { 'all_args': args, 'envs': SubprocVecEnv([ lambda: GoogleFootballEnv(num_of_left_agents=3, env_name='test_example_multiagent', representation="simple115v2", channel_dimensions=(48, 36)) for i in range(args.n_rollout_threads) ]), 'eval_envs': SubprocVecEnv([ lambda: GoogleFootballEnv(num_of_left_agents=3, env_name='test_example_multiagent', representation="simple115v2", channel_dimensions=(48, 36)) for i in range(args.n_eval_rollout_threads) ]), 'device': None, 'num_agents': 3,