コード例 #1
0
def get_model_info(name: str) -> Tuple[Callable, Dict, str, str]:
    '''
    Args:
        name: name of algorithms
    Return:
        algo_class of the algorithm model named `name`.
        defaulf config of specified algorithm.
        policy_type of policy, `on-policy` or `off-policy`
    '''
    if name not in algos.keys():
        raise NotImplementedError(name)
    else:
        class_name = algos[name]['algo_class']
        policy_mode = algos[name]['policy_mode']
        policy_type = algos[name]['policy_type']

        model = getattr(
            importlib.import_module(f'rls.algos.{policy_type}.{name}'),
            class_name)

        algo_config = load_yaml(f'rls/algos/config.yaml')[name]
        algo_config.update(
            load_yaml(f'rls/algos/config.yaml')[policy_mode.replace('-', '_')])
        algo_config.update(load_yaml(f'rls/algos/config.yaml')['general'])
        return model, algo_config, policy_mode, policy_type
コード例 #2
0
def get_model_info(name: str) -> Tuple[Callable, Dict, str, str]:
    '''
    Args:
        name: name of algorithms
    Return:
        algo_class of the algorithm model named `name`.
        defaulf config of specified algorithm.
        policy_type of policy, `on-policy` or `off-policy`
    '''
    algo_info = registry.get_model_info(name)
    class_name = algo_info['algo_class']
    policy_mode = algo_info['policy_mode']
    policy_type = algo_info['policy_type']
    LOGO = algo_info.get('logo', '')
    logger.info(colorize(LOGO, color='green'))

    model = getattr(
        importlib.import_module(f'rls.algos.{policy_type}.{name}'),
        class_name)

    algo_config = {}
    algo_config.update(
        load_yaml(f'rls/algos/config.yaml')['general']
    )
    algo_config.update(
        load_yaml(f'rls/algos/config.yaml')[policy_mode.replace('-', '_')]
    )
    algo_config.update(
        load_yaml(f'rls/algos/config.yaml')[name]
    )
    return model, algo_config, policy_mode, policy_type
コード例 #3
0
def build_env(config: Dict):
    gym_env_name = config['env_name']
    action_skip = bool(config.get('action_skip', False))
    skip = int(config.get('skip', 4))
    obs_stack = bool(config.get('obs_stack', False))
    stack = int(config.get('stack', 4))

    noop = bool(config.get('noop', False))
    noop_max = int(config.get('noop_max', 30))
    obs_grayscale = bool(config.get('obs_grayscale', False))
    obs_resize = bool(config.get('obs_resize', False))
    resize = config.get('resize', [84, 84])
    obs_scale = bool(config.get('obs_scale', False))
    max_episode_steps = config.get('max_episode_steps', None)

    env_type = get_env_type(gym_env_name)
    env = gym.make(gym_env_name)
    env = BaseEnv(env)

    if env_type == 'atari':
        assert 'NoFrameskip' in env.spec.id
        from rls.common.yaml_ops import load_yaml

        default_config = load_yaml(
            f'{os.path.dirname(__file__)}/config.yaml')['atari']
        env = make_atari(env, default_config)
    else:
        if gym_env_name.split('-')[0] == 'MiniGrid':
            env = gym_minigrid.wrappers.RGBImgPartialObsWrapper(
                env)  # Get pixel observations, or RGBImgObsWrapper
            env = gym_minigrid.wrappers.ImgObsWrapper(
                env)  # Get rid of the 'mission' field
        if noop and isinstance(env.observation_space, Box) and len(
                env.observation_space.shape) == 3:
            env = NoopResetEnv(env, noop_max=noop_max)
        if action_skip:
            env = MaxAndSkipEnv(env, skip=4)
        if isinstance(env.observation_space, Box):
            if len(env.observation_space.shape) == 3:
                if obs_grayscale or obs_resize:
                    env = GrayResizeEnv(env,
                                        resize=obs_resize,
                                        grayscale=obs_grayscale,
                                        width=resize[0],
                                        height=resize[-1])
                if obs_scale:
                    env = ScaleEnv(env)
            if obs_stack:
                env = StackEnv(env, stack=stack)
        else:
            env = OneHotObsEnv(env)
        env = TimeLimit(env, max_episode_steps)

    if isinstance(env.action_space, Box) and len(env.action_space.shape) == 1:
        env = BoxActEnv(env)

    if not (isinstance(env.observation_space, Box)
            and len(env.observation_space.shape) == 3):
        env = DtypeEnv(env)
    return env
コード例 #4
0
def main():
    options = docopt(__doc__)
    options = get_options(dict(options))
    show_dict(options.to_dict)

    trails = options.models
    assert trails > 0, '--models must greater than 0.'

    env_args, buffer_args, train_args = parse_options(
        options, default_config=load_yaml(f'./config.yaml'))

    if options.inference:
        Trainer(env_args, buffer_args, train_args).evaluate()
        return

    if options.apex is not None:
        train_args.update(load_yaml(f'./rls/distribute/apex/config.yaml'))
        Trainer(env_args, buffer_args, train_args).apex()
    else:
        if trails == 1:
            agent_run(env_args, buffer_args, train_args)
        elif trails > 1:
            processes = []
            for i in range(trails):
                _env_args, _buffer_args, _train_args = map(
                    deepcopy, [env_args, buffer_args, train_args])
                _train_args.seed += i * 10
                _train_args.name += f'/{i}'
                _train_args.allow_print = True  # NOTE: set this could block other processes' print function
                if _env_args.type == 'unity':
                    _env_args.port = env_args.port + i
                p = Process(target=agent_run,
                            args=(_env_args, _buffer_args, _train_args))
                p.start()
                time.sleep(10)
                processes.append(p)
            [p.join() for p in processes]
コード例 #5
0
ファイル: wrappers.py プロジェクト: wyz1074152339/RLs
 def __init__(self, env_args):
     self.engine_configuration_channel = EngineConfigurationChannel()
     if env_args['train_mode']:
         self.engine_configuration_channel.set_configuration_parameters(
             time_scale=env_args['train_time_scale'])
     else:
         self.engine_configuration_channel.set_configuration_parameters(
             width=env_args['width'],
             height=env_args['height'],
             quality_level=env_args['quality_level'],
             time_scale=env_args['inference_time_scale'],
             target_frame_rate=env_args['target_frame_rate'])
     self.float_properties_channel = EnvironmentParametersChannel()
     if env_args['file_path'] is None:
         self._env = UnityEnvironment(base_port=5004,
                                      seed=env_args['env_seed'],
                                      side_channels=[
                                          self.engine_configuration_channel,
                                          self.float_properties_channel
                                      ])
     else:
         unity_env_dict = load_yaml('/'.join(
             [os.getcwd(), 'rls', 'envs', 'unity_env_dict.yaml']))
         self._env = UnityEnvironment(
             file_name=env_args['file_path'],
             base_port=env_args['port'],
             no_graphics=not env_args['render'],
             seed=env_args['env_seed'],
             side_channels=[
                 self.engine_configuration_channel,
                 self.float_properties_channel
             ],
             additional_args=[
                 '--scene',
                 str(
                     unity_env_dict.get(env_args.get('env_name', 'Roller'),
                                        'None')), '--n_agents',
                 str(env_args.get('env_num', 1))
             ])
     self.reset_config = env_args['reset_config']
コード例 #6
0
ファイル: wrappers.py プロジェクト: zhijie-ai/RLs
    def __init__(self, kwargs):
        self._side_channels = self.initialize_all_side_channels(kwargs)

        env_kwargs = dict(seed=int(kwargs['env_seed']),
                          worker_id=int(kwargs['worker_id']),
                          timeout_wait=int(kwargs['timeout_wait']),
                          side_channels=list(
                              self._side_channels.values()))  # 注册所有初始化后的通讯频道
        if kwargs['file_name'] is not None:
            unity_env_dict = load_yaml('/'.join(
                [os.getcwd(), 'rls', 'envs', 'unity_env_dict.yaml']))
            env_kwargs.update(file_name=kwargs['file_name'],
                              base_port=kwargs['port'],
                              no_graphics=not kwargs['render'],
                              additional_args=[
                                  '--scene',
                                  str(
                                      unity_env_dict.get(
                                          kwargs.get('env_name', '3DBall'),
                                          'None'))
                              ])
        self.env = UnityEnvironment(**env_kwargs)
        self.env.reset()
        self.initialize_environment()
コード例 #7
0
ファイル: utils.py プロジェクト: zhijie-ai/RLs
def build_env(config: Dict, index: int = 0):
    gym_env_name = config['env_name']

    env_type = get_env_type(gym_env_name)

    env_params = {}
    if env_type == 'pybullet_envs.bullet':
        env_params.update({'renders': bool(config.get('inference', False))})

    elif env_type == 'gym_donkeycar.envs.donkey_env':
        _donkey_conf = load_yaml(
            f'{os.path.dirname(__file__)}/config.yaml')['donkey']
        import uuid
        # [120, 160, 3]
        _donkey_conf['port'] += index
        _donkey_conf['car_name'] += str(index)
        _donkey_conf['guid'] = str(uuid.uuid4())
        env_params['conf'] = _donkey_conf

    env = gym.make(gym_env_name, **env_params)
    env = BaseEnv(env)

    if env_type == 'gym.envs.atari':
        assert 'NoFrameskip' in env.spec.id, 'env id should contain NoFrameskip.'

        default_config = load_yaml(
            f'{os.path.dirname(__file__)}/config.yaml')['atari']
        env = make_atari(env, default_config)
    else:
        action_skip = bool(config.get('action_skip', False))
        skip = int(config.get('skip', 4))
        obs_stack = bool(config.get('obs_stack', False))
        stack = int(config.get('stack', 4))
        noop = bool(config.get('noop', False))
        noop_max = int(config.get('noop_max', 30))
        obs_grayscale = bool(config.get('obs_grayscale', False))
        obs_resize = bool(config.get('obs_resize', False))
        resize = config.get('resize', [84, 84])
        obs_scale = bool(config.get('obs_scale', False))
        max_episode_steps = config.get('max_episode_steps', None)

        if gym_env_name.split('-')[0] == 'MiniGrid':
            env = gym_minigrid.wrappers.RGBImgPartialObsWrapper(
                env)  # Get pixel observations, or RGBImgObsWrapper
            env = gym_minigrid.wrappers.ImgObsWrapper(
                env)  # Get rid of the 'mission' field
        if noop and isinstance(env.observation_space, Box) and len(
                env.observation_space.shape) == 3:
            env = NoopResetEnv(env, noop_max=noop_max)
        if action_skip:
            env = MaxAndSkipEnv(env, skip=4)
        if isinstance(env.observation_space, Box):
            if len(env.observation_space.shape) == 3:
                if obs_grayscale or obs_resize:
                    env = GrayResizeEnv(env,
                                        resize=obs_resize,
                                        grayscale=obs_grayscale,
                                        width=resize[0],
                                        height=resize[-1])
                if obs_scale:
                    env = ScaleEnv(env)
            if obs_stack:
                env = StackEnv(env, stack=stack)
        else:
            env = OneHotObsEnv(env)
        env = TimeLimit(env, max_episode_steps)

    if isinstance(env.action_space, Box) and len(env.action_space.shape) == 1:
        env = BoxActEnv(env)

    if not (isinstance(env.observation_space, Box)
            and len(env.observation_space.shape) == 3):
        env = DtypeEnv(env)
    return env