Exemple #1
0
 def testSampleProbConsistent(self):
     with self.test_session() as sess:
         pln = poisson_lognormal.PoissonLogNormalQuadratureCompound(
             loc=-2.,
             scale=1.1,
             quadrature_grid_and_probs=(np.polynomial.hermite.hermgauss(
                 deg=10)),
             validate_args=True)
         self.run_test_sample_consistent_log_prob(sess.run, pln, rtol=0.1)
Exemple #2
0
 def testMeanVarianceBroadcastBoth(self):
   with self.test_session() as sess:
     pln = poisson_lognormal.PoissonLogNormalQuadratureCompound(
         loc=[[0.], [-0.5]],
         scale=[[1., 0.9]],
         quadrature_grid_and_probs=(
             np.polynomial.hermite.hermgauss(deg=10)),
         validate_args=True)
     self.run_test_sample_consistent_mean_variance(
         sess, pln, rtol=0.1, atol=0.01)
Exemple #3
0
 def testMeanVariance(self):
     with self.test_session() as sess:
         pln = poisson_lognormal.PoissonLogNormalQuadratureCompound(
             loc=0.,
             scale=1.,
             quadrature_grid_and_probs=(np.polynomial.hermite.hermgauss(
                 deg=10)),
             validate_args=True)
         self.run_test_sample_consistent_mean_variance(sess.run,
                                                       pln,
                                                       rtol=0.02)
Exemple #4
0
 def testMeanVariance(self):
     with self.test_session() as sess:
         pln = poisson_lognormal.PoissonLogNormalQuadratureCompound(
             loc=array_ops.placeholder_with_default(
                 0., shape=[] if self.static_shape else None),
             scale=array_ops.placeholder_with_default(
                 1., shape=[] if self.static_shape else None),
             quadrature_size=10,
             validate_args=True)
         self.run_test_sample_consistent_mean_variance(sess.run,
                                                       pln,
                                                       rtol=0.02)
Exemple #5
0
 def testSampleProbConsistent(self):
     with self.test_session() as sess:
         pln = poisson_lognormal.PoissonLogNormalQuadratureCompound(
             loc=array_ops.placeholder_with_default(
                 -2., shape=[] if self.static_shape else None),
             scale=array_ops.placeholder_with_default(
                 1.1, shape=[] if self.static_shape else None),
             quadrature_size=10,
             validate_args=True)
         self.run_test_sample_consistent_log_prob(sess.run,
                                                  pln,
                                                  batch_size=1,
                                                  rtol=0.1)
Exemple #6
0
 def testMeanVarianceBroadcastBoth(self):
     with self.test_session() as sess:
         pln = poisson_lognormal.PoissonLogNormalQuadratureCompound(
             loc=array_ops.placeholder_with_default(
                 [[0.], [-0.5]],
                 shape=[2, 1] if self.static_shape else None),
             scale=array_ops.placeholder_with_default(
                 [[1., 0.9]], shape=[1, 2] if self.static_shape else None),
             quadrature_size=10,
             validate_args=True)
         self.run_test_sample_consistent_mean_variance(sess.run,
                                                       pln,
                                                       rtol=0.1,
                                                       atol=0.01)
Exemple #7
0
 def testSampleProbConsistentDynamicQuadrature(self):
     with self.test_session() as sess:
         qgrid = array_ops.placeholder(dtype=dtypes.float32)
         qprobs = array_ops.placeholder(dtype=dtypes.float32)
         g, p = np.polynomial.hermite.hermgauss(deg=10)
         pln = poisson_lognormal.PoissonLogNormalQuadratureCompound(
             loc=-2.,
             scale=1.1,
             quadrature_grid_and_probs=(g, p),
             validate_args=True)
         self.run_test_sample_consistent_log_prob(
             lambda x: sess.run(x, feed_dict={
                 qgrid: g,
                 qprobs: p
             }),
             pln,
             rtol=0.1)