def testSamplePoissonHighRates(self): # High rate (>= log(10.)) samples would use rejection sampling. rate = [10., 10.5, 11., 11.5, 12.0, 12.5, 13.0, 13.5, 14.0, 14.5] log_rate = np.log(rate) num_samples = int(1e5) self.assertLess( self.evaluate( st.min_num_samples_for_dkwm_cdf_test(discrepancy=0.04, false_fail_rate=1e-9, false_pass_rate=1e-9)), num_samples) samples = poisson_lib._random_poisson_noncpu( shape=[num_samples], log_rates=log_rate, output_dtype=tf.float64, seed=test_util.test_seed()) poisson = tfd.Poisson(log_rate=log_rate, validate_args=True) self.evaluate( st.assert_true_cdf_equal_by_dkwm( samples, poisson.cdf, st.left_continuous_cdf_discrete_distribution(poisson), false_fail_rate=1e-9)) self.assertAllClose(self.evaluate(tf.math.reduce_mean(samples, axis=0)), stats.poisson.mean(rate), rtol=0.01) self.assertAllClose(self.evaluate( tf.math.reduce_variance(samples, axis=0)), stats.poisson.var(rate), rtol=0.05)
def testSamplePoissonInvalidRates(self): rate = [np.nan, -1., 0., 5., 7., 10., 13.0, 14., 15., 18.] log_rate = np.log(rate) samples = self.evaluate( poisson_lib._random_poisson_noncpu(shape=[int(1e5)], log_rates=log_rate, output_dtype=tf.float64, seed=test_util.test_seed())) self.assertAllClose(self.evaluate(tf.math.reduce_mean(samples, axis=0)), stats.poisson.mean(rate), rtol=0.01) self.assertAllClose(self.evaluate( tf.math.reduce_variance(samples, axis=0)), stats.poisson.var(rate), rtol=0.05)
def testSamplePoissonLowAndHighRates(self): rate = [1., 3., 5., 6., 7., 10., 13.0, 14., 15., 18.] log_rate = np.log(rate) num_samples = int(1e5) poisson = tfd.Poisson(log_rate=log_rate, validate_args=True) self.assertLess( self.evaluate( st.min_num_samples_for_dkwm_cdf_test( discrepancy=0.04, false_fail_rate=1e-9, false_pass_rate=1e-9)), num_samples) samples = poisson_lib._random_poisson_noncpu( shape=[num_samples], log_rates=log_rate, output_dtype=tf.float64, seed=test_util.test_seed()) self.evaluate( st.assert_true_cdf_equal_by_dkwm( samples, poisson.cdf, st.left_continuous_cdf_discrete_distribution(poisson), false_fail_rate=1e-9))