예제 #1
0
 def test_not_contains(self):
     space = spaces.OneHot(3)
     self.assertFalse(space.contains(np.array([1,0,0,0])))
     self.assertFalse(space.contains(np.array([[1,0],[0,0],[0,0]])))
     self.assertFalse(space.contains(np.array([1,1,0])))
     self.assertFalse(space.contains(np.array([1,2,0])))
     self.assertFalse(space.contains(np.array([])))
예제 #2
0
 def test_0(self):
     with self.assertRaises(AssertionError):
         spaces.OneHot(0)
예제 #3
0
 def test_contains_dtype(self):
     space = spaces.OneHot(3, dtype=np.uint8)
     self.assertTrue(space.contains(np.array([1,0,0], dtype=np.float32))) #maybe we want this behaviour?
예제 #4
0
 def test_dtype(self):
     space = spaces.OneHot(3, dtype=np.uint8)
     self.assertEqual(space.dtype, np.uint8)
     self.assertEqual(space.sample().dtype, np.uint8)
예제 #5
0
 def test_shape(self):
     space = spaces.OneHot(4)
     self.assertEqual(space.shape, (4,))
예제 #6
0
 def test_sample_contains(self):
     space = spaces.OneHot(3)
     for i in range(10):
         self.assertTrue(space.contains(space.sample()))
예제 #7
0
 def test_1(self):
     self.assertEqual(spaces.OneHot(1).sample(), np.array([1.]))
예제 #8
0
 def test_n(self):
     space = gym.spaces.Discrete(5)
     p = policy.onehot(policy.uniform(space))
     onehot_space = spaces.OneHot(5)
     for i in range(10):
         space.contains(p(None))