def test_get_action(self, mock_rand, obs_dim, action_dim, filter_dims, filter_sizes, strides, padding, hidden_sizes): mock_rand.return_value = 0 env = TfEnv(DummyDiscreteEnv(obs_dim=obs_dim, action_dim=action_dim)) with mock.patch(('garage.tf.policies.' 'categorical_cnn_policy.MLPModel'), new=SimpleMLPModel): with mock.patch(('garage.tf.policies.' 'categorical_cnn_policy.CNNModel'), new=SimpleCNNModel): policy = CategoricalCNNPolicy(env_spec=env.spec, conv_filters=filter_dims, conv_filter_sizes=filter_sizes, conv_strides=strides, conv_pad=padding, hidden_sizes=hidden_sizes) env.reset() obs, _, _, _ = env.step(1) action, prob = policy.get_action(obs) expected_prob = np.full(action_dim, 0.5) assert env.action_space.contains(action) assert action == 0 assert np.array_equal(prob['prob'], expected_prob) actions, probs = policy.get_actions([obs, obs, obs]) for action, prob in zip(actions, probs['prob']): assert env.action_space.contains(action) assert action == 0 assert np.array_equal(prob, expected_prob)
def test_get_action(self, filters, strides, padding, hidden_sizes): env = GymEnv(DummyDiscretePixelEnv()) policy = CategoricalCNNPolicy(env_spec=env.spec, filters=filters, strides=strides, padding=padding, hidden_sizes=hidden_sizes) env.reset() obs = env.step(1).observation action, _ = policy.get_action(obs) assert env.action_space.contains(action) actions, _ = policy.get_actions([obs, obs, obs]) for action in actions: assert env.action_space.contains(action)
class TestCategoricalCNNPolicyImageObs(TfGraphTestCase): def setup_method(self): super().setup_method() self.env = GymEnv(DummyDiscretePixelEnv(), is_image=True) self.sess.run(tf.compat.v1.global_variables_initializer()) self.env.reset() @pytest.mark.parametrize('filters, strides, padding, hidden_sizes', [ (((3, (32, 32)), ), (1, ), 'VALID', (4, )), ]) def test_obs_unflattened(self, filters, strides, padding, hidden_sizes): self.policy = CategoricalCNNPolicy(env_spec=self.env.spec, filters=filters, strides=strides, padding=padding, hidden_sizes=hidden_sizes) obs = self.env.observation_space.sample() action, _ = self.policy.get_action( self.env.observation_space.flatten(obs)) self.env.step(action)
def test_get_action(self, filters, strides, padding, hidden_sizes): env = GarageEnv(DummyDiscretePixelEnv()) policy = CategoricalCNNPolicy(env_spec=env.spec, filters=filters, strides=strides, padding=padding, hidden_sizes=hidden_sizes) obs_var = tf.compat.v1.placeholder(tf.float32, shape=(None, None) + env.observation_space.shape, name='obs') policy.build(obs_var) env.reset() obs, _, _, _ = env.step(1) action, _ = policy.get_action(obs) assert env.action_space.contains(action) actions, _ = policy.get_actions([obs, obs, obs]) for action in actions: assert env.action_space.contains(action)