def testIntegerSamplesIncludeUpperBound(self, dtype): if dtype.is_floating_point: # Only test on integer dtypes. return spec = BoundedTensorSpec(self._shape, dtype, 3, 3) sample = spec.sample() self.assertEqual(sample.shape, self._shape) self.assertTrue(torch.all(sample == 3))
def testIntegerSamplesExcludeMaxOfDtype(self, dtype): # Exclude non integer types and uint8 (has special sampling logic). if dtype.is_floating_point or dtype == torch.uint8: return info = np.iinfo(torch_dtype_to_str(dtype)) spec = BoundedTensorSpec(self._shape, dtype, info.max - 1, info.max - 1) sample = spec.sample(outer_dims=(1, )) self.assertEqual(sample.shape, (1, ) + self._shape) self.assertTrue(torch.all(sample == info.max - 1))
def testBoundedTensorSpecSample(self, dtype): if not dtype.is_floating_point: return # minimum and maximum shape broadcasting spec = BoundedTensorSpec(self._shape, dtype, (0, ) * 30, 3) sample = spec.sample() self.assertEqual(self._shape, sample.shape) self.assertTrue(torch.all(sample <= 3)) self.assertTrue(torch.all(0 <= sample)) # last minimum is greater than last maximum self.assertRaises(AssertionError, BoundedTensorSpec, self._shape, dtype, (0, ) * 29 + (2, ), (1, ) * 30)