def test_flatten_discrete(): md = MultiDiscrete([(0, 2), (0, 3)]) trafo = flatten(md) assert trafo.target == Discrete(12) # check that we get all actions exactly once actions = [] for (i, j) in itertools.product([0, 1, 2], [0, 1, 2, 3]): actions += [(i, j)] for i in range(0, 12): a = trafo.convert_from(i) assert a in actions, (a, actions) assert trafo.convert_to(a) == i actions = list(filter(lambda x: x != a, list(actions))) assert len(actions) == 0 # same test for binary md = MultiBinary(3) trafo = flatten(md) assert trafo.target == Discrete(2**3) # check that we get all actions exactly once actions = [] for (i, j, k) in itertools.product([0, 1], [0, 1], [0, 1]): actions += [(i, j, k)] for i in range(0, 8): a = trafo.convert_from(i) assert trafo.convert_to(a) == i assert a in actions, (a, actions) actions = list(filter(lambda x: x != a, actions)) assert len(actions) == 0
def test_flatten_errors(): class UnknownSpace(gym.Space): pass with pytest.raises(TypeError): flatten(5) with pytest.raises(NotImplementedError): flatten(UnknownSpace())
def test_flatten_single(): start = Discrete(5) trafo = flatten(start) assert trafo.target == start check_convert(trafo, 4, 4) start = Box(np.array([0.0]), np.array([1.0])) trafo = flatten(start) assert trafo.target == start check_convert(trafo, 0.5, 0.5)
def test_flatten_tuple_recursive(): s1 = Box(np.zeros((2, 2)), np.ones((2, 2)), dtype=np.float32) s2 = Box(np.ones(2), np.ones(2) * 2, dtype=np.float32) trafo = flatten(Tuple((s1, s2))) assert trafo.target == Box(np.asarray([0.0, 0, 0, 0, 1, 1]), np.asarray([1.0, 1, 1, 1, 2, 2]), dtype=np.float32) assert trafo.convert_to(([[0, 1], [1, 0]], [1, 2])) == pytest.approx([0, 1, 1, 0, 1, 2], ) assert trafo.convert_from([0, 1, 1, 0, 1, 2]) == (pytest.approx(np.asarray([[0, 1], [1, 0]])), pytest.approx([1, 2]))
def test_flatten_continuous(): ct = Box(np.zeros((2,2)), np.ones((2, 2))) trafo = flatten(ct) assert trafo.target == Box(np.zeros(4), np.ones(4)) check_convert(trafo, [1, 2, 3, 4], [[1, 2], [3, 4]])