def test_get_action(self, mock_rand, obs_dim, action_dim, filter_dims, filter_sizes, strides, padding, hidden_sizes): mock_rand.return_value = 0 env = TfEnv(DummyDiscreteEnv(obs_dim=obs_dim, action_dim=action_dim)) with mock.patch(('metarl.tf.policies.' 'categorical_cnn_policy.MLPModel'), new=SimpleMLPModel): with mock.patch(('metarl.tf.policies.' 'categorical_cnn_policy.CNNModel'), new=SimpleCNNModel): policy = CategoricalCNNPolicy(env_spec=env.spec, conv_filters=filter_dims, conv_filter_sizes=filter_sizes, conv_strides=strides, conv_pad=padding, hidden_sizes=hidden_sizes) env.reset() obs, _, _, _ = env.step(1) action, prob = policy.get_action(obs) expected_prob = np.full(action_dim, 0.5) assert env.action_space.contains(action) assert action == 0 assert np.array_equal(prob['prob'], expected_prob) actions, probs = policy.get_actions([obs, obs, obs]) for action, prob in zip(actions, probs['prob']): assert env.action_space.contains(action) assert action == 0 assert np.array_equal(prob, expected_prob)
def test_get_action(self, filters, strides, padding, hidden_sizes): env = MetaRLEnv(DummyDiscretePixelEnv()) policy = CategoricalCNNPolicy(env_spec=env.spec, filters=filters, strides=strides, padding=padding, hidden_sizes=hidden_sizes) obs_var = tf.compat.v1.placeholder(tf.float32, shape=(None, None) + env.observation_space.shape, name='obs') policy.build(obs_var) env.reset() obs, _, _, _ = env.step(1) action, _ = policy.get_action(obs) assert env.action_space.contains(action) actions, _ = policy.get_actions([obs, obs, obs]) for action in actions: assert env.action_space.contains(action)