def testTriLAdjoint(self):
        tril = np.array([[[3, 0, 0], [2, -1, 0], [3, 2, 1]],
                         [[2, 0, 0], [3, -2, 0], [4, 3, 2]]],
                        dtype=np.float32)
        scale = tf.linalg.LinearOperatorLowerTriangular(tril,
                                                        is_non_singular=True)
        bijector = tfb.ScaleMatvecLinearOperator(scale=scale,
                                                 adjoint=True,
                                                 validate_args=True)

        x = np.array([[[1, 0, -1], [2, 3, 4]], [[4, 1, -7], [6, 9, 8]]],
                     dtype=np.float32)
        # If we made the bijector do x*A+b then this would be simplified to:
        # y = np.matmul(x, tril).
        triu = tril.transpose([0, 2, 1])
        y = np.matmul(triu, x[..., np.newaxis])[..., 0]
        ildj = -np.sum(np.log(np.abs(np.diagonal(tril, axis1=-2, axis2=-1))))

        self.assertStartsWith(bijector.name, 'scale_matvec_linear_operator')
        self.assertAllClose(y, self.evaluate(bijector.forward(x)))
        self.assertAllClose(x, self.evaluate(bijector.inverse(y)))
        self.assertAllClose(
            ildj,
            self.evaluate(bijector.inverse_log_det_jacobian(y, event_ndims=2)))
        self.assertAllClose(
            self.evaluate(
                -bijector.inverse_log_det_jacobian(y, event_ndims=2)),
            self.evaluate(bijector.forward_log_det_jacobian(x, event_ndims=2)))
 def testMean(self):
   shift = np.array([[-1, 0, 1], [-1, -2, -3]], dtype=np.float32)
   diag = np.array([[1, 2, 3], [2, 3, 2]], dtype=np.float32)
   fake_mvn = self._cls()(
       tfd.MultivariateNormalDiag(
           loc=tf.zeros_like(shift),
           scale_diag=tf.ones_like(diag),
           validate_args=True),
       tfb.Chain([
           tfb.Shift(shift=shift),
           tfb.ScaleMatvecLinearOperator(
               scale=tf.linalg.LinearOperatorDiag(diag, is_non_singular=True))
       ], validate_args=True),
       validate_args=True)
   self.assertAllClose(shift, self.evaluate(fake_mvn.mean()))
 def testEntropy(self):
   shift = np.array([[-1, 0, 1], [-1, -2, -3]], dtype=np.float32)
   diag = np.array([[1, 2, 3], [2, 3, 2]], dtype=np.float32)
   actual_mvn_entropy = np.concatenate(
       [[stats.multivariate_normal(shift[i], np.diag(diag[i]**2)).entropy()]
        for i in range(len(diag))])
   fake_mvn = self._cls()(
       tfd.MultivariateNormalDiag(
           loc=tf.zeros_like(shift),
           scale_diag=tf.ones_like(diag),
           validate_args=True),
       tfb.Chain([
           tfb.Shift(shift=shift),
           tfb.ScaleMatvecLinearOperator(
               scale=tf.linalg.LinearOperatorDiag(diag, is_non_singular=True))
       ], validate_args=True),
       validate_args=True)
   self.assertAllClose(actual_mvn_entropy, self.evaluate(fake_mvn.entropy()))
    def testDiag(self):
        diag = np.array([[1, 2, 3], [2, 5, 6]], dtype=np.float32)
        scale = tf.linalg.LinearOperatorDiag(diag, is_non_singular=True)
        bijector = tfb.ScaleMatvecLinearOperator(scale=scale,
                                                 validate_args=True)

        x = np.array([[1, 0, -1], [2, 3, 4]], dtype=np.float32)
        y = diag * x
        ildj = -np.sum(np.log(np.abs(diag)), axis=-1)

        self.assertStartsWith(bijector.name, 'scale_matvec_linear_operator')
        self.assertAllClose(y, self.evaluate(bijector.forward(x)))
        self.assertAllClose(x, self.evaluate(bijector.inverse(y)))
        self.assertAllClose(
            ildj,
            self.evaluate(bijector.inverse_log_det_jacobian(y, event_ndims=1)))
        self.assertAllClose(
            self.evaluate(
                -bijector.inverse_log_det_jacobian(y, event_ndims=1)),
            self.evaluate(bijector.forward_log_det_jacobian(x, event_ndims=1)))