def testShapeTypes(self):
        for shape_dtype in [np.int32, np.int64]:
            shape = np.array([1000], dtype=shape_dtype)
            sample_op = random_ops.parameterized_truncated_normal(shape=shape,
                                                                  means=0.0,
                                                                  stddevs=0.1,
                                                                  minvals=-1.,
                                                                  maxvals=1.)
            new_seed = random_ops.random_uniform([2],
                                                 seed=1234,
                                                 minval=0,
                                                 maxval=(2**31 - 1),
                                                 dtype=np.int32)
            sample_op_stateless = stateless.stateless_parameterized_truncated_normal(
                shape=shape,
                seed=new_seed,
                means=0.0,
                stddevs=0.1,
                minvals=-1.,
                maxvals=1.)

            samples = self.evaluate(sample_op)
            stateless_samples = self.evaluate(sample_op_stateless)
            self.assertAllEqual(samples.shape, shape)
            self.assertAllEqual(stateless_samples.shape, shape)
    def testSamplingWithSmallStdDevFarFromBound(self):
        sample_op = random_ops.parameterized_truncated_normal(
            shape=(int(1e5), ),
            means=0.8,
            stddevs=0.05,
            minvals=-1.,
            maxvals=1.)
        new_seed = random_ops.random_uniform([2],
                                             seed=1234,
                                             minval=0,
                                             maxval=(2**31 - 1),
                                             dtype=np.int32)
        sample_op_stateless = stateless.stateless_parameterized_truncated_normal(
            shape=(int(1e5), ),
            seed=new_seed,
            means=0.8,
            stddevs=0.05,
            minvals=-1.,
            maxvals=1.)

        with self.session() as sess:
            samples, samples_stateless = sess.run(
                [sample_op, sample_op_stateless])
            # 0. is more than 16 standard deviations from the mean, and
            # should have a likelihood < 1e-57.
            assert (~np.isnan(samples)).all()
            assert (~np.isnan(samples_stateless)).all()
            self.assertAllGreater(samples, 0.)
            self.assertAllGreater(samples_stateless, 0.)
 def _testParameterizedTruncatedNormal(self,
                                       means,
                                       stddevs,
                                       minvals,
                                       maxvals,
                                       variance_rtol=None):
   for dtype in self._random_types():
     with self.session() as sess, self.test_scope():
       seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
       n = int(10e7)
       x = stateless.stateless_parameterized_truncated_normal(
           shape=[n],
           seed=seed_t,
           means=means,
           stddevs=stddevs,
           minvals=minvals,
           maxvals=maxvals)
       y = sess.run(x, {seed_t: [0x12345678, 0xabcdef1]})
       if variance_rtol is None:
         variance_rtol = 6e-3 if dtype == dtypes.bfloat16 else 1e-3
       random_test_util.test_truncated_normal(
           self.assertEqual,
           self.assertAllClose,
           n,
           y,
           means=means,
           stddevs=stddevs,
           minvals=minvals,
           maxvals=maxvals,
           mean_atol=1e-3,
           median_atol=1e-3,
           variance_rtol=variance_rtol)
    def testStatelessParameterizedTruncatedNormalHasGrads(self):
        mean = variables.Variable(0.01)
        stddev = variables.Variable(1.)
        minval = variables.Variable(-1.)
        maxval = variables.Variable(1.)

        with self.cached_session() as sess:
            with backprop.GradientTape(persistent=True) as tape:
                samples = stateless.stateless_parameterized_truncated_normal(
                    [1], [1, 2], mean, stddev, minval, maxval)

            sess.run(
                variables.variables_initializer([mean, stddev, minval,
                                                 maxval]))
            [mean_grad,
             std_grad], mean_actual_grad, std_actual_grad = sess.run([
                 tape.gradient(samples, [mean, stddev]),
                 array_ops.ones_like(mean), (samples - mean) / stddev
             ])
            self.assertAllClose(mean_grad, mean_actual_grad)
            self.assertAllClose(std_grad, std_actual_grad[0])

            try:
                import scipy.stats  # pylint:disable=g-import-not-at-top
                truncnorm = scipy.stats.truncnorm(a=-1.,
                                                  b=1.,
                                                  loc=0.,
                                                  scale=1.)
                samples_np, [minval_grad, maxval_grad] = sess.run(
                    [samples,
                     tape.gradient(samples, [minval, maxval])])

                sample_cdf = truncnorm.cdf(samples_np)
                # These come from the implicit reparameterization trick.
                scipy_maxval_grad = np.exp(0.5 * (samples_np**2 -
                                                  ((1. - 0.01) / 1.)**2) +
                                           np.log(sample_cdf))

                scipy_minval_grad = np.exp(0.5 * (samples_np**2 -
                                                  ((-1. - 0.01) / 1.)**2) +
                                           np.log1p(-sample_cdf))

                self.assertAllClose(minval_grad,
                                    scipy_minval_grad[0],
                                    rtol=1e-2)
                self.assertAllClose(maxval_grad,
                                    scipy_maxval_grad[0],
                                    rtol=1e-2)

            except ImportError as e:
                tf_logging.warn("Cannot test truncated normal op: %s" % str(e))
 def testParameterizedTruncatedNormalBroadcast(self):
   with self.session() as sess, self.test_scope():
     seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
     means = array_ops.zeros([2], dtype=dtypes.float32)
     stddevs = array_ops.ones([3, 1], dtype=dtypes.float32)
     minvals = -array_ops.ones([5, 1, 1], dtype=dtypes.float32)
     maxvals = array_ops.ones([7, 1, 1, 1], dtype=dtypes.float32)
     shape = [11, 7, 5, 3, 2]
     x = stateless.stateless_parameterized_truncated_normal(
         shape=shape,
         seed=seed_t,
         means=means,
         stddevs=stddevs,
         minvals=minvals,
         maxvals=maxvals)
     y = sess.run(x, {seed_t: [0x12345678, 0xabcdef1]})
     self.assertEqual((11, 7, 5, 3, 2), y.shape)
    def validateKolmogorovSmirnov(self,
                                  shape,
                                  mean,
                                  stddev,
                                  minval,
                                  maxval,
                                  use_stateless=False,
                                  seed=1618):
        try:
            import scipy.stats  # pylint: disable=g-import-not-at-top
            random_seed.set_random_seed(seed)
            with self.cached_session():
                if use_stateless:
                    new_seed = random_ops.random_uniform([2],
                                                         seed=seed,
                                                         minval=0,
                                                         maxval=(2**31 - 1),
                                                         dtype=np.int32)
                    samples = stateless.stateless_parameterized_truncated_normal(
                        shape, new_seed, mean, stddev, minval, maxval).eval()
                else:
                    samples = random_ops.parameterized_truncated_normal(
                        shape, mean, stddev, minval, maxval).eval()

            assert (~np.isnan(samples)).all()
            minval = max(mean - stddev * 10, minval)
            maxval = min(mean + stddev * 10, maxval)
            dist = scipy.stats.norm(loc=mean, scale=stddev)
            cdf_min = dist.cdf(minval)
            cdf_max = dist.cdf(maxval)

            def truncated_cdf(x):
                return np.clip((dist.cdf(x) - cdf_min) / (cdf_max - cdf_min),
                               0.0, 1.0)

            pvalue = scipy.stats.kstest(samples, truncated_cdf)[1]
            self.assertGreater(pvalue, 1e-10)
        except ImportError as e:
            tf_logging.warn("Cannot test truncated normal op: %s" % str(e))
 def validateMoments(self,
                     shape,
                     mean,
                     stddev,
                     minval,
                     maxval,
                     use_stateless=False,
                     seed=1618):
     try:
         # TruncatedNormalMoments requires scipy.stats.
         # Give up early if we are unable to import it.
         random_seed.set_random_seed(seed)
         with self.cached_session():
             if use_stateless:
                 # Generate a seed that stateless ops can use.
                 new_seed = random_ops.random_uniform([2],
                                                      seed=seed,
                                                      minval=0,
                                                      maxval=(2**31 - 1),
                                                      dtype=np.int32)
                 samples = stateless.stateless_parameterized_truncated_normal(
                     shape, new_seed, mean, stddev, minval, maxval).eval()
             else:
                 samples = random_ops.parameterized_truncated_normal(
                     shape, mean, stddev, minval, maxval).eval()
             assert (~np.isnan(samples)).all()
         moments = calculate_moments(samples, self.max_moment)
         expected_moments = TruncatedNormalMoments(mean, stddev, minval,
                                                   maxval)
         num_samples = functools.reduce(lambda x, y: x * y, shape, 1)
         for i in range(1, len(moments)):
             self.assertLess(
                 z_test(moments, expected_moments, i, num_samples),
                 self.z_limit)
     except ImportError as e:
         tf_logging.warn("Cannot test truncated normal op: %s" % str(e))