def testScalarCongruency(self): with self.test_session(): bijector = chain_lib.Chain( (exp_lib.Exp(), softplus_lib.Softplus())) bijector_test_util.assert_scalar_congruency(bijector, lower_x=1e-3, upper_x=1.5, rtol=0.05)
def testBijectorLogDetJacobianEventDimsOne(self): with self.test_session(): bijector = softplus_lib.Softplus(event_ndims=1) y = 2 * rng.rand(2, 10) ildj_before = self._softplus_ildj_before_reduction(y) ildj = np.sum(ildj_before, axis=1) self.assertAllClose(ildj, bijector.inverse_log_det_jacobian(y).eval())
def testBijectorForwardInverseEventDimsOne(self): with self.test_session(): bijector = softplus_lib.Softplus(event_ndims=1) self.assertEqual("softplus", bijector.name) x = 2 * rng.randn(2, 10) y = self._softplus(x) self.assertAllClose(y, bijector.forward(x).eval()) self.assertAllClose(x, bijector.inverse(y).eval())
def testBijectorLogDetJacobianEventDimsZero(self): with self.test_session(): bijector = softplus_lib.Softplus(event_ndims=0) y = 2 * rng.rand(2, 10) # No reduction needed if event_dims = 0. ildj = self._softplus_ildj_before_reduction(y) self.assertAllClose(ildj, bijector.inverse_log_det_jacobian(y).eval())
def testBijectiveAndFinite32bit(self): with self.test_session(): bijector = softplus_lib.Softplus(event_ndims=0) x = np.linspace(-20., 20., 100).astype(np.float32) y = np.logspace(-10, 10, 100).astype(np.float32) bijector_test_util.assert_bijective_and_finite(bijector, x, y, rtol=1e-2, atol=1e-2)
def testBijectorForwardInverseEventDimsZero(self): with self.test_session(): bijector = softplus_lib.Softplus(event_ndims=0) self.assertEqual("softplus", bijector.name) x = 2 * rng.randn(2, 10) y = self._softplus(x) self.assertAllClose(y, bijector.forward(x).eval()) self.assertAllClose(x, bijector.inverse(y).eval()) self.assertAllClose( x, bijector.inverse_and_inverse_log_det_jacobian(y)[0].eval())
def testBijector(self): with self.test_session(): chain = chain_lib.Chain((exp_lib.Exp(event_ndims=1), softplus_lib.Softplus(event_ndims=1))) self.assertEqual("chain_of_exp_of_softplus", chain.name) x = np.asarray([[[1., 2.], [2., 3.]]]) self.assertAllClose(1. + np.exp(x), chain.forward(x).eval()) self.assertAllClose(np.log(x - 1.), chain.inverse(x).eval()) self.assertAllClose(-np.sum(np.log(x - 1.), axis=2), chain.inverse_log_det_jacobian(x).eval()) self.assertAllClose(np.sum(x, axis=2), chain.forward_log_det_jacobian(x).eval())
def testBijectiveAndFinite16bit(self): with self.test_session(): bijector = softplus_lib.Softplus(event_ndims=0) # softplus(-20) is zero, so we can't use such a large range as in 32bit. x = np.linspace(-10., 20., 100).astype(np.float16) # Note that float16 is only in the open set (0, inf) for a smaller # logspace range. The actual range was (-7, 4), so use something smaller # for the test. y = np.logspace(-6, 3, 100).astype(np.float16) bijector_test_util.assert_bijective_and_finite(bijector, x, y, rtol=1e-1, atol=1e-3)
def __init__(self, diag_bijector=None, diag_shift=1e-5, validate_args=False, name="scale_tril"): """Instantiates the `ScaleTriL` bijector. Args: diag_bijector: `Bijector` instance, used to transform the output diagonal to be positive. Default value: `None` (i.e., `tfb.Softplus()`). diag_shift: Float value broadcastable and added to all diagonal entries after applying the `diag_bijector`. Setting a positive value forces the output diagonal entries to be positive, but prevents inverting the transformation for matrices with diagonal entries less than this value. Default value: `1e-5` (i.e., no shift is applied). validate_args: Python `bool` indicating whether arguments should be checked for correctness. Default value: `False` (i.e., arguments are not validated). name: Python `str` name given to ops managed by this object. Default value: `scale_tril`. """ if diag_bijector is None: diag_bijector = softplus.Softplus(validate_args=validate_args) if diag_shift is not None: diag_bijector = chain.Chain( [affine_scalar.AffineScalar(shift=diag_shift), diag_bijector]) super(ScaleTriL, self).__init__([ transform_diagonal.TransformDiagonal(diag_bijector=diag_bijector), fill_triangular.FillTriangular() ], validate_args=validate_args, name=name)
def testScalarCongruency(self): with self.test_session(): bijector = softplus_lib.Softplus(event_ndims=0) bijector_test_util.assert_scalar_congruency(bijector, lower_x=-2., upper_x=2.)