def test_get_action_dict_space(self): """Test if observations from dict obs spaces are properly flattened.""" env = GymEnv(DummyDictEnv(obs_space_type='box', act_space_type='box')) policy = TanhGaussianMLPPolicy(env_spec=env.spec, hidden_nonlinearity=None, hidden_sizes=(1, ), hidden_w_init=nn.init.ones_, output_w_init=nn.init.ones_) obs = env.reset()[0] action, _ = policy.get_action(obs) assert env.action_space.shape == action.shape actions, _ = policy.get_actions(np.array([obs, obs])) for action in actions: assert env.action_space.shape == action.shape
def test_get_action_np(self, hidden_sizes): """Test Policy get action function with numpy inputs.""" env_spec = GymEnv(DummyBoxEnv()) obs_dim = env_spec.observation_space.flat_dim act_dim = env_spec.action_space.flat_dim obs = np.ones((obs_dim, ), dtype=np.float32) init_std = 2. policy = TanhGaussianMLPPolicy(env_spec=env_spec, hidden_sizes=hidden_sizes, init_std=init_std, hidden_nonlinearity=None, std_parameterization='exp', hidden_w_init=nn.init.ones_, output_w_init=nn.init.ones_) expected_mean = torch.full((act_dim, ), 1.0, dtype=torch.float) action, prob = policy.get_action(obs) assert np.allclose(prob['mean'], expected_mean.numpy(), rtol=1e-3) assert action.shape == (act_dim, )