def _testParameterizedTruncatedNormal(self,
                                       means,
                                       stddevs,
                                       minvals,
                                       maxvals,
                                       variance_rtol=None):
   for dtype in self._random_types():
     with self.session() as sess, self.test_scope():
       seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
       n = int(10e7)
       x = stateless.stateless_parameterized_truncated_normal(
           shape=[n],
           seed=seed_t,
           means=means,
           stddevs=stddevs,
           minvals=minvals,
           maxvals=maxvals)
       y = sess.run(x, {seed_t: [0x12345678, 0xabcdef1]})
       if variance_rtol is None:
         variance_rtol = 6e-3 if dtype == dtypes.bfloat16 else 1e-3
       random_test_util.test_truncated_normal(
           self.assertEqual,
           self.assertAllClose,
           n,
           y,
           means=means,
           stddevs=stddevs,
           minvals=minvals,
           maxvals=maxvals,
           mean_atol=1e-3,
           median_atol=1e-3,
           variance_rtol=variance_rtol)
 def testTruncatedNormal(self):
   for dtype in self._floats:
     gen = random.Generator(seed=123)
     n = 10000000
     y = gen.truncated_normal(shape=[n], dtype=dtype).numpy()
     random_test_util.test_truncated_normal(
         self.assertEqual, self.assertAllClose, dtype, n, y)
Ejemplo n.º 3
0
 def testTruncatedNormal(self):
     for dtype in self._floats:
         gen = random.Generator(seed=123)
         n = 10000000
         y = gen.truncated_normal(shape=[n], dtype=dtype).numpy()
         random_test_util.test_truncated_normal(self.assertEqual,
                                                self.assertAllClose, dtype,
                                                n, y)
Ejemplo n.º 4
0
 def testTruncatedNormal(self, alg):
   with ops.device(xla_device_name()):
     for dtype in self._floats:
       gen = random.Generator.from_seed(seed=123, alg=alg)
       n = 10000000
       y = gen.truncated_normal(shape=[n], dtype=dtype).numpy()
       random_test_util.test_truncated_normal(
           self.assertEqual, self.assertAllClose, dtype, n, y)
Ejemplo n.º 5
0
 def testTruncatedNormal(self):
   for dtype in self._random_types():
     with self.cached_session() as sess, self.test_scope():
       seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
       n = 10000000
       x = stateless.stateless_truncated_normal(
           shape=[n], seed=seed_t, dtype=dtype)
       y = sess.run(x, {seed_t: [0x12345678, 0xabcdef12]})
       random_test_util.test_truncated_normal(
           self.assertEqual, self.assertAllClose, dtype, n, y)
Ejemplo n.º 6
0
 def testTruncatedNormal(self, alg, dtype):
   self.check_dtype(dtype)
   with ops.device(xla_device_name()):
     gen = random.Generator.from_seed(seed=123, alg=alg)
     n = 100000
     y = gen.truncated_normal(shape=[n], dtype=dtype).numpy()
     random_test_util.test_truncated_normal(
         self.assertEqual, self.assertAllClose, n, y,
         mean_atol=2e-3, median_atol=4e-3,
         variance_rtol=1e-2 if dtype == dtypes.bfloat16 else 5e-3)