def testMeanVariance(self):
     with self.cached_session() as sess:
         pln = poisson_lognormal.PoissonLogNormalQuadratureCompound(
             loc=array_ops.placeholder_with_default(
                 0., shape=[] if self.static_shape else None),
             scale=array_ops.placeholder_with_default(
                 1., shape=[] if self.static_shape else None),
             quadrature_size=10,
             validate_args=True)
         self.run_test_sample_consistent_mean_variance(sess.run,
                                                       pln,
                                                       rtol=0.02)
 def testSampleProbConsistent(self):
     with self.cached_session() as sess:
         pln = poisson_lognormal.PoissonLogNormalQuadratureCompound(
             loc=array_ops.placeholder_with_default(
                 -2., shape=[] if self.static_shape else None),
             scale=array_ops.placeholder_with_default(
                 1.1, shape=[] if self.static_shape else None),
             quadrature_size=10,
             validate_args=True)
         self.run_test_sample_consistent_log_prob(sess.run,
                                                  pln,
                                                  batch_size=1,
                                                  rtol=0.1)
 def testMeanVarianceBroadcastBoth(self):
     with self.cached_session() as sess:
         pln = poisson_lognormal.PoissonLogNormalQuadratureCompound(
             loc=array_ops.placeholder_with_default(
                 [[0.], [-0.5]],
                 shape=[2, 1] if self.static_shape else None),
             scale=array_ops.placeholder_with_default(
                 [[1., 0.9]], shape=[1, 2] if self.static_shape else None),
             quadrature_size=10,
             validate_args=True)
         self.run_test_sample_consistent_mean_variance(sess.run,
                                                       pln,
                                                       rtol=0.1,
                                                       atol=0.01)