def main(exp_prefix='exp', algo='exact', layers=(32, 32), repeat=0,
        env_name='grid1', **alg_args):
    env = env_suite.get_env(env_name)

    if layers == 'tabular':
        network = q_networks.TabularNetwork(env)
    else:
        network = q_networks.FCNetwork(env, layers=layers)
    ptu.initialize_network(network)

    alg_args.update({
        'min_project_steps': 10,
        'max_project_steps': 300,
        'lr': 5e-3,
        'discount': 0.95,
        'n_steps': 1,
        'backup_mode': 'exact',
        'stop_modes': (stopping.AtolStop(), stopping.RtolStop()),
        'time_limit': env.time_limit,
        'env_name': env_name,
        'layers': str(layers),
    })
    fqi = exact_fqi.WeightedExactFQI(env, network, log_proj_qstar=True, **alg_args)
    with log_utils.setup_logger(algo=fqi, exp_prefix=exp_prefix, log_base_dir='./data') as log_dir:
        print('Logging to %s' % log_dir)
        try:
            for k in range(300):
                fqi.update(step=k)
        except:
            log_utils.save_exception()
Esempio n. 2
0
def main(exp_prefix='exp',
         layers=(32, 32),
         repeat=0,
         env_name='grid1',
         sampling_type=None,
         **alg_args):
    env = env_suite.get_env(env_name)

    if layers == 'tabular':
        network = q_networks.TabularNetwork(env)
    else:
        network = q_networks.FCNetwork(env, layers=layers)
    ptu.initialize_network(network)
    stop_mode = tuple()

    if sampling_type.endswith('512'):
        num_samples = 512
    elif sampling_type.endswith('64'):
        num_samples = 64
    elif sampling_type.endswith('32'):
        num_samples = 32

    alg_args.update({
        'min_project_steps': 10,
        'lr': 5e-3,
        'discount': 0.95,
        'n_steps': 1,
        'batch_size': 128,
        'stop_modes': stop_mode,
        'time_limit': env.time_limit,
        'env_name': env_name,
        'backup_mode': 'sampling',
        'layers': str(layers),
        'num_samples': num_samples,
        'log_sampling_type': sampling_type
    })

    if sampling_type.startswith('buffer'):
        fqi = replay_buffer_fqi.TabularBufferDQN(env,
                                                 network,
                                                 log_proj_qstar=False,
                                                 **alg_args)
    elif sampling_type.startswith('sample'):
        fqi = sampling_fqi.WeightedSamplingFQI(env,
                                               network,
                                               log_proj_qstar=False,
                                               **alg_args)

    with log_utils.setup_logger(algo=fqi,
                                exp_prefix=exp_prefix,
                                log_base_dir='./data') as log_dir:
        print('Logging to %s' % log_dir)
        try:
            for k in range(600):
                fqi.update(step=k)
        except:
            log_utils.save_exception()
Esempio n. 3
0
def run(output_dir='/tmp',
        env_name='pointmass_empty',
        gpu=True,
        seed=0,
        **kwargs):

    import gym
    import numpy as np
    from rlutil.logging import log_utils, logger

    import rlutil.torch as torch
    import rlutil.torch.pytorch_util as ptu

    # Envs

    from gcsl import envs
    from gcsl.envs.env_utils import DiscretizedActionEnv

    # Algo
    from gcsl.algo import buffer, gcsl, variants, networks

    ptu.set_gpu(gpu)
    if not gpu:
        print('Not using GPU. Will be slow.')

    torch.manual_seed(seed)
    np.random.seed(seed)

    env = envs.create_env(env_name)
    env_params = envs.get_env_params(env_name)
    print(env_params)

    env, policy, replay_buffer, gcsl_kwargs = variants.get_params(
        env, env_params)
    algo = gcsl.GCSL(env, policy, replay_buffer, **gcsl_kwargs)

    exp_prefix = 'example/%s/gcsl/' % (env_name, )

    with log_utils.setup_logger(exp_prefix=exp_prefix,
                                log_base_dir=output_dir):
        algo.train()
def main(exp_prefix='exp',
         validation_stop=True,
         layers=(32, 32),
         repeat=0,
         env_name='grid1',
         sampling_type=None,
         **alg_args):
    env = env_suite.get_env(env_name)

    if layers == 'tabular':
        network = q_networks.TabularNetwork(env)
    else:
        network = q_networks.FCNetwork(env, layers=layers)
    ptu.initialize_network(network)

    stop_mode = (stopping.AtolStop(), stopping.RtolStop())

    if sampling_type.endswith('512'):
        num_samples = 512
    elif sampling_type.endswith('256'):
        num_samples = 256
    elif sampling_type.endswith('128'):
        num_samples = 128
    elif sampling_type.endswith('64'):
        num_samples = 64
    elif sampling_type.endswith('32'):
        num_samples = 32
    elif sampling_type.endswith('16'):
        num_samples = 16
    elif sampling_type.endswith('8'):
        num_samples = 8
    elif sampling_type.endswith('4'):
        num_samples = 4

    alg_args.update({
        'min_project_steps': 10,
        'max_project_steps': 200,
        'lr': 5e-3,
        'discount': 0.95,
        'n_steps': 1,
        'batch_size': 128,
        'stop_modes': stop_mode,
        'time_limit': env.time_limit,
        'env_name': env_name,
        'backup_mode': 'sampling',
        'layers': str(layers),
        'num_samples': num_samples,
        'log_sampling_type': sampling_type
    })

    if sampling_type.startswith('buffer'):
        fqi = replay_buffer_fqi.TabularBufferDQN(env,
                                                 network,
                                                 log_proj_qstar=False,
                                                 **alg_args)
    elif sampling_type.startswith('sample'):
        fqi = sampling_fqi.WeightedSamplingFQI(env,
                                               network,
                                               log_proj_qstar=False,
                                               **alg_args)

    with log_utils.setup_logger(algo=fqi,
                                exp_prefix=exp_prefix,
                                log_base_dir='./data') as log_dir:
        print('Logging to %s' % log_dir)
        total_num_samples = 0
        k = 0
        try:
            while (total_num_samples < 128 * 200) and (k <= 2000):
                fqi.update(step=k)
                total_num_samples += num_samples
                k += 1
        except:
            log_utils.save_exception()
Esempio n. 5
0
def main(exp_prefix='exp', validation_stop=True, layers=(32, 32), repeat=0, env_name='grid1', sampling_type=None, **alg_args):
    env = env_suite.get_env(env_name)

    if layers == 'tabular':
        network = q_networks.TabularNetwork(env)
    else:
        network = q_networks.FCNetwork(env, layers=layers)
    ptu.initialize_network(network)

    max_project_steps = 200
    if validation_stop == 'returns':
        raise NotImplementedError("TODO(justin): reimplement")
    elif validation_stop == 'bellman':
        stop_mode = (stopping.ValidationLoss(),)
        max_project_steps = 50
    else:
        stop_mode = (stopping.AtolStop(), stopping.RtolStop()),

    if sampling_type.endswith('512'):
        num_samples = 512
    elif sampling_type.endswith('256'):
        num_samples = 256
    elif sampling_type.endswith('128'):
        num_samples = 128
    elif sampling_type.endswith('64'):
        num_samples = 64
    elif sampling_type.endswith('32'):
        num_samples = 32
    elif sampling_type.endswith('16'):
        num_samples = 16
    elif sampling_type.endswith('4'):
        num_samples = 4

    alg_args.update({
        'min_project_steps': 10,
        'max_project_steps': max_project_steps,
        'lr': 5e-3,
        'discount': 0.95,
        'n_steps': 1,
        'batch_size': 128,
        'stop_modes': stop_mode,
        'time_limit': env.time_limit,
        'env_name': env_name,
        'backup_mode': 'sampling',
        'layers': str(layers),
        'validation_stop': validation_stop,
        'num_samples': num_samples,
        'log_sampling_type': sampling_type
    })

    if sampling_type.startswith('buffer'):
        fqi = replay_buffer_fqi.TabularBufferDQN(env, network, log_proj_qstar=False, **alg_args)
    elif sampling_type.startswith('sample'):
        fqi = sampling_fqi.WeightedSamplingFQI(env, network, log_proj_qstar=False, **alg_args)

    with log_utils.setup_logger(algo=fqi, exp_prefix=exp_prefix, log_base_dir='/data') as log_dir:
        print('Logging to %s' % log_dir)
        try:
            for k in range(300):
                fqi.update(step=k)
        except:
            log_utils.save_exception()