def get_buffer(buffer_args: Config) -> Optional[ReplayBuffer]: ''' parsing arguments of replay buffer params: buffer_args: configurations of replay buffer return: Correct experience replay mechanism. For On-Policy algorithms, they don't have to specify a replay buffer out of model class, so return None ''' if buffer_args.get('buffer_size', 0) <= 0: logger.info( 'This algorithm does not need sepecify a data buffer outside the model.' ) return None _buffer_type = buffer_args.get('type', 'None') logger.info(_buffer_type) if _buffer_type in BufferDict.keys(): Buffer = getattr( importlib.import_module(f'rls.memories.replay_buffer'), BufferDict[_buffer_type]) return Buffer(batch_size=buffer_args['batch_size'], capacity=buffer_args['buffer_size'], **buffer_args[_buffer_type].to_dict) else: return None
def get_options(options: Dict) -> Config: ''' Resolves command-line arguments params: options: dictionary of command-line arguments return: op: an instance of Config class that contains the parameters ''' def f(k, t): return None if options[k] == 'None' else t(options[k]) op = Config() op.add_dict( dict([ ['inference', bool(options['--inference'])], ['algo', str(options['--algorithm'])], ['use_rnn', bool(options['--rnn'])], ['algo_config', f('--config-file', str)], ['env', f('--env', str)], ['port', int(options['--port'])], ['unity', bool(options['--unity'])], ['graphic', bool(options['--graphic'])], ['name', f('--name', str)], ['save_frequency', f('--save-frequency', int)], ['models', int(options['--models'])], ['store_dir', f('--store-dir', str)], ['seed', int(options['--seed'])], ['unity_env_seed', int(options['--unity-env-seed'])], ['max_step_per_episode', f('--max-step', int)], ['max_train_step', f('--train-step', int)], ['max_train_frame', f('--train-frame', int)], ['max_train_episode', f('--train-episode', int)], ['load', f('--load', str)], ['prefill_steps', f('--prefill-steps', int)], ['prefill_choose', bool(options['--prefill-choose'])], ['gym', bool(options['--gym'])], ['n_copys', int(options['--copys'])], ['gym_env', str(options['--gym-env'])], ['gym_env_seed', int(options['--gym-env-seed'])], ['render_episode', f('--render-episode', int)], ['info', f('--info', str)], ['unity_env', f('--unity-env', str)], ['apex', f('--apex', str)], ['hostname', bool(options['--hostname'])], ['no_save', bool(options['--no-save'])], ])) return op
def parse_options(options: Config, default_config: Dict) -> Tuple[Config]: # gym > unity > unity_env env_args = Config() env_args.env_num = options.n_copys # Environmental copies of vectorized training. env_args.inference = options.inference if options.gym: env_args.type = 'gym' env_args.add_dict(default_config['gym']['env']) env_args.env_name = options.gym_env env_args.env_seed = options.gym_env_seed else: env_args.type = 'unity' env_args.add_dict(default_config['unity']['env']) if env_args.initialize_config.env_copys <= 1: env_args.initialize_config.env_copys = options.n_copys env_args.port = options.port env_args.env_seed = options.unity_env_seed env_args.render = options.graphic or options.inference if options.unity: env_args.file_name = None env_args.env_name = 'unity' else: env_args.update({'file_name': options.env}) if os.path.exists(env_args.file_name): env_args.env_name = options.unity_env or os.path.join( *os.path.split(env_args.file_name)[0].replace( '\\', '/').replace(r'//', r'/').split('/')[-2:]) if 'visual' in env_args.env_name.lower(): # if traing with visual input but do not render the environment, all 0 obs will be passed. env_args.render = True else: raise Exception('can not find the executable file.') train_args = Config(**default_config['train']) if options.gym: train_args.add_dict(default_config['gym']['train']) train_args.render_episode = abs( train_args.render_episode) or sys.maxsize train_args.update({'render_episode': options.render_episode}) else: train_args.add_dict(default_config['unity']['train']) train_args.index = 0 train_args.name = time.strftime('%Y_%m_%d_%H_%M_%S', time.localtime(time.time())) train_args.max_step_per_episode = abs( train_args.max_step_per_episode) or sys.maxsize train_args.max_train_step = abs(train_args.max_train_step) or sys.maxsize train_args.max_frame_step = abs(train_args.max_frame_step) or sys.maxsize train_args.max_train_episode = abs( train_args.max_train_episode) or sys.maxsize train_args.inference_episode = abs( train_args.inference_episode) or sys.maxsize train_args.algo = options.algo train_args.apex = options.apex train_args.use_rnn = options.use_rnn train_args.algo_config = options.algo_config train_args.seed = options.seed train_args.inference = options.inference train_args.prefill_choose = options.prefill_choose train_args.load_model_path = options.load train_args.no_save = options.no_save train_args.base_dir = os.path.join(options.store_dir or BASE_DIR, env_args.env_name, train_args.algo) if train_args.load_model_path is not None and not os.path.exists( train_args.load_model_path): # 如果不是绝对路径,就拼接load的训练相对路径 train_args.load_model_path = os.path.join(train_args.base_dir, train_args.load_model_path) train_args.update( dict([['name', options.name], ['max_step_per_episode', options.max_step_per_episode], ['max_train_step', options.max_train_step], ['max_train_frame', options.max_train_frame], ['max_train_episode', options.max_train_episode], ['save_frequency', options.save_frequency], ['pre_fill_steps', options.prefill_steps], ['info', options.info]])) if options.apex is not None: train_args.name = f'{options.apex}/' + train_args.name if options.hostname: import socket train_args.name += ('-' + str(socket.gethostname())) buffer_args = Config(**default_config['buffer']) return env_args, buffer_args, train_args