예제 #1
0
파일: run.py 프로젝트: rigo93acosta/RLs
def get_options(options: Dict):
    f = lambda k, t: None if options[k] == 'None' else t(options[k])
    op = Config()
    op.add_dict(
        dict([['inference', bool(options['--inference'])],
              ['algo', str(options['--algorithm'])],
              ['algo_config', f('--config-file', str)],
              ['env', f('--env', str)], ['port',
                                         int(options['--port'])],
              ['unity', bool(options['--unity'])],
              ['graphic', bool(options['--graphic'])],
              ['name', f('--name', str)],
              ['save_frequency', f('--save-frequency', int)],
              ['models', int(options['--models'])],
              ['store_dir', f('--store-dir', str)],
              ['seed', int(options['--seed'])],
              ['max_step', f('--max-step', int)],
              ['max_episode', f('--max-episode', int)],
              ['sampler', f('--sampler', str)], ['load',
                                                 f('--load', str)],
              ['fill_in', bool(options['--fill-in'])],
              ['prefill_choose',
               bool(options['--prefill-choose'])],
              ['gym', bool(options['--gym'])],
              ['gym_agents', int(options['--gym-agents'])],
              ['gym_env', str(options['--gym-env'])],
              ['gym_env_seed', int(options['--gym-env-seed'])],
              ['render_episode', f('--render-episode', int)],
              ['info', f('--info', str)]]))
    return op
예제 #2
0
파일: run.py 프로젝트: yyht/RLs
def run():
    if sys.platform.startswith('win'):
        import win32api
        import win32con
        import _thread

        def _win_handler(event, hook_sigint=_thread.interrupt_main):
            if event == 0:
                hook_sigint()
                return 1
            return 0
        # Add the _win_handler function to the windows console's handler function list
        win32api.SetConsoleCtrlHandler(_win_handler, 1)

    options = docopt(__doc__)
    options = get_options(dict(options))
    print(options)

    default_config = load_yaml(f'config.yaml')
    # gym > unity > unity_env
    model_args = Config(**default_config['model'])
    train_args = Config(**default_config['train'])
    env_args = Config()
    buffer_args = Config(**default_config['buffer'])

    model_args.algo = options.algo
    model_args.use_rnn = options.use_rnn
    model_args.algo_config = options.algo_config
    model_args.seed = options.seed
    model_args.load = options.load

    env_args.env_num = options.n_copys
    if options.gym:
        train_args.add_dict(default_config['gym']['train'])
        train_args.update({'render_episode': options.render_episode})
        env_args.add_dict(default_config['gym']['env'])
        env_args.type = 'gym'
        env_args.env_name = options.gym_env
        env_args.env_seed = options.gym_env_seed
    else:
        train_args.add_dict(default_config['unity']['train'])
        env_args.add_dict(default_config['unity']['env'])
        env_args.type = 'unity'
        env_args.port = options.port
        env_args.sampler_path = options.sampler
        env_args.env_seed = options.unity_env_seed
        if options.unity:
            env_args.file_path = None
            env_args.env_name = 'unity'
        else:
            env_args.update({'file_path': options.env})
            if os.path.exists(env_args.file_path):
                env_args.env_name = options.unity_env or os.path.join(
                    *os.path.split(env_args.file_path)[0].replace('\\', '/').replace(r'//', r'/').split('/')[-2:]
                )
                if 'visual' in env_args.env_name.lower():
                    # if traing with visual input but do not render the environment, all 0 obs will be passed.
                    options.graphic = True
            else:
                raise Exception('can not find this file.')
        if options.inference:
            env_args.train_mode = False
            env_args.render = True
        else:
            env_args.train_mode = True
            env_args.render = options.graphic

    train_args.index = 0
    train_args.name = NAME
    train_args.use_wandb = options.use_wandb
    train_args.inference = options.inference
    train_args.prefill_choose = options.prefill_choose
    train_args.base_dir = os.path.join(options.store_dir or BASE_DIR, env_args.env_name, model_args.algo)
    train_args.update(
        dict([
            ['name', options.name],
            ['max_step_per_episode', options.max_step_per_episode],
            ['max_train_step', options.max_train_step],
            ['max_train_frame', options.max_train_frame],
            ['max_train_episode', options.max_train_episode],
            ['save_frequency', options.save_frequency],
            ['pre_fill_steps', options.prefill_steps],
            ['info', options.info]
        ])
    )

    if options.inference:
        Agent(env_args, model_args, buffer_args, train_args).evaluate()
        return

    trails = options.models
    if trails == 1:
        agent_run(env_args, model_args, buffer_args, train_args)
    elif trails > 1:
        processes = []
        for i in range(trails):
            _env_args = deepcopy(env_args)
            _model_args = deepcopy(model_args)
            _model_args.seed += i * 10
            _buffer_args = deepcopy(buffer_args)
            _train_args = deepcopy(train_args)
            _train_args.index = i
            if _env_args.type == 'unity':
                _env_args.port = env_args.port + i
            p = Process(target=agent_run, args=(_env_args, _model_args, _buffer_args, _train_args))
            p.start()
            time.sleep(10)
            processes.append(p)
        [p.join() for p in processes]
    else:
        raise Exception('trials must be greater than 0.')