Example #1
0
    def reset(self):
        if self._episode_recording_dir is not None and self._record_id > 0:
            # save actions to text file
            with open(join(self._episode_recording_dir, 'actions.json'),
                      'w') as actions_file:
                json.dump(self._recorded_actions, actions_file)

            # rename previous episode dir
            reward = self._recorded_episode_reward + self._recorded_episode_shaping_reward
            new_dir_name = self._episode_recording_dir + f'_r{reward:.2f}'
            os.rename(self._episode_recording_dir, new_dir_name)
            log.info(
                'Finished recording %s (rew %.3f, shaping %.3f)',
                new_dir_name,
                reward,
                self._recorded_episode_shaping_reward,
            )

        dir_name = f'ep_{self._record_id:03d}_p{self._player_id}'
        self._episode_recording_dir = join(self._record_to, dir_name)
        ensure_dir_exists(self._episode_recording_dir)

        self._record_id += 1
        self._frame_id = 0
        self._recorded_episode_reward = 0
        self._recorded_episode_shaping_reward = 0

        self._recorded_actions = []

        return self.env.reset()
Example #2
0
def safe_get(q, timeout=1e6, msg='Queue timeout'):
    """Using queue.get() with timeout is necessary, otherwise KeyboardInterrupt is not handled."""
    while True:
        try:
            return q.get(timeout=timeout)
        except Empty:
            log.info('Queue timed out (%s), timeout %.3f', msg, timeout)
Example #3
0
    def doom_multiagent(make_multi_env, worker_index, num_steps=1000):
        env_config = AttrDict({'worker_index': worker_index, 'vector_index': 0, 'safe_init': False})
        multi_env = make_multi_env(env_config)

        obs = multi_env.reset()

        visualize = False
        start = time.time()

        for i in range(num_steps):
            actions = [multi_env.action_space.sample()] * len(obs)
            obs, rew, dones, infos = multi_env.step(actions)

            if visualize:
                multi_env.render()

            if i % 100 == 0 or any(dones):
                log.info('Rew %r done %r info %r', rew, dones, infos)

            if all(dones):
                multi_env.reset()

        took = time.time() - start
        log.info('Took %.3f seconds for %d steps', took, num_steps)
        log.info('Server steps per second: %.1f', num_steps / took)
        log.info('Observations fps: %.1f', num_steps * multi_env.num_agents / took)
        log.info('Environment fps: %.1f', num_steps * multi_env.num_agents * multi_env.skip_frames / took)

        multi_env.close()
Example #4
0
    def __init__(self, cfg):
        self.cfg = cfg

        if self.cfg.seed is not None:
            log.info('Settings fixed seed %d', self.cfg.seed)
            torch.manual_seed(self.cfg.seed)
            np.random.seed(self.cfg.seed)

        self.device = torch.device('cuda')

        self.train_step = self.env_steps = 0

        self.total_train_seconds = 0
        self.last_training_step = time.time()

        self.best_avg_reward = math.nan

        self.summary_rate_decay = LinearDecay([(0, 100), (1000000, 2000),
                                               (10000000, 10000)])
        self.last_summary_written = -1e9
        self.save_rate_decay = LinearDecay([(0, self.cfg.initial_save_rate),
                                            (1000000, 5000)],
                                           staircase=100)

        summary_dir = summaries_dir(experiment_dir(cfg=self.cfg))
        self.writer = SummaryWriter(summary_dir, flush_secs=10)
Example #5
0
 def _load_state(self, checkpoint_dict):
     self.train_step = checkpoint_dict['train_step']
     self.env_steps = checkpoint_dict['env_steps']
     self.best_avg_reward = checkpoint_dict['best_avg_reward']
     self.total_train_seconds = checkpoint_dict['total_train_seconds']
     log.info(
         'Loaded experiment state at training iteration %d, env step %d',
         self.train_step, self.env_steps)
Example #6
0
    def close(self):
        log.info('Stopping multi env wrapper...')

        for worker in self.workers:
            worker.task_queue.put((None, MsgType.TERMINATE))
            time.sleep(0.1)
        for worker in self.workers:
            worker.process.join()
Example #7
0
 def _init(self, init_info):
     log.info('Initializing env for player %d, init_info: %r...',
              self.player_id, init_info)
     env = init_multiplayer_env(self.make_env_func, self.player_id,
                                self.env_config, init_info)
     if self.reset_on_init:
         env.reset()
     return env
Example #8
0
 def _on_finished_training(self):
     """This is called after normal termination, e.g. number of training steps reached."""
     log.info(
         'Finished training at train_steps %d, env_steps %d, seconds %d',
         self.train_step,
         self.env_steps,
         self.total_train_seconds,
     )
     self._save()
Example #9
0
    def _vizdoom_variables_bug_workaround(self, info, done):
        """Some variables don't get reset to zero on game.new_episode(). This fixes it (also check overflow?)."""
        if done and 'DAMAGECOUNT' in info:
            log.info('DAMAGECOUNT value on done: %r', info.get('DAMAGECOUNT'))

        if self._last_episode_info is not None:
            bugged_vars = ['DEATHCOUNT', 'HITCOUNT', 'DAMAGECOUNT']
            for v in bugged_vars:
                if v in info:
                    info[v] -= self._last_episode_info.get(v, 0)
Example #10
0
def run(cfg):
    cfg = maybe_load_from_checkpoint(cfg)

    algo = DmlabLevelGenerator(cfg)
    algo.initialize()
    status = algo.run()
    algo.finalize()

    log.info('Exit...')
    return status
Example #11
0
 def _set_game_mode(self, mode):
     if mode == 'replay':
         self.game.set_mode(Mode.PLAYER)
     else:
         if self.async_mode:
             log.info(
                 'Starting in async mode! Use this only for testing, otherwise PLAYER mode is much faster'
             )
             self.game.set_mode(Mode.ASYNC_PLAYER)
         else:
             self.game.set_mode(Mode.PLAYER)
Example #12
0
    def _terminate(self, real_envs, imagined_envs):
        if self._verbose:
            log.info('Stop worker %s...', list_to_string(self.env_indices))
        for e in real_envs:
            e.close()
        if imagined_envs is not None:
            for imagined_env in imagined_envs:
                imagined_env.close()

        if self._verbose:
            log.info('Worker %s terminated!', list_to_string(self.env_indices))
Example #13
0
def main():
    env_name = 'doom_battle'
    env = create_env(env_name, cfg=default_cfg(env=env_name))

    env.reset()
    done = False
    while not done:
        env.render()
        obs, rew, done, info = env.step(env.action_space.sample())

    log.info('Done!')
Example #14
0
    def step(self, actions, reset=None):
        if reset is None:
            results = self.await_tasks(actions, MsgType.STEP_REAL)
        else:
            results = self.await_tasks(list(zip(actions, reset)),
                                       MsgType.STEP_REAL_RESET)

        log.info('After await tasks')
        observations, rewards, dones, infos = zip(*results)

        self._update_stats(rewards, dones, infos)
        return observations, rewards, dones, infos
Example #15
0
    def test_doom_multiagent_parallel(self):
        num_workers = 16
        workers = []

        for i in range(num_workers):
            log.info('Starting worker #%d', i)
            worker = Process(target=self.doom_multiagent, args=(self.make_standard_dm, i, 200))
            worker.start()
            workers.append(worker)
            time.sleep(0.01)

        for i in range(num_workers):
            workers[i].join()
Example #16
0
 def _init(self, envs):
     log.info('Initializing envs %s...', list_to_string(self.env_indices))
     worker_index = self.env_indices[0] // len(self.env_indices)
     for i in self.env_indices:
         env_config = AttrDict({
             'worker_index': worker_index,
             'vector_index': i - self.env_indices[0]
         })
         env = self.make_env_func(env_config)
         env.seed(i)
         env.reset()
         if hasattr(env, 'num_agents') and env.num_agents > 1:
             self.is_multiagent = True
         envs.append(env)
         time.sleep(0.01)
Example #17
0
    def _maybe_print(self, avg_rewards, avg_length, fps, t):
        log.info('<====== Step %d, env step %.2fM ======>', self.train_step,
                 self.env_steps / 1e6)
        log.info('Avg FPS: %.1f', fps)
        log.info('Timing: %s', t)

        if math.isnan(avg_rewards) or math.isnan(avg_length):
            return

        log.info('Avg. %d episode length: %.3f', self.cfg.stats_episodes,
                 avg_length)
        best_reward_str = '' if math.isnan(
            self.best_avg_reward) else f'(best: {self.best_avg_reward:.3f})'
        log.info('Avg. %d episode reward: %.3f %s', self.cfg.stats_episodes,
                 avg_rewards, best_reward_str)
Example #18
0
def dmlab_ensure_global_cache_initialized(experiment_dir,
                                          all_levels_for_experiment,
                                          num_policies):
    global DMLAB_GLOBAL_LEVEL_CACHE

    assert multiprocessing.current_process().name == 'MainProcess', \
        'make sure you initialize DMLab cache before child processes are forked'

    DMLAB_GLOBAL_LEVEL_CACHE = []
    for policy_id in range(num_policies):
        # level cache is of course shared between independently training policies
        # it's easiest to achieve

        log.info('Initializing level cache for policy %d...', policy_id)
        cache = DmlabLevelCacheGlobal(LEVEL_CACHE_DIR, experiment_dir,
                                      all_levels_for_experiment, policy_id)
        DMLAB_GLOBAL_LEVEL_CACHE.append(cache)
Example #19
0
    def __init__(self, env, initial_difficulty=None):
        super().__init__(env)

        self._min_difficulty = 0
        self._max_difficulty = 150
        self._difficulty_step = 10
        self._curr_difficulty = 20 if initial_difficulty is None else initial_difficulty
        self._difficulty_std = 10

        log.info('Starting with bot difficulty %d', self._curr_difficulty)

        self._adaptive_curriculum = True
        if initial_difficulty == self._max_difficulty:
            log.debug(
                'Starting at max difficulty, disable adaptive skill curriculum'
            )
            self._adaptive_curriculum = False
Example #20
0
    def _save(self):
        checkpoint = self._get_checkpoint_dict()
        assert checkpoint is not None

        filepath = join(
            self._checkpoint_dir(),
            f'checkpoint_{self.train_step:09d}_{self.env_steps}.pth')
        log.info('Saving %s...', filepath)
        torch.save(checkpoint, filepath)

        while len(self._get_checkpoints(
                self._checkpoint_dir())) > self.cfg.keep_checkpoints:
            oldest_checkpoint = self._get_checkpoints(
                self._checkpoint_dir())[0]
            if os.path.isfile(oldest_checkpoint):
                log.debug('Removing %s', oldest_checkpoint)
                os.remove(oldest_checkpoint)

        self._save_cfg()
Example #21
0
    def replay(env, rec_path):
        doom = env.unwrapped
        doom.mode = 'replay'
        doom._ensure_initialized()
        doom.game.replay_episode(rec_path)

        episode_reward = 0
        start = time.time()

        while not doom.game.is_episode_finished():
            doom.game.advance_action()
            r = doom.game.get_last_reward()
            episode_reward += r
            log.info('Episode reward: %.3f, time so far: %.1f s',
                     episode_reward,
                     time.time() - start)

        log.info('Finishing replay')
        doom.close()
Example #22
0
    def __init__(self, num_agents, make_env_func, env_config, skip_frames):
        self.num_agents = num_agents
        log.debug('Multi agent env, num agents: %d', self.num_agents)
        self.skip_frames = skip_frames  # number of frames to skip (1 = no skip)

        env = make_env_func(
            player_id=-1
        )  # temporary env just to query observation_space and stuff
        self.action_space = env.action_space
        self.observation_space = env.observation_space

        # we can probably do this in a more generic way, but good enough for now
        self.default_reward_shaping = None
        if hasattr(env.unwrapped, '_reward_shaping_wrapper'):
            # noinspection PyProtectedMember
            self.default_reward_shaping = env.unwrapped._reward_shaping_wrapper.reward_shaping_scheme

        env.close()

        self.make_env_func = make_env_func

        self.safe_init = env_config is not None and env_config.get(
            'safe_init', False)

        if self.safe_init:
            sleep_seconds = env_config.worker_index * 1.0
            log.info(
                'Sleeping %.3f seconds to avoid creating all envs at once',
                sleep_seconds)
            time.sleep(sleep_seconds)
            log.info('Done sleeping at %d', env_config.worker_index)

        self.env_config = env_config
        self.workers = None

        # only needed when rendering
        self.enable_rendering = False
        self.last_obs = None

        self.reset_on_init = True

        self.initialized = False
Example #23
0
    def __init__(self,
                 num_envs,
                 num_workers,
                 make_env_func,
                 stats_episodes,
                 use_multiprocessing=True):
        self._verbose = False

        if num_workers > num_envs or num_envs % num_workers != 0:
            raise Exception('num_envs should be a multiple of num_workers')

        # create a temp env to query information
        env = make_env_func(None)
        self.action_space = env.action_space
        self.observation_space = env.observation_space
        env.close()
        del env

        self.num_envs = num_envs
        self.num_workers = num_workers
        self.workers = []

        envs = np.split(np.arange(num_envs), num_workers)
        self.workers = [
            _MultiEnvWorker(envs[i].tolist(), make_env_func,
                            use_multiprocessing) for i in range(num_workers)
        ]

        for worker in self.workers:
            worker.task_queue.put((None, MsgType.INIT))
            time.sleep(0.1)  # just in case
        for worker in self.workers:
            worker.task_queue.join()
        log.info('Envs initialized!')

        self.curr_episode_reward = [0] * self._num_actors()
        self.episode_rewards = [[] for _ in range(self._num_actors())]

        self.curr_episode_duration = [0] * self._num_actors()
        self.episode_lengths = [[] for _ in range(self._num_actors())]

        self.stats_episodes = stats_episodes
Example #24
0
    def _parse_info(self, info, done):
        if self.reward_shaping_scheme is None:
            # skip reward calculation
            return 0.0

        # by default these are negative values if no weapon is selected
        selected_weapon = info.get('SELECTED_WEAPON', 0.0)
        selected_weapon = int(max(0, selected_weapon))
        selected_weapon_ammo = float(
            max(0.0, info.get('SELECTED_WEAPON_AMMO', 0.0)))
        self.selected_weapon.append(selected_weapon)

        was_dead = self.prev_dead
        is_alive = not info.get('DEAD', 0.0)
        just_respawned = was_dead and is_alive

        shaping_reward = 0.0
        if not done and not just_respawned:
            shaping_reward, deltas = self._delta_rewards(info)

            shaping_reward += self._selected_weapon_rewards(
                selected_weapon,
                selected_weapon_ammo,
                deltas,
            )

            if abs(shaping_reward) > 2.5 and not self.print_once:
                log.info('Large shaping reward %.3f for %r', shaping_reward,
                         deltas)
                self.print_once = True

        if done and 'FRAGCOUNT' in self.reward_structure:
            sorted_rew = sorted(self.reward_structure.items(),
                                key=operator.itemgetter(1))
            sum_rew = sum(r for key, r in sorted_rew)
            sorted_rew = {key: f'{r:.3f}' for key, r in sorted_rew}
            log.info('Sum rewards: %.3f, reward structure: %r', sum_rew,
                     sorted_rew)

        return shaping_reward
Example #25
0
    def step(self, action):
        obs, rew, done, info = self.env.step(action)
        if obs is None:
            return obs, rew, done, info

        self.orig_env_reward += rew

        shaping_rew = self._parse_info(info, done)
        rew += shaping_rew
        self.total_shaping_reward += shaping_rew

        if self.verbose:
            log.info('Original env reward before shaping: %.3f',
                     self.orig_env_reward)
            player_id = 1
            if hasattr(self.env.unwrapped, 'player_id'):
                player_id = self.env.unwrapped.player_id

            log.info(
                'Total shaping reward is %.3f for %d (done %d)',
                self.total_shaping_reward,
                player_id,
                done,
            )

        # remember new variable values
        for var_name in self.reward_shaping_scheme['delta'].keys():
            self.prev_vars[var_name] = info.get(var_name, 0.0)

        self.prev_dead = not not info.get('DEAD', 0.0)  # float -> bool

        if done:
            if self.true_reward_func is None:
                true_reward = self.orig_env_reward
            else:
                true_reward = self.true_reward_func(info)

            info['true_reward'] = true_reward

        return obs, rew, done, info
Example #26
0
    def test_doom_multiagent_multi_env(self):
        agents_per_env = 6
        num_envs = 2
        num_workers = 2

        skip_frames = 2  # hardcoded

        multi_env = MultiAgentEnvAggregator(
            num_envs=num_envs,
            num_workers=num_workers,
            make_env_func=self.make_standard_dm,
            stats_episodes=10,
            use_multiprocessing=True,
        )
        log.info('Before reset...')
        multi_env.reset()
        log.info('After reset...')

        actions = [multi_env.action_space.sample()] * (agents_per_env * num_envs)
        obs, rew, done, info = multi_env.step(actions)
        log.info('Rewards: %r', rew)

        start = time.time()
        num_steps = 300
        for i in range(num_steps):
            obs, rew, done, info = multi_env.step(actions)
            if i % 50 == 0:
                log.debug('Steps %d, rew: %r', i, rew)

        took = time.time() - start
        log.debug('Took %.3f sec to run %d steps, steps/sec: %.1f', took, num_steps, num_steps / took)
        log.debug('Observations fps: %.1f', num_steps * multi_env.num_agents * num_envs / took)
        log.debug('Environment fps: %.1f', num_steps * multi_env.num_agents * num_envs * skip_frames / took)

        multi_env.close()
        log.info('Done!')
Example #27
0
    def sample(self, proc_idx):
        # workers should ignore Ctrl+C because the termination is handled in the event loop by a special msg
        signal.signal(signal.SIGINT, signal.SIG_IGN)

        timing = Timing()

        psutil.Process().nice(10)

        num_envs = len(DMLAB30_LEVELS_THAT_USE_LEVEL_CACHE)
        assert self.cfg.num_workers % num_envs == 0, f'should have an integer number of workers per env, e.g. {1 * num_envs}, {2 * num_envs}, etc...'
        assert self.cfg.num_envs_per_worker == 1, 'use populate_cache with 1 env per worker'

        with timing.timeit('env_init'):
            env_key = 'env'
            env_desired_num_levels = 0

            global_env_id = proc_idx * self.cfg.num_envs_per_worker
            env_config = AttrDict(worker_index=proc_idx,
                                  vector_index=0,
                                  env_id=global_env_id)
            env = create_env(self.cfg.env, cfg=self.cfg, env_config=env_config)
            env.seed(global_env_id)

            # this is to track the performance for individual DMLab levels
            if hasattr(env.unwrapped, 'level_name'):
                env_key = env.unwrapped.level_name
                env_level = env.unwrapped.level

                approx_num_episodes_per_1b_frames = DMLAB30_APPROX_NUM_EPISODES_PER_BILLION_FRAMES[
                    env_key]
                num_billions = DESIRED_TRAINING_LENGTH / int(1e9)
                num_workers_for_env = self.cfg.num_workers // num_envs
                env_desired_num_levels = int(
                    (approx_num_episodes_per_1b_frames * num_billions) /
                    num_workers_for_env)

                env_num_levels_generated = len(
                    dmlab_level_cache.DMLAB_GLOBAL_LEVEL_CACHE[0].
                    all_seeds[env_level]) // num_workers_for_env

                log.warning('Worker %d (env %s) generated %d/%d levels!',
                            proc_idx, env_key, env_num_levels_generated,
                            env_desired_num_levels)
                time.sleep(4)

            env.reset()
            env_uses_level_cache = env.unwrapped.env_uses_level_cache

            self.report_queue.put(dict(proc_idx=proc_idx, finished_reset=True))

        self.start_event.wait()

        try:
            with timing.timeit('work'):
                last_report = last_report_frames = total_env_frames = 0
                while not self.terminate.value and total_env_frames < self.cfg.sample_env_frames_per_worker:
                    action = env.action_space.sample()
                    with timing.add_time(f'{env_key}.step'):
                        env.step(action)

                    total_env_frames += 1

                    with timing.add_time(f'{env_key}.reset'):
                        env.reset()
                        env_num_levels_generated += 1
                        log.debug('Env %s done %d/%d resets', env_key,
                                  env_num_levels_generated,
                                  env_desired_num_levels)

                    if env_num_levels_generated >= env_desired_num_levels:
                        log.debug('%s finished %d/%d resets, sleeping...',
                                  env_key, env_num_levels_generated,
                                  env_desired_num_levels)
                        time.sleep(30)  # free up CPU time for other envs

                    # if env does not use level cache, there is no need to run it
                    # let other workers proceed
                    if not env_uses_level_cache:
                        log.debug('Env %s does not require cache, sleeping...',
                                  env_key)
                        time.sleep(200)

                    with timing.add_time('report'):
                        now = time.time()
                        if now - last_report > self.report_every_sec:
                            last_report = now
                            frames_since_last_report = total_env_frames - last_report_frames
                            last_report_frames = total_env_frames
                            self.report_queue.put(
                                dict(proc_idx=proc_idx,
                                     env_frames=frames_since_last_report))

                            if get_free_disk_space_mb() < 3 * 1024:
                                log.error('Not enough disk space! %d',
                                          get_free_disk_space_mb())
                                time.sleep(200)
        except:
            log.exception('Unknown exception')
            log.error('Unknown exception in worker %d, terminating...',
                      proc_idx)
            self.report_queue.put(dict(proc_idx=proc_idx, crash=True))

        time.sleep(proc_idx * 0.1 + 0.1)
        log.info('Process %d finished sampling. Timing: %s', proc_idx, timing)

        env.close()
Example #28
0
    def play_human_mode(env, skip_frames=1, num_episodes=3, num_actions=None):
        from pynput.keyboard import Listener

        doom = env.unwrapped
        doom.skip_frames = 1  # handled by this script separately

        # noinspection PyProtectedMember
        def start_listener():
            with Listener(on_press=doom._keyboard_on_press,
                          on_release=doom._keyboard_on_release) as listener:
                listener.join()

        listener_thread = Thread(target=start_listener)
        listener_thread.start()

        for episode in range(num_episodes):
            doom.mode = 'human'
            env.reset()
            last_render_time = time.time()
            time_between_frames = 1.0 / 35.0

            total_rew = 0.0

            while not doom.game.is_episode_finished() and not doom._terminate:
                num_actions = 14 if num_actions is None else num_actions
                turn_delta_action_idx = num_actions - 1

                actions = [0] * num_actions
                for action in doom._current_actions:
                    if isinstance(action, int):
                        actions[
                            action] = 1  # 1 for buttons currently pressed, 0 otherwise
                    else:
                        if action == 'turn_left':
                            actions[
                                turn_delta_action_idx] = -doom.delta_actions_scaling_factor
                        elif action == 'turn_right':
                            actions[
                                turn_delta_action_idx] = doom.delta_actions_scaling_factor

                for frame in range(skip_frames):
                    doom._actions_flattened = actions
                    _, rew, _, _ = env.step(actions)

                    new_total_rew = total_rew + rew
                    if new_total_rew != total_rew:
                        log.info('Reward: %.3f, total: %.3f', rew,
                                 new_total_rew)
                    total_rew = new_total_rew
                    state = doom.game.get_state()

                    verbose = True
                    if state is not None and verbose:
                        info = doom.get_info()
                        print(
                            'Health:',
                            info['HEALTH'],
                            # 'Weapon:', info['SELECTED_WEAPON'],
                            # 'ready:', info['ATTACK_READY'],
                            # 'ammo:', info['SELECTED_WEAPON_AMMO'],
                            # 'pc:', info['PLAYER_COUNT'],
                            # 'dmg:', info['DAMAGECOUNT'],
                        )

                    time_since_last_render = time.time() - last_render_time
                    time_wait = time_between_frames - time_since_last_render

                    if doom.show_automap and state.automap_buffer is not None:
                        map_ = state.automap_buffer
                        map_ = np.swapaxes(map_, 0, 2)
                        map_ = np.swapaxes(map_, 0, 1)
                        cv2.imshow('ViZDoom Automap Buffer', map_)
                        if time_wait > 0:
                            cv2.waitKey(int(time_wait) * 1000)
                    else:
                        if time_wait > 0:
                            time.sleep(time_wait)

                    last_render_time = time.time()

            if doom.show_automap:
                cv2.destroyAllWindows()

        log.debug('Press ESC to exit...')
        listener_thread.join()
Example #29
0
def make_doom_env_impl(
        doom_spec,
        cfg=None,
        env_config=None,
        skip_frames=None,
        episode_horizon=None,
        player_id=None, num_agents=None, max_num_players=None, num_bots=0,  # for multi-agent
        custom_resolution=None,
        **kwargs,
):
    skip_frames = skip_frames if skip_frames is not None else cfg.env_frameskip

    fps = cfg.fps if 'fps' in cfg else None
    async_mode = fps == 0

    if player_id is None:
        env = VizdoomEnv(
            doom_spec.action_space, doom_spec.env_spec_file, skip_frames=skip_frames, async_mode=async_mode,
        )
    else:
        timelimit = cfg.timelimit if cfg.timelimit is not None else doom_spec.timelimit

        from seed_rl.envs.doom.multiplayer.doom_multiagent import VizdoomEnvMultiplayer
        env = VizdoomEnvMultiplayer(
            doom_spec.action_space, doom_spec.env_spec_file,
            player_id=player_id, num_agents=num_agents, max_num_players=max_num_players, num_bots=num_bots,
            skip_frames=skip_frames,
            async_mode=async_mode,
            respawn_delay=doom_spec.respawn_delay,
            timelimit=timelimit,
        )

    record_to = cfg.record_to if 'record_to' in cfg else None
    should_record = False
    if env_config is None:
        should_record = True
    elif env_config.worker_index == 0 and env_config.vector_index == 0 and (player_id is None or player_id == 0):
        should_record = True

    if record_to is not None and should_record:
        env = RecordingWrapper(env, record_to, player_id)

    env = MultiplayerStatsWrapper(env)

    if num_bots > 0:
        bot_difficulty = cfg.start_bot_difficulty if 'start_bot_difficulty' in cfg else None
        env = BotDifficultyWrapper(env, bot_difficulty)

    resolution = custom_resolution
    if resolution is None:
        resolution = '256x144' if cfg.wide_aspect_ratio else '160x120'

    assert resolution in resolutions
    env = SetResolutionWrapper(env, resolution)  # default (wide aspect ratio)

    h, w, channels = env.observation_space.shape
    if w != cfg.res_w or h != cfg.res_h:
        env = ResizeWrapper(env, cfg.res_w, cfg.res_h, grayscale=False)

    log.info('Doom resolution: %s, resize resolution: %r', resolution, (cfg.res_w, cfg.res_h))

    # randomly vary episode duration to somewhat decorrelate the experience
    timeout = doom_spec.default_timeout
    if episode_horizon is not None and episode_horizon > 0:
        timeout = episode_horizon
    if timeout > 0:
        env = TimeLimitWrapper(env, limit=timeout, random_variation_steps=0)

    pixel_format = cfg.pixel_format if 'pixel_format' in cfg else 'HWC'
    if pixel_format == 'CHW':
        env = PixelFormatChwWrapper(env)

    if doom_spec.extra_wrappers is not None:
        for wrapper_cls, wrapper_kwargs in doom_spec.extra_wrappers:
            env = wrapper_cls(env, **wrapper_kwargs)

    if doom_spec.reward_scaling != 1.0:
        env = RewardScalingWrapper(env, doom_spec.reward_scaling)

    return env
Example #30
0
def dmlab_register_models():
    log.info('Adding model class %r to registry', DmlabEncoder)
    ENCODER_REGISTRY['dmlab_instructions'] = DmlabEncoder