def _testParameterizedTruncatedNormal(self, means, stddevs, minvals, maxvals, variance_rtol=None): for dtype in self._random_types(): with self.session() as sess, self.test_scope(): seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) n = int(10e7) x = stateless.stateless_parameterized_truncated_normal( shape=[n], seed=seed_t, means=means, stddevs=stddevs, minvals=minvals, maxvals=maxvals) y = sess.run(x, {seed_t: [0x12345678, 0xabcdef1]}) if variance_rtol is None: variance_rtol = 6e-3 if dtype == dtypes.bfloat16 else 1e-3 random_test_util.test_truncated_normal( self.assertEqual, self.assertAllClose, n, y, means=means, stddevs=stddevs, minvals=minvals, maxvals=maxvals, mean_atol=1e-3, median_atol=1e-3, variance_rtol=variance_rtol)
def testTruncatedNormal(self): for dtype in self._floats: gen = random.Generator(seed=123) n = 10000000 y = gen.truncated_normal(shape=[n], dtype=dtype).numpy() random_test_util.test_truncated_normal( self.assertEqual, self.assertAllClose, dtype, n, y)
def testTruncatedNormal(self): for dtype in self._floats: gen = random.Generator(seed=123) n = 10000000 y = gen.truncated_normal(shape=[n], dtype=dtype).numpy() random_test_util.test_truncated_normal(self.assertEqual, self.assertAllClose, dtype, n, y)
def testTruncatedNormal(self, alg): with ops.device(xla_device_name()): for dtype in self._floats: gen = random.Generator.from_seed(seed=123, alg=alg) n = 10000000 y = gen.truncated_normal(shape=[n], dtype=dtype).numpy() random_test_util.test_truncated_normal( self.assertEqual, self.assertAllClose, dtype, n, y)
def testTruncatedNormal(self): for dtype in self._random_types(): with self.cached_session() as sess, self.test_scope(): seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) n = 10000000 x = stateless.stateless_truncated_normal( shape=[n], seed=seed_t, dtype=dtype) y = sess.run(x, {seed_t: [0x12345678, 0xabcdef12]}) random_test_util.test_truncated_normal( self.assertEqual, self.assertAllClose, dtype, n, y)
def testTruncatedNormal(self, alg, dtype): self.check_dtype(dtype) with ops.device(xla_device_name()): gen = random.Generator.from_seed(seed=123, alg=alg) n = 100000 y = gen.truncated_normal(shape=[n], dtype=dtype).numpy() random_test_util.test_truncated_normal( self.assertEqual, self.assertAllClose, n, y, mean_atol=2e-3, median_atol=4e-3, variance_rtol=1e-2 if dtype == dtypes.bfloat16 else 5e-3)