def get_options(options: Dict): f = lambda k, t: None if options[k] == 'None' else t(options[k]) op = Config() op.add_dict( dict([['inference', bool(options['--inference'])], ['algo', str(options['--algorithm'])], ['algo_config', f('--config-file', str)], ['env', f('--env', str)], ['port', int(options['--port'])], ['unity', bool(options['--unity'])], ['graphic', bool(options['--graphic'])], ['name', f('--name', str)], ['save_frequency', f('--save-frequency', int)], ['models', int(options['--models'])], ['store_dir', f('--store-dir', str)], ['seed', int(options['--seed'])], ['max_step', f('--max-step', int)], ['max_episode', f('--max-episode', int)], ['sampler', f('--sampler', str)], ['load', f('--load', str)], ['fill_in', bool(options['--fill-in'])], ['prefill_choose', bool(options['--prefill-choose'])], ['gym', bool(options['--gym'])], ['gym_agents', int(options['--gym-agents'])], ['gym_env', str(options['--gym-env'])], ['gym_env_seed', int(options['--gym-env-seed'])], ['render_episode', f('--render-episode', int)], ['info', f('--info', str)]])) return op
def run(): if sys.platform.startswith('win'): import win32api import win32con import _thread def _win_handler(event, hook_sigint=_thread.interrupt_main): if event == 0: hook_sigint() return 1 return 0 # Add the _win_handler function to the windows console's handler function list win32api.SetConsoleCtrlHandler(_win_handler, 1) options = docopt(__doc__) options = get_options(dict(options)) print(options) default_config = load_yaml(f'config.yaml') # gym > unity > unity_env model_args = Config(**default_config['model']) train_args = Config(**default_config['train']) env_args = Config() buffer_args = Config(**default_config['buffer']) model_args.algo = options.algo model_args.use_rnn = options.use_rnn model_args.algo_config = options.algo_config model_args.seed = options.seed model_args.load = options.load env_args.env_num = options.n_copys if options.gym: train_args.add_dict(default_config['gym']['train']) train_args.update({'render_episode': options.render_episode}) env_args.add_dict(default_config['gym']['env']) env_args.type = 'gym' env_args.env_name = options.gym_env env_args.env_seed = options.gym_env_seed else: train_args.add_dict(default_config['unity']['train']) env_args.add_dict(default_config['unity']['env']) env_args.type = 'unity' env_args.port = options.port env_args.sampler_path = options.sampler env_args.env_seed = options.unity_env_seed if options.unity: env_args.file_path = None env_args.env_name = 'unity' else: env_args.update({'file_path': options.env}) if os.path.exists(env_args.file_path): env_args.env_name = options.unity_env or os.path.join( *os.path.split(env_args.file_path)[0].replace('\\', '/').replace(r'//', r'/').split('/')[-2:] ) if 'visual' in env_args.env_name.lower(): # if traing with visual input but do not render the environment, all 0 obs will be passed. options.graphic = True else: raise Exception('can not find this file.') if options.inference: env_args.train_mode = False env_args.render = True else: env_args.train_mode = True env_args.render = options.graphic train_args.index = 0 train_args.name = NAME train_args.use_wandb = options.use_wandb train_args.inference = options.inference train_args.prefill_choose = options.prefill_choose train_args.base_dir = os.path.join(options.store_dir or BASE_DIR, env_args.env_name, model_args.algo) train_args.update( dict([ ['name', options.name], ['max_step_per_episode', options.max_step_per_episode], ['max_train_step', options.max_train_step], ['max_train_frame', options.max_train_frame], ['max_train_episode', options.max_train_episode], ['save_frequency', options.save_frequency], ['pre_fill_steps', options.prefill_steps], ['info', options.info] ]) ) if options.inference: Agent(env_args, model_args, buffer_args, train_args).evaluate() return trails = options.models if trails == 1: agent_run(env_args, model_args, buffer_args, train_args) elif trails > 1: processes = [] for i in range(trails): _env_args = deepcopy(env_args) _model_args = deepcopy(model_args) _model_args.seed += i * 10 _buffer_args = deepcopy(buffer_args) _train_args = deepcopy(train_args) _train_args.index = i if _env_args.type == 'unity': _env_args.port = env_args.port + i p = Process(target=agent_run, args=(_env_args, _model_args, _buffer_args, _train_args)) p.start() time.sleep(10) processes.append(p) [p.join() for p in processes] else: raise Exception('trials must be greater than 0.')