Esempio n. 1
0
def get_buffer(buffer_args: Config) -> Optional[ReplayBuffer]:
    '''
    parsing arguments of replay buffer
    params:
        buffer_args: configurations of replay buffer
    return:
        Correct experience replay mechanism.
        For On-Policy algorithms, they don't have to specify a replay buffer out of model class, so return None
    '''

    if buffer_args.get('buffer_size', 0) <= 0:
        logger.info(
            'This algorithm does not need sepecify a data buffer outside the model.'
        )
        return None

    _buffer_type = buffer_args.get('type', 'None')
    logger.info(_buffer_type)

    if _buffer_type in BufferDict.keys():
        Buffer = getattr(
            importlib.import_module(f'rls.memories.replay_buffer'),
            BufferDict[_buffer_type])
        return Buffer(batch_size=buffer_args['batch_size'],
                      capacity=buffer_args['buffer_size'],
                      **buffer_args[_buffer_type].to_dict)
    else:
        return None
Esempio n. 2
0
File: run.py Progetto: zhijie-ai/RLs
def get_options(options: Dict) -> Config:
    '''
    Resolves command-line arguments
    params:
        options: dictionary of command-line arguments
    return:
        op: an instance of Config class that contains the parameters
    '''
    def f(k, t):
        return None if options[k] == 'None' else t(options[k])

    op = Config()
    op.add_dict(
        dict([
            ['inference', bool(options['--inference'])],
            ['algo', str(options['--algorithm'])],
            ['use_rnn', bool(options['--rnn'])],
            ['algo_config', f('--config-file', str)],
            ['env', f('--env', str)],
            ['port', int(options['--port'])],
            ['unity', bool(options['--unity'])],
            ['graphic', bool(options['--graphic'])],
            ['name', f('--name', str)],
            ['save_frequency', f('--save-frequency', int)],
            ['models', int(options['--models'])],
            ['store_dir', f('--store-dir', str)],
            ['seed', int(options['--seed'])],
            ['unity_env_seed',
             int(options['--unity-env-seed'])],
            ['max_step_per_episode',
             f('--max-step', int)],
            ['max_train_step', f('--train-step', int)],
            ['max_train_frame', f('--train-frame', int)],
            ['max_train_episode',
             f('--train-episode', int)],
            ['load', f('--load', str)],
            ['prefill_steps', f('--prefill-steps', int)],
            ['prefill_choose',
             bool(options['--prefill-choose'])],
            ['gym', bool(options['--gym'])],
            ['n_copys', int(options['--copys'])],
            ['gym_env', str(options['--gym-env'])],
            ['gym_env_seed', int(options['--gym-env-seed'])],
            ['render_episode', f('--render-episode', int)],
            ['info', f('--info', str)],
            ['unity_env', f('--unity-env', str)],
            ['apex', f('--apex', str)],
            ['hostname', bool(options['--hostname'])],
            ['no_save', bool(options['--no-save'])],
        ]))
    return op
Esempio n. 3
0
def parse_options(options: Config, default_config: Dict) -> Tuple[Config]:
    # gym > unity > unity_env
    env_args = Config()
    env_args.env_num = options.n_copys  # Environmental copies of vectorized training.
    env_args.inference = options.inference
    if options.gym:
        env_args.type = 'gym'
        env_args.add_dict(default_config['gym']['env'])
        env_args.env_name = options.gym_env
        env_args.env_seed = options.gym_env_seed
    else:
        env_args.type = 'unity'
        env_args.add_dict(default_config['unity']['env'])
        if env_args.initialize_config.env_copys <= 1:
            env_args.initialize_config.env_copys = options.n_copys
        env_args.port = options.port
        env_args.env_seed = options.unity_env_seed
        env_args.render = options.graphic or options.inference

        if options.unity:
            env_args.file_name = None
            env_args.env_name = 'unity'
        else:
            env_args.update({'file_name': options.env})
            if os.path.exists(env_args.file_name):
                env_args.env_name = options.unity_env or os.path.join(
                    *os.path.split(env_args.file_name)[0].replace(
                        '\\', '/').replace(r'//', r'/').split('/')[-2:])
                if 'visual' in env_args.env_name.lower():
                    # if traing with visual input but do not render the environment, all 0 obs will be passed.
                    env_args.render = True
            else:
                raise Exception('can not find the executable file.')

    train_args = Config(**default_config['train'])
    if options.gym:
        train_args.add_dict(default_config['gym']['train'])
        train_args.render_episode = abs(
            train_args.render_episode) or sys.maxsize
        train_args.update({'render_episode': options.render_episode})
    else:
        train_args.add_dict(default_config['unity']['train'])
    train_args.index = 0
    train_args.name = time.strftime('%Y_%m_%d_%H_%M_%S',
                                    time.localtime(time.time()))
    train_args.max_step_per_episode = abs(
        train_args.max_step_per_episode) or sys.maxsize
    train_args.max_train_step = abs(train_args.max_train_step) or sys.maxsize
    train_args.max_frame_step = abs(train_args.max_frame_step) or sys.maxsize
    train_args.max_train_episode = abs(
        train_args.max_train_episode) or sys.maxsize
    train_args.inference_episode = abs(
        train_args.inference_episode) or sys.maxsize

    train_args.algo = options.algo
    train_args.apex = options.apex
    train_args.use_rnn = options.use_rnn
    train_args.algo_config = options.algo_config
    train_args.seed = options.seed
    train_args.inference = options.inference
    train_args.prefill_choose = options.prefill_choose
    train_args.load_model_path = options.load
    train_args.no_save = options.no_save
    train_args.base_dir = os.path.join(options.store_dir or BASE_DIR,
                                       env_args.env_name, train_args.algo)
    if train_args.load_model_path is not None and not os.path.exists(
            train_args.load_model_path):  # 如果不是绝对路径,就拼接load的训练相对路径
        train_args.load_model_path = os.path.join(train_args.base_dir,
                                                  train_args.load_model_path)
    train_args.update(
        dict([['name', options.name],
              ['max_step_per_episode', options.max_step_per_episode],
              ['max_train_step', options.max_train_step],
              ['max_train_frame', options.max_train_frame],
              ['max_train_episode', options.max_train_episode],
              ['save_frequency', options.save_frequency],
              ['pre_fill_steps', options.prefill_steps],
              ['info', options.info]]))
    if options.apex is not None:
        train_args.name = f'{options.apex}/' + train_args.name
    if options.hostname:
        import socket
        train_args.name += ('-' + str(socket.gethostname()))

    buffer_args = Config(**default_config['buffer'])
    return env_args, buffer_args, train_args