def testTruncatedNormal(self):
   for dtype in self._random_types():
     with self.cached_session() as sess, self.test_scope():
       seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
       n = 10000000
       x = stateless.stateless_truncated_normal(
           shape=[n], seed=seed_t, dtype=dtype)
       y = sess.run(x, {seed_t: [0x12345678, 0xabcdef12]})
       random_test_util.test_truncated_normal(
           self.assertEqual, self.assertAllClose, dtype, n, y)
    def testTruncatedNormalIsInRange(self):
        for dtype in self._random_types():
            with self.cached_session() as sess, self.test_scope():
                seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
                n = 10000000
                x = stateless.stateless_truncated_normal(shape=[n],
                                                         seed=seed_t,
                                                         dtype=dtype)
                y = sess.run(x, {seed_t: [0x12345678, 0xabcdef12]})

                def normal_cdf(x):
                    return .5 * math.erfc(-x / math.sqrt(2))

                def normal_pdf(x):
                    return math.exp(-(x**2) / 2.) / math.sqrt(2 * math.pi)

                def probit(x):
                    return self.evaluate(special_math.ndtri(x))

                a = -2.
                b = 2.
                mu = 0.
                sigma = 1.

                alpha = (a - mu) / sigma
                beta = (b - mu) / sigma
                z = normal_cdf(beta) - normal_cdf(alpha)

                self.assertEqual((y >= a).sum(), n)
                self.assertEqual((y <= b).sum(), n)

                # For more information on these calculations, see:
                # Burkardt, John. "The Truncated Normal Distribution".
                # Department of Scientific Computing website. Florida State University.
                expected_mean = mu + (normal_pdf(alpha) -
                                      normal_pdf(beta)) / z * sigma
                y = y.astype(float)
                actual_mean = np.mean(y)
                self.assertAllClose(actual_mean, expected_mean, atol=5e-4)

                expected_median = mu + probit(
                    (normal_cdf(alpha) + normal_cdf(beta)) / 2.) * sigma
                actual_median = np.median(y)
                self.assertAllClose(actual_median, expected_median, atol=8e-4)

                expected_variance = sigma**2 * (1 + (
                    (alpha * normal_pdf(alpha) - beta * normal_pdf(beta)) /
                    z) - ((normal_pdf(alpha) - normal_pdf(beta)) / z)**2)
                actual_variance = np.var(y)
                self.assertAllClose(
                    actual_variance,
                    expected_variance,
                    rtol=5e-3 if dtype == dtypes.bfloat16 else 1e-3)
  def testTruncatedNormalIsInRange(self):
    for dtype in self._random_types():
      with self.cached_session() as sess, self.test_scope():
        seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
        n = 10000000
        x = stateless.stateless_truncated_normal(
            shape=[n], seed=seed_t, dtype=dtype)
        y = sess.run(x, {seed_t: [0x12345678, 0xabcdef12]})

        def normal_cdf(x):
          return .5 * math.erfc(-x / math.sqrt(2))

        def normal_pdf(x):
          return math.exp(-(x**2) / 2.) / math.sqrt(2 * math.pi)

        def probit(x, sess=sess):
          return self.evaluate(special_math.ndtri(x))

        a = -2.
        b = 2.
        mu = 0.
        sigma = 1.

        alpha = (a - mu) / sigma
        beta = (b - mu) / sigma
        z = normal_cdf(beta) - normal_cdf(alpha)

        self.assertTrue((y >= a).sum() == n)
        self.assertTrue((y <= b).sum() == n)

        # For more information on these calculations, see:
        # Burkardt, John. "The Truncated Normal Distribution".
        # Department of Scientific Computing website. Florida State University.
        expected_mean = mu + (normal_pdf(alpha) - normal_pdf(beta)) / z * sigma
        y = y.astype(float)
        actual_mean = np.mean(y)
        self.assertAllClose(actual_mean, expected_mean, atol=5e-4)

        expected_median = mu + probit(
            (normal_cdf(alpha) + normal_cdf(beta)) / 2.) * sigma
        actual_median = np.median(y)
        self.assertAllClose(actual_median, expected_median, atol=8e-4)

        expected_variance = sigma**2 * (1 + (
            (alpha * normal_pdf(alpha) - beta * normal_pdf(beta)) / z) - (
                (normal_pdf(alpha) - normal_pdf(beta)) / z)**2)
        actual_variance = np.var(y)
        self.assertAllClose(actual_variance, expected_variance,
                            rtol=5e-3 if dtype == dtypes.bfloat16 else 1e-3)