Beispiel #1
0
def run_experiment(mode="local", keys=None, params=dict()):
    flags = FLAGS.__flags
    flags = deepcopy(flags)

    for k, v in params.items():
        print('Modifying flags.%s from %r to %r' % (k, flags[k], v))
        flags[k] = v

    n_episodes = flags["max_episode"]  # max episodes before termination
    env = get_env(record_video=False, record_log=False, **flags)
    policy = get_policy(env=env, **flags)
    baseline = get_baseline(env=env, **flags)
    qf = get_qf(env=env, **flags)
    es = get_es(env=env, **flags)

    info, _ = get_env_info(**flags)
    max_path_length = info['horizon']
    n_itr = int(
        np.ceil(float(n_episodes * max_path_length) / flags['batch_size']))

    algo = get_algo(n_itr=n_itr,
                    env=env,
                    policy=policy,
                    baseline=baseline,
                    qf=qf,
                    es=es,
                    max_path_length=max_path_length,
                    **flags)

    #exp_prefix='%s-%s-%d'%(flags["exp"], flags["env_name"], flags["batch_size"])
    exp_prefix = '%s' % (flags["exp"])
    exp_name = get_annotations_string(keys=keys, **flags)
    if flags["normalize_obs"]: flags["env_name"] += 'norm'
    exp_name = '%s-%d--' % (flags["env_name"], flags["batch_size"]) + exp_name
    log_dir = config.LOG_DIR + "/local/" + exp_prefix.replace(
        "_", "-") + "/" + exp_name
    if flags["seed"] is not None:
        log_dir += '--s-%d' % flags["seed"]
    if not flags["overwrite"] and osp.exists(log_dir):
        ans = input("Overwrite %s?: (yes/no)" % log_dir)
        if ans != 'yes': sys.exit(0)

    run_experiment_lite(
        algo.train(),
        exp_prefix=exp_prefix,
        exp_name=exp_name,
        # Number of parallel workers for sampling
        n_parallel=1,
        snapshot_mode="last_best",
        # Specifies the seed for the experiment. If this is not provided, a random seed
        # will be used
        seed=flags["seed"],
        # plot=True,
        terminate_machine=True,
        sync_s3_pkl=True,
        periodic_sync_interval=1200,
        # terminate_machine=False,
        # fast_code_sync=False,
        mode=mode,
    )
Beispiel #2
0
def set_experiment(mode="local", keys=None, params=dict()):
    flags = FLAGS.__flags
    # VGG: fix error handling flags after Tensorflow 1.4
    for name in flags.keys():
        flags[name] = flags[name].value
        print('{}: {}'.format(name, flags[name]))
    flags = deepcopy(flags)

    for k, v in params.items():
        print('Modifying flags.%s from %r to %r' % (k, flags[k], v))
        flags[k] = v

    n_episodes = flags["max_episode"]  # max episodes before termination
    info, _ = get_env_info(**flags)
    max_path_length = 200  #info['horizon'] #200 **VGG: TODO: fix Gym env to include this

    print('n_episodes: ', n_episodes)
    print('max_path_length: ', max_path_length)
    print('flags[batch_size]: ', flags['batch_size'])
    print('flags[obs_space]: ', info['obs_space'])

    n_itr = int(
        np.ceil(float(n_episodes * max_path_length) / flags['batch_size']))

    exp_prefix = '%s' % (flags["exp"])
    exp_name = get_annotations_string(keys=keys, **flags)
    if flags["normalize_obs"]: flags["env_name"] += 'norm'
    exp_name = '%s-%d--' % (flags["env_name"], flags["batch_size"]) + exp_name
    log_dir = config.LOG_DIR + "/local/" + exp_prefix.replace(
        "_", "-") + "/" + exp_name
    if flags["seed"] is not None:
        log_dir += '--s-%d' % flags["seed"]
    if not flags["overwrite"] and osp.exists(log_dir):
        ans = input("Overwrite %s?: (yes/no)" % log_dir)
        if ans != 'yes': sys.exit(0)

    env = get_env(record_video=False, record_log=False, **flags)
    policy = get_policy(env=env, info=info, **flags)
    baseline = get_baseline(env=env, **flags)
    qf = get_qf(env=env, info=info, **flags)
    es = get_es(env=env, info=info, **flags)

    algo = get_algo(n_itr=n_itr,
                    env=env,
                    policy=policy,
                    baseline=baseline,
                    qf=qf,
                    es=es,
                    max_path_length=max_path_length,
                    plot=True,
                    **flags)
    return algo, dict(
        exp_prefix=exp_prefix,
        exp_name=exp_name,
        mode=mode,
        seed=flags["seed"],
    )
Beispiel #3
0
def set_experiment(mode="local", keys=None, params=dict()):
    flags = FLAGS.__flags
    flags = deepcopy(flags)

    for k, v in params.items():
        print('Modifying flags.%s from %r to %r' % (k, flags[k], v))
        flags[k] = v

    n_episodes = flags["max_episode"]  # max episodes before termination
    info, _ = get_env_info(**flags)
    max_path_length = info['horizon']
    n_itr = int(
        np.ceil(float(n_episodes * max_path_length) / flags['batch_size']))

    exp_prefix = '%s' % (flags["exp"])
    exp_name = get_annotations_string(keys=keys, **flags)
    if flags["normalize_obs"]: flags["env_name"] += 'norm'
    exp_name = '%s-%d--' % (flags["env_name"], flags["batch_size"]) + exp_name
    log_dir = config.LOG_DIR + "/local/" + exp_prefix.replace(
        "_", "-") + "/" + exp_name
    if flags["seed"] is not None:
        log_dir += '--s-%d' % flags["seed"]
    # if not flags["overwrite"] and osp.exists(log_dir):
    #     ans = input("Overwrite %s?: (yes/no)"%log_dir)
    #     if ans != 'yes': sys.exit(0)

    env = get_env(record_video=False, record_log=False, **flags)
    policy = get_policy(env=env, info=info, **flags)
    baseline = get_baseline(env=env, **flags)
    qf = get_qf(env=env, info=info, **flags)
    es = get_es(env=env, info=info, **flags)
    pf = get_pf(env=env, info=info, **flags)

    algo = get_algo(n_itr=n_itr,
                    env=env,
                    policy=policy,
                    baseline=baseline,
                    qf=qf,
                    es=es,
                    pf=pf,
                    max_path_length=max_path_length,
                    **flags)
    return algo, dict(
        exp_prefix=exp_prefix,
        exp_name=exp_name,
        mode=mode,
        seed=flags["seed"],
    )
Beispiel #4
0
def main(argv=None):
    info, _ = get_env_info(**FLAGS.__flags)
    pprint(info)