Beispiel #1
0
 def testScalarCongruency(self):
     with self.test_session():
         bijector = chain_lib.Chain(
             (exp_lib.Exp(), softplus_lib.Softplus()))
         bijector_test_util.assert_scalar_congruency(bijector,
                                                     lower_x=1e-3,
                                                     upper_x=1.5,
                                                     rtol=0.05)
Beispiel #2
0
 def testBijectorIdentity(self):
     with self.test_session():
         chain = chain_lib.Chain()
         self.assertEqual("identity", chain.name)
         x = np.asarray([[[1., 2.], [2., 3.]]])
         self.assertAllClose(x, chain.forward(x).eval())
         self.assertAllClose(x, chain.inverse(x).eval())
         self.assertAllClose(0., chain.inverse_log_det_jacobian(x).eval())
         self.assertAllClose(0., chain.forward_log_det_jacobian(x).eval())
Beispiel #3
0
 def testBijector(self):
     with self.test_session():
         chain = chain_lib.Chain((exp_lib.Exp(event_ndims=1),
                                  softplus_lib.Softplus(event_ndims=1)))
         self.assertEqual("chain_of_exp_of_softplus", chain.name)
         x = np.asarray([[[1., 2.], [2., 3.]]])
         self.assertAllClose(1. + np.exp(x), chain.forward(x).eval())
         self.assertAllClose(np.log(x - 1.), chain.inverse(x).eval())
         self.assertAllClose(-np.sum(np.log(x - 1.), axis=2),
                             chain.inverse_log_det_jacobian(x).eval())
         self.assertAllClose(np.sum(x, axis=2),
                             chain.forward_log_det_jacobian(x).eval())
Beispiel #4
0
 def testShapeGetters(self):
   with self.test_session():
     bijector = chain_lib.Chain([
         softmax_centered_lib.SoftmaxCentered(
             event_ndims=1, validate_args=True),
         softmax_centered_lib.SoftmaxCentered(
             event_ndims=0, validate_args=True)])
     x = tensor_shape.TensorShape([])
     y = tensor_shape.TensorShape([2 + 1])
     self.assertAllEqual(y, bijector.forward_event_shape(x))
     self.assertAllEqual(
         y.as_list(),
         bijector.forward_event_shape_tensor(x.as_list()).eval())
     self.assertAllEqual(x, bijector.inverse_event_shape(y))
     self.assertAllEqual(
         x.as_list(),
         bijector.inverse_event_shape_tensor(y.as_list()).eval())
Beispiel #5
0
    def __init__(self,
                 diag_bijector=None,
                 diag_shift=1e-5,
                 validate_args=False,
                 name="scale_tril"):
        """Instantiates the `ScaleTriL` bijector.

    Args:
      diag_bijector: `Bijector` instance, used to transform the output diagonal
        to be positive.
        Default value: `None` (i.e., `tfb.Softplus()`).
      diag_shift: Float value broadcastable and added to all diagonal entries
        after applying the `diag_bijector`. Setting a positive
        value forces the output diagonal entries to be positive, but
        prevents inverting the transformation for matrices with
        diagonal entries less than this value.
        Default value: `1e-5` (i.e., no shift is applied).
      validate_args: Python `bool` indicating whether arguments should be
        checked for correctness.
        Default value: `False` (i.e., arguments are not validated).
      name: Python `str` name given to ops managed by this object.
        Default value: `scale_tril`.
    """

        if diag_bijector is None:
            diag_bijector = softplus.Softplus(validate_args=validate_args)

        if diag_shift is not None:
            diag_bijector = chain.Chain(
                [affine_scalar.AffineScalar(shift=diag_shift), diag_bijector])

        super(ScaleTriL, self).__init__([
            transform_diagonal.TransformDiagonal(diag_bijector=diag_bijector),
            fill_triangular.FillTriangular()
        ],
                                        validate_args=validate_args,
                                        name=name)