示例#1
0
def make_env_func(cfg, env_config):
    env = create_env(cfg.env, cfg=cfg, env_config=env_config)
    if not is_multiagent_env(env):
        env = MultiAgentWrapper(env)
    if not isinstance(env.observation_space, spaces.Dict):
        env = DictObservationsWrapper(env)
    return env
示例#2
0
    def initialize(self):
        # creating an environment in the main process tends to fix some very weird issues further down the line
        # https://stackoverflow.com/questions/60963839/importing-opencv-after-importing-pytorch-messes-with-cpu-affinity
        # do not delete this unless you know what you're doing
        tmp_env = create_env(self.cfg.env, cfg=self.cfg, env_config=None)
        tmp_env.close()

        for i in range(self.cfg.num_workers):
            p = multiprocessing.Process(target=self.sample, args=(i, ))
            self.processes.append(p)
示例#3
0
def main():
    env_name = 'doom_battle'
    env = create_env(env_name, cfg=default_cfg(env=env_name))

    env.reset()
    done = False
    while not done:
        env.render()
        obs, rew, done, info = env.step(env.action_space.sample())

    log.info('Done!')
示例#4
0
    def test_minigrid_env(self):
        env_name = 'MiniGrid-Empty-Random-5x5-v0'
        env = create_env(env_name, cfg=default_cfg(env=env_name))
        log.info('Env action space: %r', env.action_space)
        log.info('Env obs space: %r', env.observation_space)

        env.reset()
        total_rew = 0
        for i in range(1000):
            obs, rew, done, info = env.step(env.action_space.sample())
            total_rew += rew
            if done:
                env.reset()
    def test_voxel_env(self):
        env_name = 'voxel_env_Sokoban'
        env = create_env(env_name, cfg=default_cfg(env=env_name))
        log.info('Env action space: %r', env.action_space)
        log.info('Env obs space: %r', env.observation_space)

        env.reset()
        total_rew = 0
        for i in range(1000):
            obs, rew, done, info = env.step(
                [env.action_space.sample() for _ in range(env.num_agents)])
            total_rew += sum(rew)

        log.info('Total rew: %.3f', total_rew)
示例#6
0
    def forward_pass(device_type):
        env_name = 'atari_breakout'
        cfg = default_cfg(algo='APPO', env=env_name)
        cfg.actor_critic_share_weights = True
        cfg.hidden_size = 128
        cfg.use_rnn = True
        cfg.env_framestack = 4

        env = create_env(env_name, cfg=cfg)

        torch.set_num_threads(1)
        torch.backends.cudnn.benchmark = True

        actor_critic = create_actor_critic(cfg, env.observation_space,
                                           env.action_space)
        device = torch.device(device_type)
        actor_critic.to(device)

        timing = Timing()
        with timing.timeit('all'):
            batch = 128
            with timing.add_time('input'):
                # better avoid hardcoding here...
                observations = dict(
                    obs=torch.rand([batch, 4, 84, 84]).to(device))
                rnn_states = torch.rand([batch,
                                         get_hidden_size(cfg)]).to(device)

            n = 200
            for i in range(n):
                with timing.add_time('forward'):
                    output = actor_critic(observations, rnn_states)

                log.debug('Progress %d/%d', i, n)

        log.debug('Timing: %s', timing)
示例#7
0
 def make_env_func(env_config):
     return create_env(cfg.env, cfg=cfg, env_config=env_config)
示例#8
0
    def sample(self, proc_idx):
        # workers should ignore Ctrl+C because the termination is handled in the event loop by a special msg
        signal.signal(signal.SIGINT, signal.SIG_IGN)

        timing = Timing()

        psutil.Process().nice(10)

        num_envs = len(DMLAB30_LEVELS_THAT_USE_LEVEL_CACHE)
        assert self.cfg.num_workers % num_envs == 0, f'should have an integer number of workers per env, e.g. {1 * num_envs}, {2 * num_envs}, etc...'
        assert self.cfg.num_envs_per_worker == 1, 'use populate_cache with 1 env per worker'

        with timing.timeit('env_init'):
            env_key = 'env'
            env_desired_num_levels = 0

            global_env_id = proc_idx * self.cfg.num_envs_per_worker
            env_config = AttrDict(worker_index=proc_idx, vector_index=0, env_id=global_env_id)
            env = create_env(self.cfg.env, cfg=self.cfg, env_config=env_config)
            env.seed(global_env_id)

            # this is to track the performance for individual DMLab levels
            if hasattr(env.unwrapped, 'level_name'):
                env_key = env.unwrapped.level_name
                env_level = env.unwrapped.level

                approx_num_episodes_per_1b_frames = DMLAB30_APPROX_NUM_EPISODES_PER_BILLION_FRAMES[env_key]
                num_billions = DESIRED_TRAINING_LENGTH / int(1e9)
                num_workers_for_env = self.cfg.num_workers // num_envs
                env_desired_num_levels = int((approx_num_episodes_per_1b_frames * num_billions) / num_workers_for_env)

                env_num_levels_generated = len(dmlab_level_cache.DMLAB_GLOBAL_LEVEL_CACHE[0].all_seeds[env_level]) // num_workers_for_env

                log.warning('Worker %d (env %s) generated %d/%d levels!', proc_idx, env_key, env_num_levels_generated, env_desired_num_levels)
                time.sleep(4)

            env.reset()
            env_uses_level_cache = env.unwrapped.env_uses_level_cache

            self.report_queue.put(dict(proc_idx=proc_idx, finished_reset=True))

        self.start_event.wait()

        try:
            with timing.timeit('work'):
                last_report = last_report_frames = total_env_frames = 0
                while not self.terminate.value and total_env_frames < self.cfg.sample_env_frames_per_worker:
                    action = env.action_space.sample()
                    with timing.add_time(f'{env_key}.step'):
                        env.step(action)

                    total_env_frames += 1

                    with timing.add_time(f'{env_key}.reset'):
                        env.reset()
                        env_num_levels_generated += 1
                        log.debug('Env %s done %d/%d resets', env_key, env_num_levels_generated, env_desired_num_levels)

                    if env_num_levels_generated >= env_desired_num_levels:
                        log.debug('%s finished %d/%d resets, sleeping...', env_key, env_num_levels_generated, env_desired_num_levels)
                        time.sleep(30)  # free up CPU time for other envs

                    # if env does not use level cache, there is no need to run it
                    # let other workers proceed
                    if not env_uses_level_cache:
                        log.debug('Env %s does not require cache, sleeping...', env_key)
                        time.sleep(200)

                    with timing.add_time('report'):
                        now = time.time()
                        if now - last_report > self.report_every_sec:
                            last_report = now
                            frames_since_last_report = total_env_frames - last_report_frames
                            last_report_frames = total_env_frames
                            self.report_queue.put(dict(proc_idx=proc_idx, env_frames=frames_since_last_report))

                            if get_free_disk_space_mb(self.cfg) < 3 * 1024:
                                log.error('Not enough disk space! %d', get_free_disk_space_mb(self.cfg))
                                time.sleep(200)
        except:
            log.exception('Unknown exception')
            log.error('Unknown exception in worker %d, terminating...', proc_idx)
            self.report_queue.put(dict(proc_idx=proc_idx, crash=True))

        time.sleep(proc_idx * 0.1 + 0.1)
        log.info('Process %d finished sampling. Timing: %s', proc_idx, timing)

        env.close()
示例#9
0
 def make_env_func(env_config):
     return create_env(ENV_NAME, cfg=common_config, env_config=env_config)