예제 #1
0
 def testScalarCongruency(self):
     with self.test_session():
         bijector = bijectors.Invert(bijectors.Exp())
         assert_scalar_congruency(bijector,
                                  lower_x=1e-3,
                                  upper_x=1.5,
                                  rtol=0.05)
예제 #2
0
 def testBijector(self):
     with self.test_session():
         for fwd in [
                 bijectors.Identity(),
                 bijectors.Exp(event_ndims=1),
                 bijectors.Affine(shift=[0., 1.],
                                  scale_diag=[2., 3.],
                                  event_ndims=1),
                 bijectors.Softplus(event_ndims=1),
                 bijectors.SoftmaxCentered(event_ndims=1),
                 bijectors.SigmoidCentered(),
         ]:
             rev = bijectors.Invert(fwd)
             self.assertEqual("_".join(["invert", fwd.name]), rev.name)
             x = [[[1., 2.], [2., 3.]]]
             self.assertAllClose(
                 fwd.inverse(x).eval(),
                 rev.forward(x).eval())
             self.assertAllClose(
                 fwd.forward(x).eval(),
                 rev.inverse(x).eval())
             self.assertAllClose(
                 fwd.forward_log_det_jacobian(x).eval(),
                 rev.inverse_log_det_jacobian(x).eval())
             self.assertAllClose(
                 fwd.inverse_log_det_jacobian(x).eval(),
                 rev.forward_log_det_jacobian(x).eval())
예제 #3
0
 def testDocstringExample(self):
   with self.cached_session():
     exp_gamma_distribution = (
         transformed_distribution_lib.TransformedDistribution(
             distribution=gamma_lib.Gamma(concentration=1., rate=2.),
             bijector=bijectors.Invert(bijectors.Exp())))
     self.assertAllEqual(
         [], array_ops.shape(exp_gamma_distribution.sample()).eval())
예제 #4
0
 def testShapeGetters(self):
   with self.cached_session():
     bijector = bijectors.Invert(bijectors.SoftmaxCentered(validate_args=True))
     x = tensor_shape.TensorShape([2])
     y = tensor_shape.TensorShape([1])
     self.assertAllEqual(y, bijector.forward_event_shape(x))
     self.assertAllEqual(
         y.as_list(),
         bijector.forward_event_shape_tensor(x.as_list()).eval())
     self.assertAllEqual(x, bijector.inverse_event_shape(y))
     self.assertAllEqual(
         x.as_list(),
         bijector.inverse_event_shape_tensor(y.as_list()).eval())