示例#1
0
 def test_convert_tf(self):
     d = Dict({'position': Discrete(2), 'velocity': Discrete(3)})
     tensor_dict = d.to_tf_placeholder('test', 1)
     assert isinstance(tensor_dict, Dict)
     assert all(
         [isinstance(c, tf.Tensor) for c in tensor_dict.spaces.values()])
     assert all([v.dtype == tf.int64 for v in tensor_dict.spaces.values()])
示例#2
0
 def test_unflatten_n(self):
     disc = Discrete(3)
     obs = np.asarray([0, 1, 2])
     flat_arr = disc.flatten_n(obs)
     base = np.asarray([0, 1, 2])
     unflat_arr = disc.unflatten_n(flat_arr)
     assert np.array_equal(unflat_arr, base)
示例#3
0
 def test_convert_tf(self):
     disc = Discrete(10)
     tensor = disc.to_tf_placeholder('test', 1)
     assert isinstance(tensor, tf.Tensor)
     assert disc.dtype == np.int64
     assert tensor.dtype == tf.int64
     assert tensor.get_shape().as_list() == [None, 10]
示例#4
0
 def test_flatten_n(self):
     disc = Discrete(3)
     tup = Tuple((Discrete(2), disc))
     obs = disc.flatten_n(np.asarray([0, 1, 0, 1, 2]))
     unflat_ret = tup.unflatten_n(obs)
     flat_ret = tup.flatten_n(unflat_ret)
     base = np.asarray([[1., 0., 1., 0., 0.]])
     assert np.array_equal(flat_ret, base)
示例#5
0
 def test_convert_theano(self):
     d = Dict({'position': Discrete(2), 'velocity': Discrete(3)})
     tensor_dict = d.to_theano_tensor('test', 1)
     assert isinstance(tensor_dict, Dict)
     assert all([
         isinstance(c, theano.tensor.TensorVariable)
         for c in tensor_dict.spaces.values()
     ])
     assert all(
         [space.dtype == 'int64' for space in tensor_dict.spaces.values()])
示例#6
0
    def test_pickleable(self):
        motion_dict = {'position': Discrete(2), 'velocity': Discrete(3)}
        sample = {
            'position': 1,
            'velocity': 2,
        }
        d = Dict(motion_dict)
        round_trip = pickle.loads(pickle.dumps(d))

        assert d.contains(sample)
        assert round_trip
        assert round_trip.contains(sample)
示例#7
0
 def test_convert_theano(self):
     tup = Tuple((Box(0.0, 1.0, (3, 4)), Discrete(2)))
     tensor_tup = tup.to_theano_tensor('test', 1)
     assert isinstance(tensor_tup, tuple)
     assert all(
         [isinstance(c, theano.tensor.TensorVariable) for c in tensor_tup])
     assert [c.dtype for c in tensor_tup] == ['float32', 'int64']
示例#8
0
 def test_convert_tf(self):
     tup = Tuple((Box(0.0, 1.0, (3, 4)), Discrete(2)))
     tensor_tup = tup.to_tf_placeholder('test', 1)
     assert isinstance(tensor_tup, tuple)
     assert all([isinstance(c, tf.Tensor) for c in tensor_tup])
     assert [c.dtype for c in tensor_tup] == [tf.float32, tf.int64]
     assert [c.get_shape().as_list() for c in tensor_tup] == [[None, 3, 4],
                                                              [None, 2]]
示例#9
0
    def _to_akro_space(self, space):
        """
        Converts a gym.space into an akro.space.

        Args:
            space (gym.spaces)

        Returns:
            space (akro.spaces)
        """
        if isinstance(space, GymBox):
            return Box(low=space.low, high=space.high, dtype=space.dtype)
        elif isinstance(space, GymDict):
            return Dict(space.spaces)
        elif isinstance(space, GymDiscrete):
            return Discrete(space.n)
        elif isinstance(space, GymTuple):
            return Tuple(list(map(self._to_akro_space, space.spaces)))
        else:
            raise NotImplementedError
示例#10
0
 def test_hash(self):
     disc1 = Discrete(10)
     disc2 = Discrete(10)
     assert disc1.__hash__() == disc2.__hash__()
示例#11
0
 def test_weighted_sample_unnormalized(self):
     disc = Discrete(4)
     weights = np.array([1., 2., 3., 5.])
     res = disc.weighted_sample(weights)
     assert res >= 0 and res < disc.n
示例#12
0
 def test_hash(self):
     disc = Discrete(10)
     assert disc.__hash__() == 10
示例#13
0
 def test_pickleable(self):
     tup = Tuple((Discrete(3), Discrete(2)))
     round_trip = pickle.loads(pickle.dumps(tup))
     assert round_trip
     assert round_trip.spaces == tup.spaces
示例#14
0
 def test_weighted_sample(self):
     disc = Discrete(4)
     weights = [0.1, 0.2, 0.3, 0.4]
     res = disc.weighted_sample(weights)
     assert res >= 0 and res < disc.n
示例#15
0
 def test_unflatten_n(self):
     disc = Discrete(3)
     tup = Tuple((Discrete(2), disc))
     obs = disc.flatten_n(np.asarray([0, 1, 0, 1, 2]))
     ret = tup.unflatten_n(obs)
     assert ret == [(0, 0)]
示例#16
0
 def test_flatten_n(self):
     disc = Discrete(3)
     obs = np.asarray([0, 1, 2])
     arr = disc.flatten_n(obs)
     base = np.asarray([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]])
     assert np.array_equal(arr, base)
示例#17
0
 def test_flat_dim(self):
     disc = Discrete(10)
     assert disc.flat_dim == 10
示例#18
0
 def test_flatten(self):
     disc = Discrete(10)
     x = [3, 5, 7]
     arr = disc.flatten(x)
     assert all(arr[x] == 1)
示例#19
0
 def test_flat_dim(self):
     tup = Tuple((Discrete(3), Discrete(2)))
     assert tup.flat_dim == 5
示例#20
0
 def test_pickleable(self):
     obj = Discrete(10)
     round_trip = pickle.loads(pickle.dumps(obj))
     assert round_trip
     assert round_trip.n == 10
示例#21
0
 def test_concat(self):
     with pytest.raises(NotImplementedError):
         disc1 = Discrete(4)
         disc2 = Discrete(4)
         disc1.concat(disc2)
示例#22
0
 def test_flatten(self):
     tup = Tuple((Discrete(3), Discrete(2)))
     x = [2, 0]
     arr = tup.flatten(x)
     assert arr[2] == arr[3] == 1
示例#23
0
 def test_unflatten(self):
     tup = Tuple((Discrete(3), Discrete(2)))
     x = [2, 0]
     arr = tup.flatten(x)
     assert tup.unflatten(arr) == (2, 0)
示例#24
0
 def test_convert_theano(self):
     disc = Discrete(10)
     tensor = disc.to_theano_tensor('test', 1)
     assert isinstance(tensor, theano.tensor.TensorVariable)
     assert disc.dtype == np.int64
     assert tensor.dtype == 'int64'
示例#25
0
 def test_unflatten(self):
     disc = Discrete(10)
     x = [3, 5, 7]
     arr = disc.flatten(x)
     assert disc.unflatten(arr) == 3
示例#26
0
 def test_hash(self):
     tup = Tuple((Discrete(3), Discrete(2)))
     assert tup.__hash__() == 3713083796995235906
示例#27
0
 def test_hash(self):
     tup1 = Tuple((Discrete(3), Discrete(2)))
     tup2 = Tuple((Discrete(3), Discrete(2)))
     assert tup1.__hash__() == tup2.__hash__()