예제 #1
0
 def test_convert_theano(self):
     tup = Tuple((Box(0.0, 1.0, (3, 4)), Discrete(2)))
     tensor_tup = tup.to_theano_tensor('test', 1)
     assert isinstance(tensor_tup, tuple)
     assert all(
         [isinstance(c, theano.tensor.TensorVariable) for c in tensor_tup])
     assert [c.dtype for c in tensor_tup] == ['float32', 'int64']
예제 #2
0
 def test_convert_tf(self):
     tup = Tuple((Box(0.0, 1.0, (3, 4)), Discrete(2)))
     tensor_tup = tup.to_tf_placeholder('test', 1)
     assert isinstance(tensor_tup, tuple)
     assert all([isinstance(c, tf.Tensor) for c in tensor_tup])
     assert [c.dtype for c in tensor_tup] == [tf.float32, tf.int64]
     assert [c.get_shape().as_list() for c in tensor_tup] == [[None, 3, 4],
                                                              [None, 2]]
예제 #3
0
 def test_flatten_n(self):
     disc = Discrete(3)
     tup = Tuple((Discrete(2), disc))
     obs = disc.flatten_n(np.asarray([0, 1, 0, 1, 2]))
     unflat_ret = tup.unflatten_n(obs)
     flat_ret = tup.flatten_n(unflat_ret)
     base = np.asarray([[1., 0., 1., 0., 0.]])
     assert np.array_equal(flat_ret, base)
예제 #4
0
파일: base.py 프로젝트: quantumahesh/garage
    def _to_akro_space(self, space):
        """
        Converts a gym.space into an akro.space.

        Args:
            space (gym.spaces)

        Returns:
            space (akro.spaces)
        """
        if isinstance(space, GymBox):
            return Box(low=space.low, high=space.high, dtype=space.dtype)
        elif isinstance(space, GymDict):
            return Dict(space.spaces)
        elif isinstance(space, GymDiscrete):
            return Discrete(space.n)
        elif isinstance(space, GymTuple):
            return Tuple(list(map(self._to_akro_space, space.spaces)))
        else:
            raise NotImplementedError
예제 #5
0
 def test_hash(self):
     tup1 = Tuple((Discrete(3), Discrete(2)))
     tup2 = Tuple((Discrete(3), Discrete(2)))
     assert tup1.__hash__() == tup2.__hash__()
예제 #6
0
    def test_concat(self):
        tup1 = Tuple((Box(0, 1, (5, )), Box(0, 1, (10, ))))
        tup2 = Tuple((Box(0, 1, (5, )), Box(0, 1, (10, ))))
        concat_tup = tup1.concat(tup2)

        assert concat_tup.flat_dim == 30
예제 #7
0
 def test_unflatten_n(self):
     disc = Discrete(3)
     tup = Tuple((Discrete(2), disc))
     obs = disc.flatten_n(np.asarray([0, 1, 0, 1, 2]))
     ret = tup.unflatten_n(obs)
     assert ret == [(0, 0)]
예제 #8
0
 def test_unflatten(self):
     tup = Tuple((Discrete(3), Discrete(2)))
     x = [2, 0]
     arr = tup.flatten(x)
     assert tup.unflatten(arr) == (2, 0)
예제 #9
0
 def test_flatten(self):
     tup = Tuple((Discrete(3), Discrete(2)))
     x = [2, 0]
     arr = tup.flatten(x)
     assert arr[2] == arr[3] == 1
예제 #10
0
 def test_flat_dim(self):
     tup = Tuple((Discrete(3), Discrete(2)))
     assert tup.flat_dim == 5
예제 #11
0
 def test_pickleable(self):
     tup = Tuple((Discrete(3), Discrete(2)))
     round_trip = pickle.loads(pickle.dumps(tup))
     assert round_trip
     assert round_trip.spaces == tup.spaces
예제 #12
0
 def test_hash(self):
     tup = Tuple((Discrete(3), Discrete(2)))
     assert tup.__hash__() == 3713083796995235906