Esempio n. 1
0
 def testIntegerSamplesIncludeUpperBound(self, dtype):
     if dtype.is_floating_point:  # Only test on integer dtypes.
         return
     spec = BoundedTensorSpec(self._shape, dtype, 3, 3)
     sample = spec.sample()
     self.assertEqual(sample.shape, self._shape)
     self.assertTrue(torch.all(sample == 3))
Esempio n. 2
0
 def testIntegerSamplesExcludeMaxOfDtype(self, dtype):
     # Exclude non integer types and uint8 (has special sampling logic).
     if dtype.is_floating_point or dtype == torch.uint8:
         return
     info = np.iinfo(torch_dtype_to_str(dtype))
     spec = BoundedTensorSpec(self._shape, dtype, info.max - 1,
                              info.max - 1)
     sample = spec.sample(outer_dims=(1, ))
     self.assertEqual(sample.shape, (1, ) + self._shape)
     self.assertTrue(torch.all(sample == info.max - 1))
Esempio n. 3
0
    def testBoundedTensorSpecSample(self, dtype):
        if not dtype.is_floating_point:
            return
        # minimum and maximum shape broadcasting
        spec = BoundedTensorSpec(self._shape, dtype, (0, ) * 30, 3)
        sample = spec.sample()
        self.assertEqual(self._shape, sample.shape)
        self.assertTrue(torch.all(sample <= 3))
        self.assertTrue(torch.all(0 <= sample))

        # last minimum is greater than last maximum
        self.assertRaises(AssertionError, BoundedTensorSpec, self._shape,
                          dtype, (0, ) * 29 + (2, ), (1, ) * 30)