def test_example(self, discrete, dim_state, dim_action, kind): if discrete: num_states, num_actions = dim_state, dim_action dim_state, dim_action = (), () else: num_states, num_actions = -1, -1 dim_state, dim_action = (dim_state,), (dim_action,) if kind == "nan": o = Observation.nan_example( dim_state=dim_state, dim_action=dim_action, num_states=num_states, num_actions=num_actions, ) elif kind == "zero": o = Observation.zero_example( dim_state=dim_state, dim_action=dim_action, num_states=num_states, num_actions=num_actions, ) elif kind == "random": o = Observation.random_example( dim_state=dim_state, dim_action=dim_action, num_states=num_states, num_actions=num_actions, ) else: with pytest.raises(ValueError): Observation.get_example( dim_state=dim_state, dim_action=dim_action, num_states=num_states, num_actions=num_actions, kind=kind, ) return if discrete: torch.testing.assert_allclose(o.state.shape, torch.Size([])) torch.testing.assert_allclose(o.action.shape, torch.Size([])) torch.testing.assert_allclose(o.next_state.shape, torch.Size([])) torch.testing.assert_allclose(o.log_prob_action, torch.tensor(1.0)) else: torch.testing.assert_allclose(o.state.shape, torch.Size(dim_state)) torch.testing.assert_allclose(o.action.shape, torch.Size(dim_action)) torch.testing.assert_allclose(o.next_state.shape, torch.Size(dim_state)) torch.testing.assert_allclose(o.log_prob_action, torch.tensor(1.0))
def _init_observation(self, observation): if observation.state.ndim == 0: dim_state, num_states = 1, 1 else: dim_state, num_states = observation.state.shape[-1], -1 if observation.action.ndim == 0: dim_action, num_actions = 1, 1 else: dim_action, num_actions = observation.action.shape[-1], -1 self.zero_observation = Observation.zero_example( dim_state=dim_state, dim_action=dim_action, num_states=num_states, num_actions=num_actions, )