def main(exp_prefix='exp', algo='exact', layers=(32, 32), repeat=0,
        env_name='grid1', **alg_args):
    env = env_suite.get_env(env_name)

    if layers == 'tabular':
        network = q_networks.TabularNetwork(env)
    else:
        network = q_networks.FCNetwork(env, layers=layers)
    ptu.initialize_network(network)

    alg_args.update({
        'min_project_steps': 10,
        'max_project_steps': 300,
        'lr': 5e-3,
        'discount': 0.95,
        'n_steps': 1,
        'backup_mode': 'exact',
        'stop_modes': (stopping.AtolStop(), stopping.RtolStop()),
        'time_limit': env.time_limit,
        'env_name': env_name,
        'layers': str(layers),
    })
    fqi = exact_fqi.WeightedExactFQI(env, network, log_proj_qstar=True, **alg_args)
    with log_utils.setup_logger(algo=fqi, exp_prefix=exp_prefix, log_base_dir='./data') as log_dir:
        print('Logging to %s' % log_dir)
        try:
            for k in range(300):
                fqi.update(step=k)
        except:
            log_utils.save_exception()
Пример #2
0
    def setUp(self):
        self.env_tab = tabular_env.CliffwalkEnv(10)
        self.env_obs = random_obs_wrapper.RandomObsWrapper(self.env_tab, 8)
        self.env = time_limit_wrapper.TimeLimitWrapper(self.env_obs, 50)

        self.network = q_networks.LinearNetwork(self.env)
        ptu.initialize_network(self.network)

        self.alg_args = {
            'min_project_steps': 10,
            'max_project_steps': 20,
            'lr': 5e-3,
            'discount': 0.95,
            'n_steps': 1,
            'num_samples': 32,
            'stop_modes': (stopping.AtolStop(), stopping.RtolStop()),
            'backup_mode': 'sampling',
            'ent_wt': 0.01,
        }
def main(exp_prefix='exp',
         validation_stop=True,
         layers=(32, 32),
         repeat=0,
         env_name='grid1',
         sampling_type=None,
         **alg_args):
    env = env_suite.get_env(env_name)

    if layers == 'tabular':
        network = q_networks.TabularNetwork(env)
    else:
        network = q_networks.FCNetwork(env, layers=layers)
    ptu.initialize_network(network)

    stop_mode = (stopping.AtolStop(), stopping.RtolStop())

    if sampling_type.endswith('512'):
        num_samples = 512
    elif sampling_type.endswith('256'):
        num_samples = 256
    elif sampling_type.endswith('128'):
        num_samples = 128
    elif sampling_type.endswith('64'):
        num_samples = 64
    elif sampling_type.endswith('32'):
        num_samples = 32
    elif sampling_type.endswith('16'):
        num_samples = 16
    elif sampling_type.endswith('8'):
        num_samples = 8
    elif sampling_type.endswith('4'):
        num_samples = 4

    alg_args.update({
        'min_project_steps': 10,
        'max_project_steps': 200,
        'lr': 5e-3,
        'discount': 0.95,
        'n_steps': 1,
        'batch_size': 128,
        'stop_modes': stop_mode,
        'time_limit': env.time_limit,
        'env_name': env_name,
        'backup_mode': 'sampling',
        'layers': str(layers),
        'num_samples': num_samples,
        'log_sampling_type': sampling_type
    })

    if sampling_type.startswith('buffer'):
        fqi = replay_buffer_fqi.TabularBufferDQN(env,
                                                 network,
                                                 log_proj_qstar=False,
                                                 **alg_args)
    elif sampling_type.startswith('sample'):
        fqi = sampling_fqi.WeightedSamplingFQI(env,
                                               network,
                                               log_proj_qstar=False,
                                               **alg_args)

    with log_utils.setup_logger(algo=fqi,
                                exp_prefix=exp_prefix,
                                log_base_dir='./data') as log_dir:
        print('Logging to %s' % log_dir)
        total_num_samples = 0
        k = 0
        try:
            while (total_num_samples < 128 * 200) and (k <= 2000):
                fqi.update(step=k)
                total_num_samples += num_samples
                k += 1
        except:
            log_utils.save_exception()
Пример #4
0
def main(exp_prefix='exp', validation_stop=True, layers=(32, 32), repeat=0, env_name='grid1', sampling_type=None, **alg_args):
    env = env_suite.get_env(env_name)

    if layers == 'tabular':
        network = q_networks.TabularNetwork(env)
    else:
        network = q_networks.FCNetwork(env, layers=layers)
    ptu.initialize_network(network)

    max_project_steps = 200
    if validation_stop == 'returns':
        raise NotImplementedError("TODO(justin): reimplement")
    elif validation_stop == 'bellman':
        stop_mode = (stopping.ValidationLoss(),)
        max_project_steps = 50
    else:
        stop_mode = (stopping.AtolStop(), stopping.RtolStop()),

    if sampling_type.endswith('512'):
        num_samples = 512
    elif sampling_type.endswith('256'):
        num_samples = 256
    elif sampling_type.endswith('128'):
        num_samples = 128
    elif sampling_type.endswith('64'):
        num_samples = 64
    elif sampling_type.endswith('32'):
        num_samples = 32
    elif sampling_type.endswith('16'):
        num_samples = 16
    elif sampling_type.endswith('4'):
        num_samples = 4

    alg_args.update({
        'min_project_steps': 10,
        'max_project_steps': max_project_steps,
        'lr': 5e-3,
        'discount': 0.95,
        'n_steps': 1,
        'batch_size': 128,
        'stop_modes': stop_mode,
        'time_limit': env.time_limit,
        'env_name': env_name,
        'backup_mode': 'sampling',
        'layers': str(layers),
        'validation_stop': validation_stop,
        'num_samples': num_samples,
        'log_sampling_type': sampling_type
    })

    if sampling_type.startswith('buffer'):
        fqi = replay_buffer_fqi.TabularBufferDQN(env, network, log_proj_qstar=False, **alg_args)
    elif sampling_type.startswith('sample'):
        fqi = sampling_fqi.WeightedSamplingFQI(env, network, log_proj_qstar=False, **alg_args)

    with log_utils.setup_logger(algo=fqi, exp_prefix=exp_prefix, log_base_dir='/data') as log_dir:
        print('Logging to %s' % log_dir)
        try:
            for k in range(300):
                fqi.update(step=k)
        except:
            log_utils.save_exception()