Esempio n. 1
0
def _MatrixTriangularSolveGrad(op, grad):
  """Gradient for MatrixTriangularSolve."""
  a = op.inputs[0]
  adjoint_a = op.get_attr("adjoint")
  lower_a = op.get_attr("lower")
  c = op.outputs[0]
  grad_b = linalg_ops.matrix_triangular_solve(
      a, grad, lower=lower_a, adjoint=not adjoint_a)
  if adjoint_a:
    grad_a = -math_ops.batch_matmul(c, grad_b, adj_y=True)
  else:
    grad_a = -math_ops.batch_matmul(grad_b, c, adj_y=True)
  if lower_a:
    grad_a = array_ops.batch_matrix_band_part(grad_a, -1, 0)
  else:
    grad_a = array_ops.batch_matrix_band_part(grad_a, 0, -1)
  return (grad_a, grad_b)
 def _batch_matmul(self, x, transpose_x=False):
     # tf.batch_matmul is defined x * y, so "y" is on the right, not "x".
     chol = array_ops.batch_matrix_band_part(self._chol, -1, 0)
     chol_times_x = math_ops.batch_matmul(chol,
                                          x,
                                          adj_x=True,
                                          adj_y=transpose_x)
     return math_ops.batch_matmul(chol, chol_times_x)
 def _matmul(self, x, transpose_x=False):
     # tf.matmul is defined a * b.
     chol = array_ops.batch_matrix_band_part(self._chol, -1, 0)
     chol_times_x = math_ops.matmul(chol,
                                    x,
                                    transpose_a=True,
                                    transpose_b=transpose_x)
     return math_ops.matmul(chol, chol_times_x)
Esempio n. 4
0
    def _sample_n(self, n, seed):
        batch_shape = self.batch_shape()
        event_shape = self.event_shape()
        batch_ndims = array_ops.shape(batch_shape)[0]

        ndims = batch_ndims + 3  # sample_ndims=1, event_ndims=2
        shape = array_ops.concat(0, ((n, ), batch_shape, event_shape))

        # Complexity: O(nbk^2)
        x = random_ops.random_normal(shape=shape,
                                     mean=0.,
                                     stddev=1.,
                                     dtype=self.dtype,
                                     seed=seed)

        # Complexity: O(nbk)
        # This parametrization is equivalent to Chi2, i.e.,
        # ChiSquared(k) == Gamma(alpha=k/2, beta=1/2)
        g = random_ops.random_gamma(shape=(n, ),
                                    alpha=self._multi_gamma_sequence(
                                        0.5 * self.df, self.dimension),
                                    beta=0.5,
                                    dtype=self.dtype,
                                    seed=seed)

        # Complexity: O(nbk^2)
        x = array_ops.batch_matrix_band_part(x, -1, 0)  # Tri-lower.

        # Complexity: O(nbk)
        x = array_ops.batch_matrix_set_diag(x, math_ops.sqrt(g))

        # Make batch-op ready.
        # Complexity: O(nbk^2)
        perm = array_ops.concat(0, (math_ops.range(1, ndims), (0, )))
        x = array_ops.transpose(x, perm)
        shape = array_ops.concat(0, (batch_shape, (event_shape[0], -1)))
        x = array_ops.reshape(x, shape)

        # Complexity: O(nbM) where M is the complexity of the operator solving a
        # vector system.  E.g., for OperatorPDDiag, each matmul is O(k^2), so
        # this complexity is O(nbk^2). For OperatorPDCholesky, each matmul is
        # O(k^3) so this step has complexity O(nbk^3).
        x = self.scale_operator_pd.sqrt_matmul(x)

        # Undo make batch-op ready.
        # Complexity: O(nbk^2)
        shape = array_ops.concat(0, (batch_shape, event_shape, (n, )))
        x = array_ops.reshape(x, shape)
        perm = array_ops.concat(0,
                                ((ndims - 1, ), math_ops.range(0, ndims - 1)))
        x = array_ops.transpose(x, perm)

        if not self.cholesky_input_output_matrices:
            # Complexity: O(nbk^3)
            x = math_ops.batch_matmul(x, x, adj_y=True)

        return x
Esempio n. 5
0
  def _sample_n(self, n, seed):
    batch_shape = self.batch_shape()
    event_shape = self.event_shape()
    batch_ndims = array_ops.shape(batch_shape)[0]

    ndims = batch_ndims + 3  # sample_ndims=1, event_ndims=2
    shape = array_ops.concat(0, ((n,), batch_shape, event_shape))

    # Complexity: O(nbk^2)
    x = random_ops.random_normal(shape=shape,
                                 mean=0.,
                                 stddev=1.,
                                 dtype=self.dtype,
                                 seed=seed)

    # Complexity: O(nbk)
    # This parametrization is equivalent to Chi2, i.e.,
    # ChiSquared(k) == Gamma(alpha=k/2, beta=1/2)
    g = random_ops.random_gamma(shape=(n,),
                                alpha=self._multi_gamma_sequence(
                                    0.5 * self.df, self.dimension),
                                beta=0.5,
                                dtype=self.dtype,
                                seed=seed)

    # Complexity: O(nbk^2)
    x = array_ops.batch_matrix_band_part(x, -1, 0)  # Tri-lower.

    # Complexity: O(nbk)
    x = array_ops.batch_matrix_set_diag(x, math_ops.sqrt(g))

    # Make batch-op ready.
    # Complexity: O(nbk^2)
    perm = array_ops.concat(0, (math_ops.range(1, ndims), (0,)))
    x = array_ops.transpose(x, perm)
    shape = array_ops.concat(0, (batch_shape, (event_shape[0], -1)))
    x = array_ops.reshape(x, shape)

    # Complexity: O(nbM) where M is the complexity of the operator solving a
    # vector system.  E.g., for OperatorPDDiag, each matmul is O(k^2), so
    # this complexity is O(nbk^2). For OperatorPDCholesky, each matmul is
    # O(k^3) so this step has complexity O(nbk^3).
    x = self.scale_operator_pd.sqrt_matmul(x)

    # Undo make batch-op ready.
    # Complexity: O(nbk^2)
    shape = array_ops.concat(0, (batch_shape, event_shape, (n,)))
    x = array_ops.reshape(x, shape)
    perm = array_ops.concat(0, ((ndims-1,), math_ops.range(0, ndims-1)))
    x = array_ops.transpose(x, perm)

    if not self.cholesky_input_output_matrices:
      # Complexity: O(nbk^3)
      x = math_ops.batch_matmul(x, x, adj_y=True)

    return x
Esempio n. 6
0
def _BatchMatrixTriangularSolveGrad(op, grad):
    """Gradient for BatchMatrixTriangularSolve."""
    a = op.inputs[0]
    adjoint_a = op.get_attr("adjoint")
    lower_a = op.get_attr("lower")
    c = op.outputs[0]
    grad_b = linalg_ops.batch_matrix_triangular_solve(a,
                                                      grad,
                                                      lower=lower_a,
                                                      adjoint=not adjoint_a)
    if adjoint_a:
        grad_a = -math_ops.batch_matmul(c, grad_b, adj_y=True)
    else:
        grad_a = -math_ops.batch_matmul(grad_b, c, adj_y=True)
    if lower_a:
        grad_a = array_ops.batch_matrix_band_part(grad_a, -1, 0)
    else:
        grad_a = array_ops.batch_matrix_band_part(grad_a, 0, -1)
    return (grad_a, grad_b)
  def sqrt_matmul(self, x, name='sqrt_matmul'):
    """Left (batch) matmul `x` by a sqrt of this matrix:  `Sx` where `A = S S^T.

    Args:
      x: `Tensor` with shape broadcastable to `[N1,...,Nb, k]` and same `dtype`
        as self.
      name:  A name scope to use for ops added by this method.

    Returns:
      Shape `[N1,...,Nb, k]` `Tensor` holding the product `S x`.
    """
    with ops.name_scope(self.name):
      with ops.op_scope([x] + self.inputs, name):
        chol_lower = array_ops.batch_matrix_band_part(self._chol, -1, 0)
        return math_ops.batch_matmul(chol_lower, x)
    def sqrt_matmul(self, x, name='sqrt_matmul'):
        """Left (batch) matmul `x` by a sqrt of this matrix:  `Sx` where `A = S S^T.

    Args:
      x: `Tensor` with shape broadcastable to `[N1,...,Nb, k]` and same `dtype`
        as self.
      name:  A name scope to use for ops added by this method.

    Returns:
      Shape `[N1,...,Nb, k]` `Tensor` holding the product `S x`.
    """
        with ops.name_scope(self.name):
            with ops.op_scope([x] + self.inputs, name):
                chol_lower = array_ops.batch_matrix_band_part(
                    self._chol, -1, 0)
                return math_ops.batch_matmul(chol_lower, x)
Esempio n. 9
0
  def __init__(self, mu, sigma=None, sigma_chol=None, name=None):
    """Multivariate Normal distributions on `R^k`.

    User must provide means `mu`, which are tensors of rank `N+1` (`N >= 0`)
    with the last dimension having length `k`.

    User must provide exactly one of `sigma` (the covariance matrices) or
    `sigma_chol` (the cholesky decompositions of the covariance matrices).
    `sigma` or `sigma_chol` must be of rank `N+2`.  The last two dimensions
    must both have length `k`.  The first `N` dimensions correspond to batch
    indices.

    If `sigma_chol` is not provided, the batch cholesky factorization of `sigma`
    is calculated for you.

    The shapes of `mu` and `sigma` must match for the first `N` dimensions.

    Regardless of which parameter is provided, the covariance matrices must all
    be **positive definite** (an error is raised if one of them is not).

    Args:
      mu: (N+1)-D.  `float` or `double` tensor, the means of the distributions.
      sigma: (N+2)-D.  (optional) `float` or `double` tensor, the covariances
        of the distribution(s).  The first `N+1` dimensions must match
        those of `mu`.  Must be batch-positive-definite.
      sigma_chol: (N+2)-D.  (optional) `float` or `double` tensor, a
        lower-triangular factorization of `sigma`
        (`sigma = sigma_chol . sigma_chol^*`).  The first `N+1` dimensions
        must match those of `mu`.  The tensor itself need not be batch
        lower triangular: we ignore the upper triangular part.  However,
        the batch diagonals must be positive (i.e., sigma_chol must be
        batch-positive-definite).
      name: The name to give Ops created by the initializer.

    Raises:
      ValueError: if neither sigma nor sigma_chol is provided.
      TypeError: if mu and sigma (resp. sigma_chol) are different dtypes.
    """
    if (sigma is None) == (sigma_chol is None):
      raise ValueError("Exactly one of sigma and sigma_chol must be provided")

    with ops.op_scope([mu, sigma, sigma_chol], name, "MultivariateNormal"):
      sigma_or_half = sigma_chol if sigma is None else sigma

      mu = ops.convert_to_tensor(mu)
      sigma_or_half = ops.convert_to_tensor(sigma_or_half)

      contrib_tensor_util.assert_same_float_dtype((mu, sigma_or_half))

      with ops.control_dependencies([
          _assert_compatible_shapes(mu, sigma_or_half)]):
        mu = array_ops.identity(mu, name="mu")

        # Store the dimensionality of the MVNs
        self._k = array_ops.gather(array_ops.shape(mu), array_ops.rank(mu) - 1)

        if sigma_chol is not None:
          # Ensure we only keep the lower triangular part.
          sigma_chol = array_ops.batch_matrix_band_part(
              sigma_chol, num_lower=-1, num_upper=0)
          sigma_det = _determinant_from_sigma_chol(sigma_chol)
          with ops.control_dependencies([
              _assert_batch_positive_definite(sigma_chol)]):
            self._sigma = math_ops.batch_matmul(
                sigma_chol, sigma_chol, adj_y=True, name="sigma")
            self._sigma_chol = array_ops.identity(sigma_chol, "sigma_chol")
            self._sigma_det = array_ops.identity(sigma_det, "sigma_det")
            self._mu = array_ops.identity(mu, "mu")
        else:  # sigma is not None
          sigma_chol = linalg_ops.batch_cholesky(sigma)
          sigma_det = _determinant_from_sigma_chol(sigma_chol)
          # batch_cholesky checks for PSD; so we can just use it here.
          with ops.control_dependencies([sigma_chol]):
            self._sigma = array_ops.identity(sigma, "sigma")
            self._sigma_chol = array_ops.identity(sigma_chol, "sigma_chol")
            self._sigma_det = array_ops.identity(sigma_det, "sigma_det")
            self._mu = array_ops.identity(mu, "mu")
Esempio n. 10
0
  def sample_n(self, n, seed=None, name='sample'):
    # pylint: disable=line-too-long
    """Generate `n` samples.

    Complexity: O(nbk^3)

    The sampling procedure is based on the [Bartlett decomposition](
    https://en.wikipedia.org/wiki/Wishart_distribution#Bartlett_decomposition)
    and [using a Gamma distribution to generate Chi2 random variates](
    https://en.wikipedia.org/wiki/Chi-squared_distribution#Gamma.2C_exponential.2C_and_related_distributions).

    Args:
      n: `Scalar` `Tensor` of type `int32` or `int64`, the number of
        observations to sample.
      seed: Python integer; random number generator seed.
      name: The name of this op.

    Returns:
      samples: a `Tensor` of shape `(n,) + self.batch_shape + self.event_shape`
          with values of type `self.dtype`.
    """
    with ops.name_scope(self.name):
      with ops.name_scope(name, values=[n] + list(self.inputs.values())):
        n = ops.convert_to_tensor(n, name='n')
        if n.dtype != dtypes.int32:
          raise TypeError('n.dtype=%s which is not int32' % n.dtype)
        batch_shape = self.batch_shape()
        event_shape = self.event_shape()
        batch_ndims = array_ops.shape(batch_shape)[0]

        ndims = batch_ndims + 3  # sample_ndims=1, event_ndims=2
        shape = array_ops.concat(0, ((n,), batch_shape, event_shape))

        # Complexity: O(nbk^2)
        x = random_ops.random_normal(shape=shape,
                                     mean=0.,
                                     stddev=1.,
                                     dtype=self.dtype,
                                     seed=seed)

        # Complexity: O(nbk)
        # This parametrization is equivalent to Chi2, i.e.,
        # ChiSquared(k) == Gamma(alpha=k/2, beta=1/2)
        g = random_ops.random_gamma(shape=(n,),
                                    alpha=self._multi_gamma_sequence(
                                        0.5 * self.df, self.dimension),
                                    beta=0.5,
                                    dtype=self.dtype,
                                    seed=seed)

        # Complexity: O(nbk^2)
        x = array_ops.batch_matrix_band_part(x, -1, 0)  # Tri-lower.

        # Complexity: O(nbk)
        x = array_ops.batch_matrix_set_diag(x, math_ops.sqrt(g))

        # Make batch-op ready.
        # Complexity: O(nbk^2)
        perm = array_ops.concat(0, (math_ops.range(1, ndims), (0,)))
        x = array_ops.transpose(x, perm)
        shape = array_ops.concat(0, (batch_shape, (event_shape[0], -1)))
        x = array_ops.reshape(x, shape)

        # Complexity: O(nbM) where M is the complexity of the operator solving a
        # vector system.  E.g., for OperatorPDDiag, each matmul is O(k^2), so
        # this complexity is O(nbk^2). For OperatorPDCholesky, each matmul is
        # O(k^3) so this step has complexity O(nbk^3).
        x = self.scale_operator_pd.sqrt_matmul(x)

        # Undo make batch-op ready.
        # Complexity: O(nbk^2)
        shape = array_ops.concat(0, (batch_shape, event_shape, (n,)))
        x = array_ops.reshape(x, shape)
        perm = array_ops.concat(0, ((ndims-1,), math_ops.range(0, ndims-1)))
        x = array_ops.transpose(x, perm)

        if not self.cholesky_input_output_matrices:
          # Complexity: O(nbk^3)
          x = math_ops.batch_matmul(x, x, adj_y=True)

        # Set shape hints.
        if self.scale_operator_pd.get_shape().ndims is not None:
          x.set_shape(tensor_shape.TensorShape(
              [tensor_util.constant_value(n)] +
              self.scale_operator_pd.get_shape().as_list()))
        elif x.get_shape().ndims is not None:
          x.get_shape()[0].merge_with(
              tensor_shape.TensorDimension(tensor_util.constant_value(n)))

        return x
Esempio n. 11
0
def _BatchMatrixBandPartGrad(op, grad):
    num_lower = op.inputs[1]
    num_upper = op.inputs[2]
    return (array_ops.batch_matrix_band_part(grad, num_lower,
                                             num_upper), None, None)
Esempio n. 12
0
 def _matmul(self, x, transpose_x=False):
   # tf.matmul is defined a * b.
   chol = array_ops.batch_matrix_band_part(self._chol, -1, 0)
   chol_times_x = math_ops.matmul(
       chol, x, transpose_a=True, transpose_b=transpose_x)
   return math_ops.matmul(chol, chol_times_x)
Esempio n. 13
0
 def _to_dense(self):
     chol = array_ops.batch_matrix_band_part(self._chol, -1, 0)
     return math_ops.batch_matmul(chol, chol, adj_y=True)
Esempio n. 14
0
 def _sqrt_to_dense(self):
     chol = array_ops.batch_matrix_band_part(self._chol, -1, 0)
     return array_ops.identity(chol)
Esempio n. 15
0
    def sample_n(self, n, seed=None, name="sample"):
        # pylint: disable=line-too-long
        """Generate `n` samples.

    Complexity: O(nbk^3)

    The sampling procedure is based on the [Bartlett decomposition](
    https://en.wikipedia.org/wiki/Wishart_distribution#Bartlett_decomposition)
    and [using a Gamma distribution to generate Chi2 random variates](
    https://en.wikipedia.org/wiki/Chi-squared_distribution#Gamma.2C_exponential.2C_and_related_distributions).

    Args:
      n: `Scalar` `Tensor` of type `int32` or `int64`, the number of
        observations to sample.
      seed: Python integer; random number generator seed.
      name: The name of this op.

    Returns:
      samples: a `Tensor` of shape `(n,) + self.batch_shape + self.event_shape`
          with values of type `self.dtype`.
    """
        with ops.name_scope(self.name):
            with ops.name_scope(name, values=[n] + list(self.inputs.values())):
                n = ops.convert_to_tensor(n, name="n")
                if n.dtype != dtypes.int32:
                    raise TypeError("n.dtype=%s which is not int32" % n.dtype)
                batch_shape = self.batch_shape()
                event_shape = self.event_shape()
                batch_ndims = array_ops.shape(batch_shape)[0]

                ndims = batch_ndims + 3  # sample_ndims=1, event_ndims=2
                shape = array_ops.concat(0, ((n,), batch_shape, event_shape))

                # Complexity: O(nbk^2)
                x = random_ops.random_normal(shape=shape, mean=0.0, stddev=1.0, dtype=self.dtype, seed=seed)

                # Complexity: O(nbk)
                # This parametrization is equivalent to Chi2, i.e.,
                # ChiSquared(k) == Gamma(alpha=k/2, beta=1/2)
                g = random_ops.random_gamma(
                    shape=(n,),
                    alpha=self._multi_gamma_sequence(0.5 * self.df, self.dimension),
                    beta=0.5,
                    dtype=self.dtype,
                    seed=seed,
                )

                # Complexity: O(nbk^2)
                x = array_ops.batch_matrix_band_part(x, -1, 0)  # Tri-lower.

                # Complexity: O(nbk)
                x = array_ops.batch_matrix_set_diag(x, math_ops.sqrt(g))

                # Make batch-op ready.
                # Complexity: O(nbk^2)
                perm = array_ops.concat(0, (math_ops.range(1, ndims), (0,)))
                x = array_ops.transpose(x, perm)
                shape = array_ops.concat(0, (batch_shape, (event_shape[0], -1)))
                x = array_ops.reshape(x, shape)

                # Complexity: O(nbM) where M is the complexity of the operator solving a
                # vector system.  E.g., for OperatorPDDiag, each matmul is O(k^2), so
                # this complexity is O(nbk^2). For OperatorPDCholesky, each matmul is
                # O(k^3) so this step has complexity O(nbk^3).
                x = self.scale_operator_pd.sqrt_matmul(x)

                # Undo make batch-op ready.
                # Complexity: O(nbk^2)
                shape = array_ops.concat(0, (batch_shape, event_shape, (n,)))
                x = array_ops.reshape(x, shape)
                perm = array_ops.concat(0, ((ndims - 1,), math_ops.range(0, ndims - 1)))
                x = array_ops.transpose(x, perm)

                if not self.cholesky_input_output_matrices:
                    # Complexity: O(nbk^3)
                    x = math_ops.batch_matmul(x, x, adj_y=True)

                # Set shape hints.
                if self.scale_operator_pd.get_shape().ndims is not None:
                    x.set_shape(
                        tensor_shape.TensorShape(
                            [tensor_util.constant_value(n)] + self.scale_operator_pd.get_shape().as_list()
                        )
                    )
                elif x.get_shape().ndims is not None:
                    x.get_shape()[0].merge_with(tensor_shape.TensorDimension(tensor_util.constant_value(n)))

                return x
Esempio n. 16
0
 def _sqrt_matmul(self, x, transpose_x=False):
     chol = array_ops.batch_matrix_band_part(self._chol, -1, 0)
     # tf.matmul is defined a * b
     return math_ops.matmul(chol, x, transpose_b=transpose_x)
Esempio n. 17
0
 def _batch_matmul(self, x, transpose_x=False):
   # tf.batch_matmul is defined x * y, so "y" is on the right, not "x".
   chol = array_ops.batch_matrix_band_part(self._chol, -1, 0)
   chol_times_x = math_ops.batch_matmul(
       chol, x, adj_x=True, adj_y=transpose_x)
   return math_ops.batch_matmul(chol, chol_times_x)
Esempio n. 18
0
 def _sqrt_matmul(self, x, transpose_x=False):
   chol = array_ops.batch_matrix_band_part(self._chol, -1, 0)
   # tf.matmul is defined a * b
   return math_ops.matmul(chol, x, transpose_b=transpose_x)
Esempio n. 19
0
 def _batch_sqrt_matmul(self, x, transpose_x=False):
   chol = array_ops.batch_matrix_band_part(self._chol, -1, 0)
   # tf.batch_matmul is defined x * y, so "y" is on the right, not "x".
   return math_ops.batch_matmul(chol, x, adj_y=transpose_x)
Esempio n. 20
0
 def _sqrt_to_dense(self):
   chol = array_ops.batch_matrix_band_part(self._chol, -1, 0)
   return array_ops.identity(chol)
Esempio n. 21
0
    def __init__(self, mu, sigma=None, sigma_chol=None, name=None):
        """Multivariate Normal distributions on `R^k`.

    User must provide means `mu`, which are tensors of rank `N+1` (`N >= 0`)
    with the last dimension having length `k`.

    User must provide exactly one of `sigma` (the covariance matrices) or
    `sigma_chol` (the cholesky decompositions of the covariance matrices).
    `sigma` or `sigma_chol` must be of rank `N+2`.  The last two dimensions
    must both have length `k`.  The first `N` dimensions correspond to batch
    indices.

    If `sigma_chol` is not provided, the batch cholesky factorization of `sigma`
    is calculated for you.

    The shapes of `mu` and `sigma` must match for the first `N` dimensions.

    Regardless of which parameter is provided, the covariance matrices must all
    be **positive definite** (an error is raised if one of them is not).

    Args:
      mu: (N+1)-D.  `float` or `double` tensor, the means of the distributions.
      sigma: (N+2)-D.  (optional) `float` or `double` tensor, the covariances
        of the distribution(s).  The first `N+1` dimensions must match
        those of `mu`.  Must be batch-positive-definite.
      sigma_chol: (N+2)-D.  (optional) `float` or `double` tensor, a
        lower-triangular factorization of `sigma`
        (`sigma = sigma_chol . sigma_chol^*`).  The first `N+1` dimensions
        must match those of `mu`.  The tensor itself need not be batch
        lower triangular: we ignore the upper triangular part.  However,
        the batch diagonals must be positive (i.e., sigma_chol must be
        batch-positive-definite).
      name: The name to give Ops created by the initializer.

    Raises:
      ValueError: if neither sigma nor sigma_chol is provided.
      TypeError: if mu and sigma (resp. sigma_chol) are different dtypes.
    """
        if (sigma is None) == (sigma_chol is None):
            raise ValueError(
                "Exactly one of sigma and sigma_chol must be provided")

        with ops.op_scope([mu, sigma, sigma_chol], name, "MultivariateNormal"):
            sigma_or_half = sigma_chol if sigma is None else sigma

            mu = ops.convert_to_tensor(mu)
            sigma_or_half = ops.convert_to_tensor(sigma_or_half)

            contrib_tensor_util.assert_same_float_dtype((mu, sigma_or_half))

            with ops.control_dependencies(
                [_assert_compatible_shapes(mu, sigma_or_half)]):
                mu = array_ops.identity(mu, name="mu")

                # Store the dimensionality of the MVNs
                self._k = array_ops.gather(array_ops.shape(mu),
                                           array_ops.rank(mu) - 1)

                if sigma_chol is not None:
                    # Ensure we only keep the lower triangular part.
                    sigma_chol = array_ops.batch_matrix_band_part(sigma_chol,
                                                                  num_lower=-1,
                                                                  num_upper=0)
                    sigma_det = _determinant_from_sigma_chol(sigma_chol)
                    with ops.control_dependencies(
                        [_assert_batch_positive_definite(sigma_chol)]):
                        self._sigma = math_ops.batch_matmul(sigma_chol,
                                                            sigma_chol,
                                                            adj_y=True,
                                                            name="sigma")
                        self._sigma_chol = array_ops.identity(
                            sigma_chol, "sigma_chol")
                        self._sigma_det = array_ops.identity(
                            sigma_det, "sigma_det")
                        self._mu = array_ops.identity(mu, "mu")
                else:  # sigma is not None
                    sigma_chol = linalg_ops.batch_cholesky(sigma)
                    sigma_det = _determinant_from_sigma_chol(sigma_chol)
                    # batch_cholesky checks for PSD; so we can just use it here.
                    with ops.control_dependencies([sigma_chol]):
                        self._sigma = array_ops.identity(sigma, "sigma")
                        self._sigma_chol = array_ops.identity(
                            sigma_chol, "sigma_chol")
                        self._sigma_det = array_ops.identity(
                            sigma_det, "sigma_det")
                        self._mu = array_ops.identity(mu, "mu")
Esempio n. 22
0
 def _batch_sqrt_matmul(self, x, transpose_x=False):
     chol = array_ops.batch_matrix_band_part(self._chol, -1, 0)
     # tf.batch_matmul is defined x * y, so "y" is on the right, not "x".
     return math_ops.batch_matmul(chol, x, adj_y=transpose_x)
Esempio n. 23
0
def _BatchMatrixBandPartGrad(op, grad):
  num_lower = op.inputs[1]
  num_upper = op.inputs[2]
  return (array_ops.batch_matrix_band_part(grad, num_lower, num_upper), None,
          None)
Esempio n. 24
0
 def _to_dense(self):
   chol = array_ops.batch_matrix_band_part(self._chol, -1, 0)
   return math_ops.batch_matmul(chol, chol, adj_y=True)