def test_example(self, discrete, dim_state, dim_action, kind): if discrete: num_states, num_actions = dim_state, dim_action dim_state, dim_action = (), () else: num_states, num_actions = -1, -1 dim_state, dim_action = (dim_state,), (dim_action,) if kind == "nan": o = Observation.nan_example( dim_state=dim_state, dim_action=dim_action, num_states=num_states, num_actions=num_actions, ) elif kind == "zero": o = Observation.zero_example( dim_state=dim_state, dim_action=dim_action, num_states=num_states, num_actions=num_actions, ) elif kind == "random": o = Observation.random_example( dim_state=dim_state, dim_action=dim_action, num_states=num_states, num_actions=num_actions, ) else: with pytest.raises(ValueError): Observation.get_example( dim_state=dim_state, dim_action=dim_action, num_states=num_states, num_actions=num_actions, kind=kind, ) return if discrete: torch.testing.assert_allclose(o.state.shape, torch.Size([])) torch.testing.assert_allclose(o.action.shape, torch.Size([])) torch.testing.assert_allclose(o.next_state.shape, torch.Size([])) torch.testing.assert_allclose(o.log_prob_action, torch.tensor(1.0)) else: torch.testing.assert_allclose(o.state.shape, torch.Size(dim_state)) torch.testing.assert_allclose(o.action.shape, torch.Size(dim_action)) torch.testing.assert_allclose(o.next_state.shape, torch.Size(dim_state)) torch.testing.assert_allclose(o.log_prob_action, torch.tensor(1.0))
def create_er_from_transitions(discrete, dim_state, dim_action, max_len, num_steps, num_transitions): """Create a memory with `num_transitions' transitions.""" if discrete: num_states, num_actions = dim_state, dim_action dim_state, dim_action = (), () else: num_states, num_actions = -1, -1 dim_state, dim_action = (dim_state, ), (dim_action, ) memory = ExperienceReplay(max_len, num_steps=num_steps) for _ in range(num_transitions): observation = Observation.random_example( dim_state=dim_state, dim_action=dim_action, num_states=num_states, num_actions=num_actions, ) memory.append(observation) return memory
def test_clone(self, discrete, dim_state, dim_action): if discrete: num_states, num_actions = dim_state, dim_action dim_state, dim_action = (), () else: num_states, num_actions = -1, -1 dim_state, dim_action = (dim_state,), (dim_action,) o = Observation.random_example( dim_state=dim_state, dim_action=dim_action, num_states=num_states, num_actions=num_actions, ) o1 = o.clone() assert o is not o1 assert o == o1 for x, x1 in zip(o, o1): assert Observation._is_equal_nan(x, x1) assert x is not x1
def test_append(self, discrete, dim_state, dim_action, max_len, num_steps): num_transitions = 200 memory = create_er_from_transitions(discrete, dim_state, dim_action, max_len, num_steps, num_transitions) if discrete: num_states, num_actions = dim_state, dim_action dim_state, dim_action = (), () else: num_states, num_actions = -1, -1 dim_state, dim_action = (dim_state, ), (dim_action, ) observation = Observation.random_example( dim_state=dim_state, dim_action=dim_action, num_states=num_states, num_actions=num_actions, ) memory.append(observation) assert memory.valid[(memory.ptr - 1) % max_len] == 1 assert memory.valid[(memory.ptr - 2) % max_len] == 1 for i in range(num_steps): assert memory.valid[(memory.ptr + i) % max_len] == 0 assert memory.memory[(memory.ptr - 1) % max_len] is not observation