def test_get_params_internal(self, obs_dim): box_env = TfEnv(DummyBoxEnv(obs_dim=obs_dim)) with mock.patch(('garage.tf.baselines.' 'continuous_mlp_baseline_with_model.' 'ContinuousMLPRegressorWithModel'), new=SimpleMLPRegressor): cmb = ContinuousMLPBaselineWithModel(env_spec=box_env.spec) params_interal = cmb.get_params_internal() trainable_params = tf.trainable_variables( scope='ContinuousMLPBaselineWithModel') assert np.array_equal(params_interal, trainable_params)
def test_ppo_pendulum_continuous_baseline(self): """Test PPO with Pendulum environment.""" with LocalRunner(self.sess) as runner: env = TfEnv(normalize(gym.make('InvertedDoublePendulum-v2'))) policy = GaussianMLPPolicy( env_spec=env.spec, hidden_sizes=(64, 64), hidden_nonlinearity=tf.nn.tanh, output_nonlinearity=None, ) baseline = ContinuousMLPBaselineWithModel( env_spec=env.spec, regressor_args=dict(hidden_sizes=(32, 32)), ) algo = PPO(env_spec=env.spec, policy=policy, baseline=baseline, max_path_length=100, discount=0.99, lr_clip_range=0.01, optimizer_args=dict(batch_size=32, max_epochs=10)) runner.setup(algo, env) last_avg_ret = runner.train(n_epochs=10, batch_size=2048) assert last_avg_ret > 30 env.close()
def test_ppo_pendulum_recurrent_continuous_baseline(self): """Test PPO with Pendulum environment and recurrent policy.""" with LocalTFRunner(snapshot_config) as runner: env = TfEnv(normalize(gym.make('InvertedDoublePendulum-v2'))) policy = GaussianLSTMPolicy(env_spec=env.spec, ) baseline = ContinuousMLPBaselineWithModel( env_spec=env.spec, regressor_args=dict(hidden_sizes=(32, 32)), ) algo = PPO( env_spec=env.spec, policy=policy, baseline=baseline, max_path_length=100, discount=0.99, gae_lambda=0.95, lr_clip_range=0.2, optimizer_args=dict( batch_size=32, max_epochs=10, ), stop_entropy_gradient=True, entropy_method='max', policy_ent_coeff=0.02, center_adv=False, ) runner.setup(algo, env) last_avg_ret = runner.train(n_epochs=10, batch_size=2048) assert last_avg_ret > 100 env.close()
def test_fit(self, obs_dim): box_env = TfEnv(DummyBoxEnv(obs_dim=obs_dim)) with mock.patch(('garage.tf.baselines.' 'continuous_mlp_baseline_with_model.' 'ContinuousMLPRegressorWithModel'), new=SimpleMLPRegressor): cmb = ContinuousMLPBaselineWithModel(env_spec=box_env.spec) paths = [{ 'observations': [np.full(obs_dim, 1)], 'returns': [1] }, { 'observations': [np.full(obs_dim, 2)], 'returns': [2] }] cmb.fit(paths) obs = {'observations': [np.full(obs_dim, 1), np.full(obs_dim, 2)]} prediction = cmb.predict(obs) assert np.array_equal(prediction, [1, 2])
def test_is_pickleable(self): box_env = TfEnv(DummyBoxEnv(obs_dim=(1, ))) with mock.patch(('garage.tf.baselines.' 'continuous_mlp_baseline_with_model.' 'ContinuousMLPRegressorWithModel'), new=SimpleMLPRegressor): cmb = ContinuousMLPBaselineWithModel(env_spec=box_env.spec) obs = {'observations': [np.full(1, 1), np.full(1, 1)]} with tf.variable_scope('ContinuousMLPBaselineWithModel', reuse=True): return_var = tf.get_variable('SimpleMLPModel/return_var') return_var.load(1.0) prediction = cmb.predict(obs) h = pickle.dumps(cmb) with tf.Session(graph=tf.Graph()): cmb_pickled = pickle.loads(h) prediction2 = cmb_pickled.predict(obs) assert np.array_equal(prediction, prediction2)
def test_param_values(self, obs_dim): box_env = TfEnv(DummyBoxEnv(obs_dim=obs_dim)) with mock.patch(('garage.tf.baselines.' 'continuous_mlp_baseline_with_model.' 'ContinuousMLPRegressorWithModel'), new=SimpleMLPRegressor): cmb = ContinuousMLPBaselineWithModel(env_spec=box_env.spec) new_cmb = ContinuousMLPBaselineWithModel( env_spec=box_env.spec, name='ContinuousMLPBaselineWithModel2') # Manual change the parameter of ContinuousMLPBaselineWithModel with tf.variable_scope('ContinuousMLPBaselineWithModel2', reuse=True): return_var = tf.get_variable('SimpleMLPModel/return_var') return_var.load(1.0) old_param_values = cmb.get_param_values() new_param_values = new_cmb.get_param_values() assert not np.array_equal(old_param_values, new_param_values) new_cmb.set_param_values(old_param_values) new_param_values = new_cmb.get_param_values() assert np.array_equal(old_param_values, new_param_values)