def testBijectorIdentity(self):
   with self.cached_session():
     chain = Chain()
     self.assertEqual("identity", chain.name)
     x = np.asarray([[[1., 2.],
                      [2., 3.]]])
     self.assertAllClose(x, chain.forward(x).eval())
     self.assertAllClose(x, chain.inverse(x).eval())
     self.assertAllClose(
         0., chain.inverse_log_det_jacobian(x, event_ndims=1).eval())
     self.assertAllClose(
         0., chain.forward_log_det_jacobian(x, event_ndims=1).eval())
 def testBijector(self):
   with self.cached_session():
     chain = Chain((Exp(), Softplus()))
     self.assertEqual("chain_of_exp_of_softplus", chain.name)
     x = np.asarray([[[1., 2.],
                      [2., 3.]]])
     self.assertAllClose(1. + np.exp(x), chain.forward(x).eval())
     self.assertAllClose(np.log(x - 1.), chain.inverse(x).eval())
     self.assertAllClose(
         -np.sum(np.log(x - 1.), axis=2),
         chain.inverse_log_det_jacobian(x, event_ndims=1).eval())
     self.assertAllClose(
         np.sum(x, axis=2),
         chain.forward_log_det_jacobian(x, event_ndims=1).eval())
  def testChainAffineExp(self):
    scale_diag = np.array([1., 2., 3.], dtype=np.float32)
    chain = Chain([Affine(scale_diag=scale_diag), Exp()])
    x = [0., np.log(2., dtype=np.float32), np.log(3., dtype=np.float32)]
    y = [1., 4., 9.]
    self.assertAllClose(y, self.evaluate(chain.forward(x)))
    self.assertAllClose(x, self.evaluate(chain.inverse(y)))
    self.assertAllClose(
        np.log(6, dtype=np.float32) + np.sum(x),
        self.evaluate(chain.forward_log_det_jacobian(x, event_ndims=1)))

    self.assertAllClose(
        -np.log(6, dtype=np.float32) - np.sum(x),
        self.evaluate(chain.inverse_log_det_jacobian(y, event_ndims=1)))