예제 #1
0
def get_model_info(name: str):
    '''
    Args:
        name: name of algorithms
    Return:
        class of the algorithm model named `name`.
        defaulf config of specified algorithm.
        mode of policy, `on-policy` or `off-policy`
    '''
    if name not in algos.keys():
        raise NotImplementedError(name)
    else:
        class_name = algos[name]['class']
        policy_mode = algos[name]['policy']
        model_file = importlib.import_module('algos.tf2algos.' + name)
        model = getattr(model_file, class_name)
        algo_general_config = load_yaml(f'algos/config.yaml')['general']
        if policy_mode == 'on-policy':
            algo_policy_config = load_yaml(f'algos/config.yaml')['on_policy']
        elif policy_mode == 'off-policy':
            algo_policy_config = load_yaml(f'algos/config.yaml')['off_policy']
        algo_config = load_yaml(f'algos/config.yaml')[name]
        algo_config.update(algo_policy_config)
        algo_config.update(algo_general_config)
        return model, algo_config, policy_mode
예제 #2
0
def build_env(config: Dict):
    gym_env_name = config['env_name']
    action_skip = bool(config.get('action_skip', False))
    skip = int(config.get('skip', 4))
    obs_stack = bool(config.get('obs_stack', False))
    stack = int(config.get('stack', 4))

    noop = bool(config.get('noop', False))
    noop_max = int(config.get('noop_max', 30))
    obs_grayscale = bool(config.get('obs_grayscale', False))
    obs_resize = bool(config.get('obs_resize', False))
    resize = config.get('resize', [84, 84])
    obs_scale = bool(config.get('obs_scale', False))
    max_episode_steps = config.get('max_episode_steps', None)

    env_type = get_env_type(gym_env_name)
    env = gym.make(gym_env_name)
    env = BaseEnv(env)

    if env_type == 'atari':
        assert 'NoFrameskip' in env.spec.id
        from common.yaml_ops import load_yaml

        default_config = load_yaml(
            f'{os.path.dirname(__file__)}/config.yaml')['atari']
        env = make_atari(env, default_config)
    else:
        if gym_env_name.split('-')[0] == 'MiniGrid':
            env = gym_minigrid.wrappers.RGBImgPartialObsWrapper(
                env)  # Get pixel observations, or RGBImgObsWrapper
            env = gym_minigrid.wrappers.ImgObsWrapper(
                env)  # Get rid of the 'mission' field
        if noop and isinstance(env.observation_space, Box) and len(
                env.observation_space.shape) == 3:
            env = NoopResetEnv(env, noop_max=noop_max)
        if action_skip:
            env = MaxAndSkipEnv(env, skip=4)
        if isinstance(env.observation_space, Box):
            if len(env.observation_space.shape) == 3:
                if obs_grayscale or obs_resize:
                    env = GrayResizeEnv(env,
                                        resize=obs_resize,
                                        grayscale=obs_grayscale,
                                        width=resize[0],
                                        height=resize[-1])
                if obs_scale:
                    env = ScaleEnv(env)
            if obs_stack:
                env = StackEnv(env, stack=stack)
        else:
            env = OneHotObsEnv(env)
        env = TimeLimit(env, max_episode_steps)

    if isinstance(env.action_space, Box) and len(env.action_space.shape) == 1:
        env = BoxActEnv(env)

    if not (isinstance(env.observation_space, Box)
            and len(env.observation_space.shape) == 3):
        env = DtypeEnv(env)
    return env
예제 #3
0
파일: wrappers.py 프로젝트: yyht/RLs
 def __init__(self, env_args):
     self.engine_configuration_channel = EngineConfigurationChannel()
     if env_args['train_mode']:
         self.engine_configuration_channel.set_configuration_parameters(
             time_scale=env_args['train_time_scale'])
     else:
         self.engine_configuration_channel.set_configuration_parameters(
             width=env_args['width'],
             height=env_args['height'],
             quality_level=env_args['quality_level'],
             time_scale=env_args['inference_time_scale'],
             target_frame_rate=env_args['target_frame_rate'])
     self.float_properties_channel = EnvironmentParametersChannel()
     if env_args['file_path'] is None:
         self._env = UnityEnvironment(base_port=5004,
                                      seed=env_args['env_seed'],
                                      side_channels=[
                                          self.engine_configuration_channel,
                                          self.float_properties_channel
                                      ])
     else:
         unity_env_dict = load_yaml(
             os.path.dirname(__file__) + '/../../unity_env_dict.yaml')
         self._env = UnityEnvironment(
             file_name=env_args['file_path'],
             base_port=env_args['port'],
             no_graphics=not env_args['render'],
             seed=env_args['env_seed'],
             side_channels=[
                 self.engine_configuration_channel,
                 self.float_properties_channel
             ],
             additional_args=[
                 '--scene',
                 str(
                     unity_env_dict.get(env_args.get('env_name', 'Roller'),
                                        'None')), '--n_agents',
                 str(env_args.get('env_num', 1))
             ])
예제 #4
0
파일: run.py 프로젝트: yyht/RLs
def run():
    if sys.platform.startswith('win'):
        import win32api
        import win32con
        import _thread

        def _win_handler(event, hook_sigint=_thread.interrupt_main):
            if event == 0:
                hook_sigint()
                return 1
            return 0
        # Add the _win_handler function to the windows console's handler function list
        win32api.SetConsoleCtrlHandler(_win_handler, 1)

    options = docopt(__doc__)
    options = get_options(dict(options))
    print(options)

    default_config = load_yaml(f'config.yaml')
    # gym > unity > unity_env
    model_args = Config(**default_config['model'])
    train_args = Config(**default_config['train'])
    env_args = Config()
    buffer_args = Config(**default_config['buffer'])

    model_args.algo = options.algo
    model_args.use_rnn = options.use_rnn
    model_args.algo_config = options.algo_config
    model_args.seed = options.seed
    model_args.load = options.load

    env_args.env_num = options.n_copys
    if options.gym:
        train_args.add_dict(default_config['gym']['train'])
        train_args.update({'render_episode': options.render_episode})
        env_args.add_dict(default_config['gym']['env'])
        env_args.type = 'gym'
        env_args.env_name = options.gym_env
        env_args.env_seed = options.gym_env_seed
    else:
        train_args.add_dict(default_config['unity']['train'])
        env_args.add_dict(default_config['unity']['env'])
        env_args.type = 'unity'
        env_args.port = options.port
        env_args.sampler_path = options.sampler
        env_args.env_seed = options.unity_env_seed
        if options.unity:
            env_args.file_path = None
            env_args.env_name = 'unity'
        else:
            env_args.update({'file_path': options.env})
            if os.path.exists(env_args.file_path):
                env_args.env_name = options.unity_env or os.path.join(
                    *os.path.split(env_args.file_path)[0].replace('\\', '/').replace(r'//', r'/').split('/')[-2:]
                )
                if 'visual' in env_args.env_name.lower():
                    # if traing with visual input but do not render the environment, all 0 obs will be passed.
                    options.graphic = True
            else:
                raise Exception('can not find this file.')
        if options.inference:
            env_args.train_mode = False
            env_args.render = True
        else:
            env_args.train_mode = True
            env_args.render = options.graphic

    train_args.index = 0
    train_args.name = NAME
    train_args.use_wandb = options.use_wandb
    train_args.inference = options.inference
    train_args.prefill_choose = options.prefill_choose
    train_args.base_dir = os.path.join(options.store_dir or BASE_DIR, env_args.env_name, model_args.algo)
    train_args.update(
        dict([
            ['name', options.name],
            ['max_step_per_episode', options.max_step_per_episode],
            ['max_train_step', options.max_train_step],
            ['max_train_frame', options.max_train_frame],
            ['max_train_episode', options.max_train_episode],
            ['save_frequency', options.save_frequency],
            ['pre_fill_steps', options.prefill_steps],
            ['info', options.info]
        ])
    )

    if options.inference:
        Agent(env_args, model_args, buffer_args, train_args).evaluate()
        return

    trails = options.models
    if trails == 1:
        agent_run(env_args, model_args, buffer_args, train_args)
    elif trails > 1:
        processes = []
        for i in range(trails):
            _env_args = deepcopy(env_args)
            _model_args = deepcopy(model_args)
            _model_args.seed += i * 10
            _buffer_args = deepcopy(buffer_args)
            _train_args = deepcopy(train_args)
            _train_args.index = i
            if _env_args.type == 'unity':
                _env_args.port = env_args.port + i
            p = Process(target=agent_run, args=(_env_args, _model_args, _buffer_args, _train_args))
            p.start()
            time.sleep(10)
            processes.append(p)
        [p.join() for p in processes]
    else:
        raise Exception('trials must be greater than 0.')
예제 #5
0
def run():
    if sys.platform.startswith('win'):
        import win32api
        import win32con
        import _thread

        def _win_handler(event, hook_sigint=_thread.interrupt_main):
            if event == 0:
                hook_sigint()
                return 1
            return 0

        # Add the _win_handler function to the windows console's handler function list
        win32api.SetConsoleCtrlHandler(_win_handler, 1)

    options = docopt(__doc__)
    print(options)

    default_config = load_yaml(f'config.yaml')
    # gym > unity > unity_env
    env_args, model_args, train_args = {}, {}, {}
    unity_args, gym_args, buffer_args = default_config[
        'unity'], default_config['gym'], default_config['buffer']

    model_args['algo'] = str(options['--algorithm'])
    model_args['algo_config'] = None if options[
        '--config-file'] == 'None' else str(options['--config-file'])
    model_args['seed'] = int(options['--seed'])
    model_args['load'] = None if options['--load'] == 'None' else str(
        options['--load'])
    model_args['logger2file'] = default_config['logger2file']

    train_args['index'] = 0
    train_args['all_learner_print'] = default_config['all_learner_print']
    train_args['name'] = NAME if options['--name'] == 'None' else str(
        options['--name'])
    train_args['max_step'] = default_config['max_step'] if options[
        '--max-step'] == 'None' else int(options['--max-step'])
    train_args['max_episode'] = default_config['max_episode'] if options[
        '--max-episode'] == 'None' else int(options['--max-episode'])
    train_args['save_frequency'] = default_config['save_frequency'] if options[
        '--save-frequency'] == 'None' else int(options['--save-frequency'])
    train_args['inference'] = bool(options['--inference'])
    train_args['fill_in'] = bool(options['--fill-in'])
    train_args['no_op_choose'] = bool(options['--noop-choose'])
    train_args['info'] = default_config['info'] if options[
        '--info'] == 'None' else str(options['--info'])

    if options['--gym']:
        env_args['type'] = 'gym'
        env_args['env_name'] = str(options['--gym-env'])
        env_args['env_num'] = int(options['--gym-agents'])
        env_args['env_seed'] = int(options['--gym-env-seed'])
        env_args['render_mode'] = gym_args['render_mode']
        env_args['action_skip'] = gym_args['action_skip']
        env_args['skip'] = gym_args['skip']
        env_args['obs_stack'] = gym_args['obs_stack']
        env_args['stack'] = gym_args['stack']
        env_args['obs_grayscale'] = gym_args['obs_grayscale']
        env_args['obs_resize'] = gym_args['obs_resize']
        env_args['resize'] = gym_args['resize']
        env_args['obs_scale'] = gym_args['obs_scale']

        train_args['render_episode'] = gym_args['render_episode'] if options[
            '--render-episode'] == 'None' else int(options['--render-episode'])
        train_args['no_op_steps'] = gym_args['random_steps']
        train_args['render'] = gym_args['render']
        train_args['eval_while_train'] = gym_args['eval_while_train']
        train_args['max_eval_episode'] = gym_args['max_eval_episode']
    else:
        env_args['type'] = 'unity'
        if options['--unity']:
            env_args['file_path'] = None
            env_args['env_name'] = 'unity'
        else:
            env_args['file_path'] = unity_args['exe_file'] if options[
                '--env'] == 'None' else str(options['--env'])
            if os.path.exists(env_args['file_path']):
                env_args['env_name'] = os.path.join(
                    *os.path.split(env_args['file_path'])[0].replace(
                        '\\', '/').replace(r'//', r'/').split('/')[-2:])
            else:
                raise Exception('can not find this file.')
        if bool(options['--inference']):
            env_args['train_mode'] = False
        else:
            env_args['train_mode'] = True

        env_args['port'] = int(options['--port'])
        env_args['render'] = bool(options['--graphic'])
        env_args['sampler_path'] = None if options[
            '--sampler'] == 'None' else str(options['--sampler'])
        env_args['reset_config'] = unity_args['reset_config']

        train_args['no_op_steps'] = unity_args['no_op_steps']

    train_args['base_dir'] = os.path.join(
        BASE_DIR if options['--store-dir'] == 'None' else str(
            options['--store-dir']), env_args['env_name'], model_args['algo'])

    if bool(options['--inference']):
        Agent(env_args, model_args, buffer_args, train_args).evaluate()

    trails = int(options['--modes'])
    if trails == 1:
        agent_run(env_args, model_args, buffer_args, train_args)
    elif trails > 1:
        processes = []
        for i in range(trails):
            _env_args = deepcopy(env_args)
            _model_args = deepcopy(model_args)
            _model_args['seed'] += i * 10
            _buffer_args = deepcopy(buffer_args)
            _train_args = deepcopy(train_args)
            _train_args['index'] = i
            if _env_args['type'] == 'unity':
                _env_args['port'] = env_args['port'] + i
            p = Process(target=agent_run,
                        args=(_env_args, _model_args, _buffer_args,
                              _train_args))
            p.start()
            time.sleep(10)
            processes.append(p)
        [p.join() for p in processes]
    else:
        raise Exception('trials must be greater than 0.')