def testWeights(self, mode): fqi = sampling_fqi.WeightedSamplingFQI(self.env, self.network, weighting_scheme=mode, **self.alg_args) log_utils.reset_logger() fqi.update(1)
def main(exp_prefix='exp', layers=(32, 32), repeat=0, env_name='grid1', sampling_type=None, **alg_args): env = env_suite.get_env(env_name) if layers == 'tabular': network = q_networks.TabularNetwork(env) else: network = q_networks.FCNetwork(env, layers=layers) ptu.initialize_network(network) stop_mode = tuple() if sampling_type.endswith('512'): num_samples = 512 elif sampling_type.endswith('64'): num_samples = 64 elif sampling_type.endswith('32'): num_samples = 32 alg_args.update({ 'min_project_steps': 10, 'lr': 5e-3, 'discount': 0.95, 'n_steps': 1, 'batch_size': 128, 'stop_modes': stop_mode, 'time_limit': env.time_limit, 'env_name': env_name, 'backup_mode': 'sampling', 'layers': str(layers), 'num_samples': num_samples, 'log_sampling_type': sampling_type }) if sampling_type.startswith('buffer'): fqi = replay_buffer_fqi.TabularBufferDQN(env, network, log_proj_qstar=False, **alg_args) elif sampling_type.startswith('sample'): fqi = sampling_fqi.WeightedSamplingFQI(env, network, log_proj_qstar=False, **alg_args) with log_utils.setup_logger(algo=fqi, exp_prefix=exp_prefix, log_base_dir='./data') as log_dir: print('Logging to %s' % log_dir) try: for k in range(600): fqi.update(step=k) except: log_utils.save_exception()
def main(exp_prefix='exp', validation_stop=True, layers=(32, 32), repeat=0, env_name='grid1', sampling_type=None, **alg_args): env = env_suite.get_env(env_name) if layers == 'tabular': network = q_networks.TabularNetwork(env) else: network = q_networks.FCNetwork(env, layers=layers) ptu.initialize_network(network) stop_mode = (stopping.AtolStop(), stopping.RtolStop()) if sampling_type.endswith('512'): num_samples = 512 elif sampling_type.endswith('256'): num_samples = 256 elif sampling_type.endswith('128'): num_samples = 128 elif sampling_type.endswith('64'): num_samples = 64 elif sampling_type.endswith('32'): num_samples = 32 elif sampling_type.endswith('16'): num_samples = 16 elif sampling_type.endswith('8'): num_samples = 8 elif sampling_type.endswith('4'): num_samples = 4 alg_args.update({ 'min_project_steps': 10, 'max_project_steps': 200, 'lr': 5e-3, 'discount': 0.95, 'n_steps': 1, 'batch_size': 128, 'stop_modes': stop_mode, 'time_limit': env.time_limit, 'env_name': env_name, 'backup_mode': 'sampling', 'layers': str(layers), 'num_samples': num_samples, 'log_sampling_type': sampling_type }) if sampling_type.startswith('buffer'): fqi = replay_buffer_fqi.TabularBufferDQN(env, network, log_proj_qstar=False, **alg_args) elif sampling_type.startswith('sample'): fqi = sampling_fqi.WeightedSamplingFQI(env, network, log_proj_qstar=False, **alg_args) with log_utils.setup_logger(algo=fqi, exp_prefix=exp_prefix, log_base_dir='./data') as log_dir: print('Logging to %s' % log_dir) total_num_samples = 0 k = 0 try: while (total_num_samples < 128 * 200) and (k <= 2000): fqi.update(step=k) total_num_samples += num_samples k += 1 except: log_utils.save_exception()
def main(exp_prefix='exp', validation_stop=True, layers=(32, 32), repeat=0, env_name='grid1', sampling_type=None, **alg_args): env = env_suite.get_env(env_name) if layers == 'tabular': network = q_networks.TabularNetwork(env) else: network = q_networks.FCNetwork(env, layers=layers) ptu.initialize_network(network) max_project_steps = 200 if validation_stop == 'returns': raise NotImplementedError("TODO(justin): reimplement") elif validation_stop == 'bellman': stop_mode = (stopping.ValidationLoss(),) max_project_steps = 50 else: stop_mode = (stopping.AtolStop(), stopping.RtolStop()), if sampling_type.endswith('512'): num_samples = 512 elif sampling_type.endswith('256'): num_samples = 256 elif sampling_type.endswith('128'): num_samples = 128 elif sampling_type.endswith('64'): num_samples = 64 elif sampling_type.endswith('32'): num_samples = 32 elif sampling_type.endswith('16'): num_samples = 16 elif sampling_type.endswith('4'): num_samples = 4 alg_args.update({ 'min_project_steps': 10, 'max_project_steps': max_project_steps, 'lr': 5e-3, 'discount': 0.95, 'n_steps': 1, 'batch_size': 128, 'stop_modes': stop_mode, 'time_limit': env.time_limit, 'env_name': env_name, 'backup_mode': 'sampling', 'layers': str(layers), 'validation_stop': validation_stop, 'num_samples': num_samples, 'log_sampling_type': sampling_type }) if sampling_type.startswith('buffer'): fqi = replay_buffer_fqi.TabularBufferDQN(env, network, log_proj_qstar=False, **alg_args) elif sampling_type.startswith('sample'): fqi = sampling_fqi.WeightedSamplingFQI(env, network, log_proj_qstar=False, **alg_args) with log_utils.setup_logger(algo=fqi, exp_prefix=exp_prefix, log_base_dir='/data') as log_dir: print('Logging to %s' % log_dir) try: for k in range(300): fqi.update(step=k) except: log_utils.save_exception()