コード例 #1
0
ファイル: test_datatypes.py プロジェクト: sebimarkgraf/rllib
    def test_example(self, discrete, dim_state, dim_action, kind):
        if discrete:
            num_states, num_actions = dim_state, dim_action
            dim_state, dim_action = (), ()
        else:
            num_states, num_actions = -1, -1
            dim_state, dim_action = (dim_state,), (dim_action,)
        if kind == "nan":
            o = Observation.nan_example(
                dim_state=dim_state,
                dim_action=dim_action,
                num_states=num_states,
                num_actions=num_actions,
            )
        elif kind == "zero":
            o = Observation.zero_example(
                dim_state=dim_state,
                dim_action=dim_action,
                num_states=num_states,
                num_actions=num_actions,
            )
        elif kind == "random":
            o = Observation.random_example(
                dim_state=dim_state,
                dim_action=dim_action,
                num_states=num_states,
                num_actions=num_actions,
            )
        else:
            with pytest.raises(ValueError):
                Observation.get_example(
                    dim_state=dim_state,
                    dim_action=dim_action,
                    num_states=num_states,
                    num_actions=num_actions,
                    kind=kind,
                )
            return

        if discrete:
            torch.testing.assert_allclose(o.state.shape, torch.Size([]))
            torch.testing.assert_allclose(o.action.shape, torch.Size([]))
            torch.testing.assert_allclose(o.next_state.shape, torch.Size([]))
            torch.testing.assert_allclose(o.log_prob_action, torch.tensor(1.0))

        else:
            torch.testing.assert_allclose(o.state.shape, torch.Size(dim_state))
            torch.testing.assert_allclose(o.action.shape, torch.Size(dim_action))
            torch.testing.assert_allclose(o.next_state.shape, torch.Size(dim_state))
            torch.testing.assert_allclose(o.log_prob_action, torch.tensor(1.0))
コード例 #2
0
def create_er_from_transitions(discrete, dim_state, dim_action, max_len,
                               num_steps, num_transitions):
    """Create a memory with `num_transitions' transitions."""
    if discrete:
        num_states, num_actions = dim_state, dim_action
        dim_state, dim_action = (), ()
    else:
        num_states, num_actions = -1, -1
        dim_state, dim_action = (dim_state, ), (dim_action, )

    memory = ExperienceReplay(max_len, num_steps=num_steps)
    for _ in range(num_transitions):
        observation = Observation.random_example(
            dim_state=dim_state,
            dim_action=dim_action,
            num_states=num_states,
            num_actions=num_actions,
        )
        memory.append(observation)
    return memory
コード例 #3
0
ファイル: test_datatypes.py プロジェクト: sebimarkgraf/rllib
    def test_clone(self, discrete, dim_state, dim_action):
        if discrete:
            num_states, num_actions = dim_state, dim_action
            dim_state, dim_action = (), ()
        else:
            num_states, num_actions = -1, -1
            dim_state, dim_action = (dim_state,), (dim_action,)

        o = Observation.random_example(
            dim_state=dim_state,
            dim_action=dim_action,
            num_states=num_states,
            num_actions=num_actions,
        )
        o1 = o.clone()
        assert o is not o1
        assert o == o1
        for x, x1 in zip(o, o1):
            assert Observation._is_equal_nan(x, x1)
            assert x is not x1
コード例 #4
0
    def test_append(self, discrete, dim_state, dim_action, max_len, num_steps):
        num_transitions = 200
        memory = create_er_from_transitions(discrete, dim_state, dim_action,
                                            max_len, num_steps,
                                            num_transitions)
        if discrete:
            num_states, num_actions = dim_state, dim_action
            dim_state, dim_action = (), ()
        else:
            num_states, num_actions = -1, -1
            dim_state, dim_action = (dim_state, ), (dim_action, )
        observation = Observation.random_example(
            dim_state=dim_state,
            dim_action=dim_action,
            num_states=num_states,
            num_actions=num_actions,
        )

        memory.append(observation)
        assert memory.valid[(memory.ptr - 1) % max_len] == 1
        assert memory.valid[(memory.ptr - 2) % max_len] == 1
        for i in range(num_steps):
            assert memory.valid[(memory.ptr + i) % max_len] == 0
        assert memory.memory[(memory.ptr - 1) % max_len] is not observation