def make_env_func(cfg, env_config): env = create_env(cfg.env, cfg=cfg, env_config=env_config) if not is_multiagent_env(env): env = MultiAgentWrapper(env) if not isinstance(env.observation_space, spaces.Dict): env = DictObservationsWrapper(env) return env
def initialize(self): # creating an environment in the main process tends to fix some very weird issues further down the line # https://stackoverflow.com/questions/60963839/importing-opencv-after-importing-pytorch-messes-with-cpu-affinity # do not delete this unless you know what you're doing tmp_env = create_env(self.cfg.env, cfg=self.cfg, env_config=None) tmp_env.close() for i in range(self.cfg.num_workers): p = multiprocessing.Process(target=self.sample, args=(i, )) self.processes.append(p)
def main(): env_name = 'doom_battle' env = create_env(env_name, cfg=default_cfg(env=env_name)) env.reset() done = False while not done: env.render() obs, rew, done, info = env.step(env.action_space.sample()) log.info('Done!')
def test_minigrid_env(self): env_name = 'MiniGrid-Empty-Random-5x5-v0' env = create_env(env_name, cfg=default_cfg(env=env_name)) log.info('Env action space: %r', env.action_space) log.info('Env obs space: %r', env.observation_space) env.reset() total_rew = 0 for i in range(1000): obs, rew, done, info = env.step(env.action_space.sample()) total_rew += rew if done: env.reset()
def test_voxel_env(self): env_name = 'voxel_env_Sokoban' env = create_env(env_name, cfg=default_cfg(env=env_name)) log.info('Env action space: %r', env.action_space) log.info('Env obs space: %r', env.observation_space) env.reset() total_rew = 0 for i in range(1000): obs, rew, done, info = env.step( [env.action_space.sample() for _ in range(env.num_agents)]) total_rew += sum(rew) log.info('Total rew: %.3f', total_rew)
def forward_pass(device_type): env_name = 'atari_breakout' cfg = default_cfg(algo='APPO', env=env_name) cfg.actor_critic_share_weights = True cfg.hidden_size = 128 cfg.use_rnn = True cfg.env_framestack = 4 env = create_env(env_name, cfg=cfg) torch.set_num_threads(1) torch.backends.cudnn.benchmark = True actor_critic = create_actor_critic(cfg, env.observation_space, env.action_space) device = torch.device(device_type) actor_critic.to(device) timing = Timing() with timing.timeit('all'): batch = 128 with timing.add_time('input'): # better avoid hardcoding here... observations = dict( obs=torch.rand([batch, 4, 84, 84]).to(device)) rnn_states = torch.rand([batch, get_hidden_size(cfg)]).to(device) n = 200 for i in range(n): with timing.add_time('forward'): output = actor_critic(observations, rnn_states) log.debug('Progress %d/%d', i, n) log.debug('Timing: %s', timing)
def make_env_func(env_config): return create_env(cfg.env, cfg=cfg, env_config=env_config)
def sample(self, proc_idx): # workers should ignore Ctrl+C because the termination is handled in the event loop by a special msg signal.signal(signal.SIGINT, signal.SIG_IGN) timing = Timing() psutil.Process().nice(10) num_envs = len(DMLAB30_LEVELS_THAT_USE_LEVEL_CACHE) assert self.cfg.num_workers % num_envs == 0, f'should have an integer number of workers per env, e.g. {1 * num_envs}, {2 * num_envs}, etc...' assert self.cfg.num_envs_per_worker == 1, 'use populate_cache with 1 env per worker' with timing.timeit('env_init'): env_key = 'env' env_desired_num_levels = 0 global_env_id = proc_idx * self.cfg.num_envs_per_worker env_config = AttrDict(worker_index=proc_idx, vector_index=0, env_id=global_env_id) env = create_env(self.cfg.env, cfg=self.cfg, env_config=env_config) env.seed(global_env_id) # this is to track the performance for individual DMLab levels if hasattr(env.unwrapped, 'level_name'): env_key = env.unwrapped.level_name env_level = env.unwrapped.level approx_num_episodes_per_1b_frames = DMLAB30_APPROX_NUM_EPISODES_PER_BILLION_FRAMES[env_key] num_billions = DESIRED_TRAINING_LENGTH / int(1e9) num_workers_for_env = self.cfg.num_workers // num_envs env_desired_num_levels = int((approx_num_episodes_per_1b_frames * num_billions) / num_workers_for_env) env_num_levels_generated = len(dmlab_level_cache.DMLAB_GLOBAL_LEVEL_CACHE[0].all_seeds[env_level]) // num_workers_for_env log.warning('Worker %d (env %s) generated %d/%d levels!', proc_idx, env_key, env_num_levels_generated, env_desired_num_levels) time.sleep(4) env.reset() env_uses_level_cache = env.unwrapped.env_uses_level_cache self.report_queue.put(dict(proc_idx=proc_idx, finished_reset=True)) self.start_event.wait() try: with timing.timeit('work'): last_report = last_report_frames = total_env_frames = 0 while not self.terminate.value and total_env_frames < self.cfg.sample_env_frames_per_worker: action = env.action_space.sample() with timing.add_time(f'{env_key}.step'): env.step(action) total_env_frames += 1 with timing.add_time(f'{env_key}.reset'): env.reset() env_num_levels_generated += 1 log.debug('Env %s done %d/%d resets', env_key, env_num_levels_generated, env_desired_num_levels) if env_num_levels_generated >= env_desired_num_levels: log.debug('%s finished %d/%d resets, sleeping...', env_key, env_num_levels_generated, env_desired_num_levels) time.sleep(30) # free up CPU time for other envs # if env does not use level cache, there is no need to run it # let other workers proceed if not env_uses_level_cache: log.debug('Env %s does not require cache, sleeping...', env_key) time.sleep(200) with timing.add_time('report'): now = time.time() if now - last_report > self.report_every_sec: last_report = now frames_since_last_report = total_env_frames - last_report_frames last_report_frames = total_env_frames self.report_queue.put(dict(proc_idx=proc_idx, env_frames=frames_since_last_report)) if get_free_disk_space_mb(self.cfg) < 3 * 1024: log.error('Not enough disk space! %d', get_free_disk_space_mb(self.cfg)) time.sleep(200) except: log.exception('Unknown exception') log.error('Unknown exception in worker %d, terminating...', proc_idx) self.report_queue.put(dict(proc_idx=proc_idx, crash=True)) time.sleep(proc_idx * 0.1 + 0.1) log.info('Process %d finished sampling. Timing: %s', proc_idx, timing) env.close()
def make_env_func(env_config): return create_env(ENV_NAME, cfg=common_config, env_config=env_config)