Esempio n. 1
0
def test_flatten_discrete():
    md = MultiDiscrete([(0, 2), (0, 3)])
    trafo = flatten(md)

    assert trafo.target == Discrete(12)
    # check that we get all actions exactly once
    actions = []
    for (i, j) in itertools.product([0, 1, 2], [0, 1, 2, 3]):
        actions += [(i, j)]
    for i in range(0, 12):
        a = trafo.convert_from(i)
        assert a in actions, (a, actions)
        assert trafo.convert_to(a) == i
        actions = list(filter(lambda x: x != a, list(actions)))
    assert len(actions) == 0

    # same test for binary
    md = MultiBinary(3)
    trafo = flatten(md)

    assert trafo.target == Discrete(2**3)
    # check that we get all actions exactly once
    actions = []
    for (i, j, k) in itertools.product([0, 1], [0, 1], [0, 1]):
        actions += [(i, j, k)]
    for i in range(0, 8):
        a = trafo.convert_from(i)
        assert trafo.convert_to(a) == i
        assert a in actions, (a, actions)
        actions = list(filter(lambda x: x != a, actions))
    assert len(actions) == 0
Esempio n. 2
0
def test_flatten_errors():
    class UnknownSpace(gym.Space):
        pass

    with pytest.raises(TypeError):
        flatten(5)

    with pytest.raises(NotImplementedError):
        flatten(UnknownSpace())
Esempio n. 3
0
def test_flatten_single():
    start = Discrete(5)
    trafo = flatten(start)
    assert trafo.target == start
    check_convert(trafo, 4, 4)

    start = Box(np.array([0.0]), np.array([1.0]))
    trafo = flatten(start)
    assert trafo.target == start
    check_convert(trafo, 0.5, 0.5)
Esempio n. 4
0
def test_flatten_tuple_recursive():
    s1 = Box(np.zeros((2, 2)), np.ones((2, 2)), dtype=np.float32)
    s2 = Box(np.ones(2), np.ones(2) * 2, dtype=np.float32)
    trafo = flatten(Tuple((s1, s2)))

    assert trafo.target == Box(np.asarray([0.0, 0, 0, 0, 1, 1]), np.asarray([1.0, 1, 1, 1, 2, 2]), dtype=np.float32)
    assert trafo.convert_to(([[0, 1], [1, 0]], [1, 2])) == pytest.approx([0, 1, 1, 0, 1, 2], )
    assert trafo.convert_from([0, 1, 1, 0, 1, 2]) == (pytest.approx(np.asarray([[0, 1], [1, 0]])), pytest.approx([1, 2]))
Esempio n. 5
0
def test_flatten_continuous():
    ct = Box(np.zeros((2,2)), np.ones((2, 2)))
    trafo = flatten(ct)

    assert trafo.target == Box(np.zeros(4), np.ones(4))
    check_convert(trafo, [1, 2, 3, 4], [[1, 2], [3, 4]])