def vectorized(env_id, seed, num_processes, log_dir, add_timestep, sensors={DEFAULT_SENSOR_NAME: None}, addl_repeat_count=0, preprocessing_fn=None, env_specific_kwargs={}, vis_interval=20, visdom_name='main', visdom_log_file=None, visdom_server='localhost', visdom_port='8097', num_val_processes=0, gae_gamma=None): '''Returns vectorized environment. Either the simulator implements this (habitat) or 'vectorized' uses the call_to_run helper ''' simulator, scenario = env_id.split('_') if simulator.lower() in ['habitat']: # These simulators internally handle vectorization/distribution env = make_habitat_vector_env( scenario=scenario, num_processes=num_processes, preprocessing_fn=preprocessing_fn, log_dir=log_dir, num_val_processes=num_val_processes, vis_interval=vis_interval, visdom_name=visdom_name, visdom_log_file=visdom_log_file, visdom_server=visdom_server, visdom_port=visdom_port, seed=seed, **env_specific_kwargs) else: # These simulators must be manually vectorized envs = [ EnvFactory.call_to_run(env_id, seed, rank, log_dir, add_timestep, sensors=sensors, addl_repeat_count=addl_repeat_count, preprocessing_fn=preprocessing_fn, env_specific_kwargs=env_specific_kwargs, vis_interval=vis_interval, visdom_name=visdom_name, visdom_log_file=visdom_log_file, visdom_server=visdom_server, visdom_port=visdom_port, num_val_processes=num_val_processes, num_processes=num_processes) for rank in range(num_processes) ] if num_processes == 1: env = DummyVecEnv(envs) else: env = DistributedEnv.new(envs, gae_gamma=gae_gamma, distribution_method=DistributedEnv.distribution_schemes.vectorize) return env
def _thunk(): preprocessing_fn_implemented_inside_env = False logging_implemented_inside_env = False already_distributed = False if env_id.startswith("dm"): _, domain, task = env_id.split('.') env = dm_control2gym.make(domain_name=domain, task_name=task) elif env_id.startswith("Gibson"): env = GibsonEnv(env_id=env_id, gibson_config=gibson_config, blind=blind, blank_sensor=blank_sensor, start_locations_file=start_locations_file, target_dim=target_dim, **env_specific_kwargs) elif env_id.startswith("DummyGibson"): env = DummyGibsonEnv(env_id=env_id, gibson_config=gibson_config, blind=blind, blank_sensor=blank_sensor, start_locations_file=start_locations_file, target_dim=target_dim, **env_specific_kwargs) elif env_id.startswith("Doom"): env_specific_kwargs['repeat_count'] = addl_repeat_count + 1 num_train_processes = num_processes - num_val_processes # 1 (train only), 2 test only env_specific_kwargs['randomize_textures'] = 1 if rank < num_train_processes else 2 vizdoom_class = eval(scenario.split('.')[0]) env = vizdoom_class(**env_specific_kwargs) elif env_id.startswith("Habitat"): env = make_habitat_vector_env( num_processes=rank, target_dim=target_dim, preprocessing_fn=preprocessing_fn, log_dir=log_dir, num_val_processes=num_val_processes, visdom_name=visdom_name, visdom_log_file=visdom_log_file, visdom_server=visdom_server, visdom_port=visdom_port, seed=seed, **env_specific_kwargs) already_distributed = True preprocessing_fn_implemented_inside_env = True logging_implemented_inside_env = True else: env = gym.make(env_id) if already_distributed: # Env is now responsible for logging, preprocessing, repeat_count return env is_atari = hasattr(gym.envs, 'atari') and isinstance( env.unwrapped, gym.envs.atari.atari_env.AtariEnv) if is_atari: env = make_atari(env_id) if add_timestep: raise NotImplementedError("AddTimestep not implemented for SensorDict") obs_shape = env.observation_space.shape if add_timestep and len(obs_shape) == 1 \ and str(env).find('TimeLimit') > -1: env = AddTimestep(env) if not (logging_implemented_inside_env or log_dir is None): os.makedirs(os.path.join(log_dir, visdom_name), exist_ok=True) print("Visdom log file", visdom_log_file) first_val_process = num_processes - num_val_processes if (rank == 0 or rank == first_val_process) and visdom_log_file is not None: env = VisdomMonitor(env, directory=os.path.join(log_dir, visdom_name), video_callable=lambda x: x % vis_interval == 0, uid=str(rank), server=visdom_server, port=visdom_port, visdom_log_file=visdom_log_file, visdom_env=visdom_name) else: print("Not using visdom") env = wrappers.Monitor(env, directory=os.path.join(log_dir, visdom_name), uid=str(rank)) if is_atari: env = wrap_deepmind(env) if addl_repeat_count > 0: if not hasattr(env, 'repeat_count') and not hasattr(env.unwrapped, 'repeat_count'): env = SkipWrapper(repeat_count)(env) if sensors is not None: if hasattr(env, 'is_embodied') or hasattr(env.unwrapped, 'is_embodied'): pass else: assert len(sensors) == 1, 'Can only handle one sensor' sensor_name = list(sensors.keys())[0] env = SensorEnvWrapper(env, name=sensor_name) if not (preprocessing_fn_implemented_inside_env or preprocessing_fn is None): transform, space = preprocessing_fn(env.observation_space) env = ProcessObservationWrapper(env, transform, space) env.seed(seed + rank) return env