def testChainIldjWithPlaceholder(self): chain = Chain((Exp(), Exp())) samples = array_ops.placeholder( dtype=np.float32, shape=[None, 10], name="samples") ildj = chain.inverse_log_det_jacobian(samples, event_ndims=0) self.assertTrue(ildj is not None) with self.cached_session(): ildj.eval({samples: np.zeros([2, 10], np.float32)})
def testMinEventNdimsShapeChangingAddRemoveDims(self): chain = Chain([ ShapeChanging(2, 1), ShapeChanging(3, 0), ShapeChanging(1, 2)]) self.assertEqual(4, chain.forward_min_event_ndims) self.assertEqual(1, chain.inverse_min_event_ndims)
def testMinEventNdimsShapeChangingRemoveDims(self): chain = Chain([ShapeChanging(3, 0)]) self.assertEqual(3, chain.forward_min_event_ndims) self.assertEqual(0, chain.inverse_min_event_ndims) chain = Chain([ShapeChanging(3, 0), Affine()]) self.assertEqual(3, chain.forward_min_event_ndims) self.assertEqual(0, chain.inverse_min_event_ndims) chain = Chain([Affine(), ShapeChanging(3, 0)]) self.assertEqual(4, chain.forward_min_event_ndims) self.assertEqual(1, chain.inverse_min_event_ndims) chain = Chain([ShapeChanging(3, 0), ShapeChanging(3, 0)]) self.assertEqual(6, chain.forward_min_event_ndims) self.assertEqual(0, chain.inverse_min_event_ndims)
def testMinEventNdimsChain(self): chain = Chain([Exp(), Exp(), Exp()]) self.assertEqual(0, chain.forward_min_event_ndims) self.assertEqual(0, chain.inverse_min_event_ndims) chain = Chain([Affine(), Affine(), Affine()]) self.assertEqual(1, chain.forward_min_event_ndims) self.assertEqual(1, chain.inverse_min_event_ndims) chain = Chain([Exp(), Affine()]) self.assertEqual(1, chain.forward_min_event_ndims) self.assertEqual(1, chain.inverse_min_event_ndims) chain = Chain([Affine(), Exp()]) self.assertEqual(1, chain.forward_min_event_ndims) self.assertEqual(1, chain.inverse_min_event_ndims) chain = Chain([Affine(), Exp(), Softplus(), Affine()]) self.assertEqual(1, chain.forward_min_event_ndims) self.assertEqual(1, chain.inverse_min_event_ndims)
def testBijectorIdentity(self): with self.cached_session(): chain = Chain() self.assertEqual("identity", chain.name) x = np.asarray([[[1., 2.], [2., 3.]]]) self.assertAllClose(x, chain.forward(x).eval()) self.assertAllClose(x, chain.inverse(x).eval()) self.assertAllClose( 0., chain.inverse_log_det_jacobian(x, event_ndims=1).eval()) self.assertAllClose( 0., chain.forward_log_det_jacobian(x, event_ndims=1).eval())
def testBijector(self): with self.cached_session(): chain = Chain((Exp(), Softplus())) self.assertEqual("chain_of_exp_of_softplus", chain.name) x = np.asarray([[[1., 2.], [2., 3.]]]) self.assertAllClose(1. + np.exp(x), chain.forward(x).eval()) self.assertAllClose(np.log(x - 1.), chain.inverse(x).eval()) self.assertAllClose( -np.sum(np.log(x - 1.), axis=2), chain.inverse_log_det_jacobian(x, event_ndims=1).eval()) self.assertAllClose( np.sum(x, axis=2), chain.forward_log_det_jacobian(x, event_ndims=1).eval())
def testChainAffineExp(self): scale_diag = np.array([1., 2., 3.], dtype=np.float32) chain = Chain([Affine(scale_diag=scale_diag), Exp()]) x = [0., np.log(2., dtype=np.float32), np.log(3., dtype=np.float32)] y = [1., 4., 9.] self.assertAllClose(y, self.evaluate(chain.forward(x))) self.assertAllClose(x, self.evaluate(chain.inverse(y))) self.assertAllClose( np.log(6, dtype=np.float32) + np.sum(x), self.evaluate(chain.forward_log_det_jacobian(x, event_ndims=1))) self.assertAllClose( -np.log(6, dtype=np.float32) - np.sum(x), self.evaluate(chain.inverse_log_det_jacobian(y, event_ndims=1)))
def testShapeGetters(self): with self.cached_session(): chain = Chain([ SoftmaxCentered(validate_args=True), SoftmaxCentered(validate_args=True), ]) x = tensor_shape.TensorShape([1]) y = tensor_shape.TensorShape([2 + 1]) self.assertAllEqual(y, chain.forward_event_shape(x)) self.assertAllEqual( y.as_list(), chain.forward_event_shape_tensor(x.as_list()).eval()) self.assertAllEqual(x, chain.inverse_event_shape(y)) self.assertAllEqual( x.as_list(), chain.inverse_event_shape_tensor(y.as_list()).eval())
def testScalarCongruency(self): with self.cached_session(): chain = Chain((Exp(), Softplus())) assert_scalar_congruency( chain, lower_x=1e-3, upper_x=1.5, rtol=0.05)