示例#1
0
def _log_determinant_from_sigma_chol(sigma_chol):
    det_last_dim = array_ops.rank(sigma_chol) - 2
    sigma_batch_diag = array_ops.batch_matrix_diag_part(sigma_chol)
    log_det = 2.0 * math_ops.reduce_sum(math_ops.log(sigma_batch_diag),
                                        reduction_indices=det_last_dim)
    log_det.set_shape(sigma_chol.get_shape()[:-2])
    return log_det
示例#2
0
文件: mvn.py 项目: 0-T-0/tensorflow
def _determinant_from_sigma_chol(sigma_chol):
  det_last_dim = array_ops.rank(sigma_chol) - 2
  sigma_batch_diag = array_ops.batch_matrix_diag_part(sigma_chol)
  det = math_ops.square(math_ops.reduce_prod(
      sigma_batch_diag, reduction_indices=det_last_dim))
  det.set_shape(sigma_chol.get_shape()[:-2])
  return det
示例#3
0
    def variance(self, name="variance"):
        """Variance of the Wishart distribution.

    This function should not be confused with the covariance of the Wishart. The
    covariance matrix would have shape `q x q` where,
    `q = dimension * (dimension+1) / 2`
    and having elements corresponding to some mapping from a lower-triangular
    matrix to a vector-space.

    This function returns the diagonal of the Covariance matrix but shaped
    as a `dimension x dimension` matrix.

    Args:
      name: The name of this op.

    Returns:
      variance: `Tensor` of dtype `self.dtype`.
    """
        with ops.name_scope(self.name):
            with ops.name_scope(name, values=list(self.inputs.values())):
                x = math_ops.sqrt(self.df) * self.scale_operator_pd.to_dense()
                d = array_ops.expand_dims(array_ops.batch_matrix_diag_part(x), -1)
                v = math_ops.square(x) + math_ops.batch_matmul(d, d, adj_y=True)
                if self.cholesky_input_output_matrices:
                    return linalg_ops.batch_cholesky(v)
                else:
                    return v
示例#4
0
def _log_determinant_from_sigma_chol(sigma_chol):
  det_last_dim = array_ops.rank(sigma_chol) - 2
  sigma_batch_diag = array_ops.batch_matrix_diag_part(sigma_chol)
  log_det = 2.0 * math_ops.reduce_sum(
      math_ops.log(sigma_batch_diag), reduction_indices=det_last_dim)
  log_det.set_shape(sigma_chol.get_shape()[:-2])
  return log_det
示例#5
0
def _assert_batch_positive_definite(sigma_chol):
    """Add assertions checking that the sigmas are all Positive Definite.

  Given `sigma_chol == cholesky(sigma)`, it is sufficient to check that
  `all(diag(sigma_chol) > 0)`.  This is because to check that a matrix is PD,
  it is sufficient that its cholesky factorization is PD, and to check that a
  triangular matrix is PD, it is sufficient to check that its diagonal
  entries are positive.

  Args:
    sigma_chol: N-D.  The lower triangular cholesky decomposition of `sigma`.

  Returns:
    An assertion op to use with `control_dependencies`, verifying that
    `sigma_chol` is positive definite.
  """
    sigma_batch_diag = array_ops.batch_matrix_diag_part(sigma_chol)
    return logging_ops.Assert(
        math_ops.reduce_all(sigma_batch_diag > 0),
        [
            "sigma_chol is not positive definite.  batched diagonals: ",
            sigma_batch_diag,
            " shaped: ",
            array_ops.shape(sigma_batch_diag),
        ],
    )
示例#6
0
def _determinant_from_sigma_chol(sigma_chol):
    det_last_dim = array_ops.rank(sigma_chol) - 2
    sigma_batch_diag = array_ops.batch_matrix_diag_part(sigma_chol)
    det = math_ops.square(
        math_ops.reduce_prod(sigma_batch_diag, reduction_indices=det_last_dim))
    det.set_shape(sigma_chol.get_shape()[:-2])
    return det
示例#7
0
  def variance(self, name='variance'):
    """Variance of the Wishart distribution.

    This function should not be confused with the covariance of the Wishart. The
    covariance matrix would have shape `q x q` where,
    `q = dimension * (dimension+1) / 2`
    and having elements corresponding to some mapping from a lower-triangular
    matrix to a vector-space.

    This function returns the diagonal of the Covariance matrix but shaped
    as a `dimension x dimension` matrix.

    Args:
      name: The name of this op.

    Returns:
      variance: `Tensor` of dtype `self.dtype`.
    """
    with ops.name_scope(self.name):
      with ops.name_scope(name, values=list(self.inputs.values())):
        x = math_ops.sqrt(self.df) * self.scale_operator_pd.to_dense()
        d = array_ops.expand_dims(array_ops.batch_matrix_diag_part(x), -1)
        v = math_ops.square(x) + math_ops.batch_matmul(d, d, adj_y=True)
        if self.cholesky_input_output_matrices:
          return linalg_ops.batch_cholesky(v)
        else:
          return v
def batch_matrix_diag_transform(matrix, transform=None, name=None):
    """Transform diagonal of [batch-]matrix, leave rest of matrix unchanged.

  Create a trainable covariance defined by a Cholesky factor:

  ```python
  # Transform network layer into 2 x 2 array.
  matrix_values = tf.contrib.layers.fully_connected(activations, 4)
  matrix = tf.reshape(matrix_values, (batch_size, 2, 2))

  # Make the diagonal positive.  If the upper triangle was zero, this would be a
  # valid Cholesky factor.
  chol = batch_matrix_diag_transform(matrix, transform=tf.nn.softplus)

  # OperatorPDCholesky ignores the upper triangle.
  operator = OperatorPDCholesky(chol)
  ```

  Example of heteroskedastic 2-D linear regression.

  ```python
  # Get a trainable Cholesky factor.
  matrix_values = tf.contrib.layers.fully_connected(activations, 4)
  matrix = tf.reshape(matrix_values, (batch_size, 2, 2))
  chol = batch_matrix_diag_transform(matrix, transform=tf.nn.softplus)

  # Get a trainable mean.
  mu = tf.contrib.layers.fully_connected(activations, 2)

  # This is a fully trainable multivariate normal!
  dist = tf.contrib.distributions.MVNCholesky(mu, chol)

  # Standard log loss.  Minimizing this will "train" mu and chol, and then dist
  # will be a distribution predicting labels as multivariate Gaussians.
  loss = -1 * tf.reduce_mean(dist.log_pdf(labels))
  ```

  Args:
    matrix:  Rank `R` `Tensor`, `R >= 2`, where the last two dimensions are
      equal.
    transform:  Element-wise function mapping `Tensors` to `Tensors`.  To
      be applied to the diagonal of `matrix`.  If `None`, `matrix` is returned
      unchanged.  Defaults to `None`.
    name:  A name to give created ops.
      Defaults to "batch_matrix_diag_transform".

  Returns:
    A `Tensor` with same shape and `dtype` as `matrix`.
  """
    with ops.name_scope(name, 'batch_matrix_diag_transform', [matrix]):
        matrix = ops.convert_to_tensor(matrix, name='matrix')
        if transform is None:
            return matrix
        # Replace the diag with transformed diag.
        diag = array_ops.batch_matrix_diag_part(matrix)
        transformed_diag = transform(diag)
        transformed_mat = array_ops.batch_matrix_set_diag(
            matrix, transformed_diag)

    return transformed_mat
示例#9
0
 def _variance(self):
     x = math_ops.sqrt(self.df) * self.scale_operator_pd.to_dense()
     d = array_ops.expand_dims(array_ops.batch_matrix_diag_part(x), -1)
     v = math_ops.square(x) + math_ops.batch_matmul(d, d, adj_y=True)
     if self.cholesky_input_output_matrices:
         return linalg_ops.cholesky(v)
     return v
示例#10
0
 def _variance(self):
   x = math_ops.sqrt(self.df) * self.scale_operator_pd.to_dense()
   d = array_ops.expand_dims(array_ops.batch_matrix_diag_part(x), -1)
   v = math_ops.square(x) + math_ops.batch_matmul(d, d, adj_y=True)
   if self.cholesky_input_output_matrices:
     return linalg_ops.batch_cholesky(v)
   return v
示例#11
0
def batch_matrix_diag_transform(matrix, transform=None, name=None):
  """Transform diagonal of [batch-]matrix, leave rest of matrix unchanged.

  Create a trainable covariance defined by a Cholesky factor:

  ```python
  # Transform network layer into 2 x 2 array.
  matrix_values = tf.contrib.layers.fully_connected(activations, 4)
  matrix = tf.reshape(matrix_values, (batch_size, 2, 2))

  # Make the diagonal positive.  If the upper triangle was zero, this would be a
  # valid Cholesky factor.
  chol = batch_matrix_diag_transform(matrix, transform=tf.nn.softplus)

  # OperatorPDCholesky ignores the upper triangle.
  operator = OperatorPDCholesky(chol)
  ```

  Example of heteroskedastic 2-D linear regression.

  ```python
  # Get a trainable Cholesky factor.
  matrix_values = tf.contrib.layers.fully_connected(activations, 4)
  matrix = tf.reshape(matrix_values, (batch_size, 2, 2))
  chol = batch_matrix_diag_transform(matrix, transform=tf.nn.softplus)

  # Get a trainable mean.
  mu = tf.contrib.layers.fully_connected(activations, 2)

  # This is a fully trainable multivariate normal!
  dist = tf.contrib.distributions.MVNCholesky(mu, chol)

  # Standard log loss.  Minimizing this will "train" mu and chol, and then dist
  # will be a distribution predicting labels as multivariate Gaussians.
  loss = -1 * tf.reduce_mean(dist.log_pdf(labels))
  ```

  Args:
    matrix:  Rank `R` `Tensor`, `R >= 2`, where the last two dimensions are
      equal.
    transform:  Element-wise function mapping `Tensors` to `Tensors`.  To
      be applied to the diagonal of `matrix`.  If `None`, `matrix` is returned
      unchanged.  Defaults to `None`.
    name:  A name to give created ops.
      Defaults to "batch_matrix_diag_transform".

  Returns:
    A `Tensor` with same shape and `dtype` as `matrix`.
  """
  with ops.name_scope(name, "batch_matrix_diag_transform", [matrix]):
    matrix = ops.convert_to_tensor(matrix, name="matrix")
    if transform is None:
      return matrix
    # Replace the diag with transformed diag.
    diag = array_ops.batch_matrix_diag_part(matrix)
    transformed_diag = transform(diag)
    transformed_mat = array_ops.batch_matrix_set_diag(matrix, transformed_diag)

  return transformed_mat
 def _sqrt_log_det(self):
     # The matrix determinant lemma states:
     # det(M + VDV^T) = det(D^{-1} + V^T M^{-1} V) * det(D) * det(M)
     #                = det(C) * det(D) * det(M)
     #
     # Here we compute the Cholesky factor of "C", then pass the result on.
     diag_chol_c = array_ops.batch_matrix_diag_part(self._chol_capacitance(batch_mode=False))
     return self._sqrt_log_det_core(diag_chol_c)
 def _sqrt_log_det(self):
   # The matrix determinant lemma states:
   # det(M + VDV^T) = det(D^{-1} + V^T M^{-1} V) * det(D) * det(M)
   #                = det(C) * det(D) * det(M)
   #
   # Here we compute the Cholesky factor of "C", then pass the result on.
   diag_chol_c = array_ops.batch_matrix_diag_part(self._chol_capacitance(
       batch_mode=False))
   return self._sqrt_log_det_core(diag_chol_c)
 def _batch_log_det(self):
   """Log determinant of every batch member."""
   # Note that array_ops.diag_part does not seem more efficient for non-batch,
   # and would give a bad result for a batch matrix, so aways use
   # batch_matrix_diag_part.
   diag = array_ops.batch_matrix_diag_part(self._chol)
   det = 2.0 * math_ops.reduce_sum(math_ops.log(diag), reduction_indices=[-1])
   det.set_shape(self.get_shape()[:-2])
   return det
示例#15
0
 def _batch_log_det(self):
     """Log determinant of every batch member."""
     # Note that array_ops.diag_part does not seem more efficient for non-batch,
     # and would give a bad result for a batch matrix, so aways use
     # batch_matrix_diag_part.
     diag = array_ops.batch_matrix_diag_part(self._chol)
     det = 2.0 * math_ops.reduce_sum(math_ops.log(diag),
                                     reduction_indices=[-1])
     det.set_shape(self.get_shape()[:-2])
     return det
示例#16
0
def _BatchMatrixSetDiagGrad(op, grad):
  diag_shape = op.inputs[1].get_shape()
  diag_shape = diag_shape.merge_with(op.inputs[0].get_shape()[:-1])
  diag_shape = diag_shape.merge_with(grad.get_shape()[:-1])
  if diag_shape.is_fully_defined():
    diag_shape = diag_shape.as_list()
  else:
    diag_shape = array_ops.shape(grad)
    diag_shape = array_ops.slice(diag_shape, [0], [array_ops.rank(grad) - 1])
  grad_input = array_ops.batch_matrix_set_diag(
      grad, array_ops.zeros(diag_shape, dtype=grad.dtype))
  grad_diag = array_ops.batch_matrix_diag_part(grad)
  return (grad_input, grad_diag)
示例#17
0
def _BatchMatrixSetDiagGrad(op, grad):
  diag_shape = op.inputs[1].get_shape()
  diag_shape = diag_shape.merge_with(op.inputs[0].get_shape()[:-1])
  diag_shape = diag_shape.merge_with(grad.get_shape()[:-1])
  if diag_shape.is_fully_defined():
    diag_shape = diag_shape.as_list()
  else:
    diag_shape = array_ops.shape(grad)
    diag_shape = array_ops.slice(diag_shape, [0], [array_ops.rank(grad) - 1])
  grad_input = array_ops.batch_matrix_set_diag(
      grad, array_ops.zeros(diag_shape, dtype=grad.dtype))
  grad_diag = array_ops.batch_matrix_diag_part(grad)
  return (grad_input, grad_diag)
示例#18
0
    def _check_chol(self, chol):
        """Verify that `chol` is proper."""
        chol = ops.convert_to_tensor(chol, name='chol')
        if not self.verify_pd:
            return chol

        shape = array_ops.shape(chol)
        rank = array_ops.rank(chol)

        is_matrix = check_ops.assert_rank_at_least(chol, 2)
        is_square = check_ops.assert_equal(array_ops.gather(shape, rank - 2),
                                           array_ops.gather(shape, rank - 1))

        deps = [is_matrix, is_square]
        diag = array_ops.batch_matrix_diag_part(chol)
        deps.append(check_ops.assert_positive(diag))

        return control_flow_ops.with_dependencies(deps, chol)
  def _check_chol(self, chol):
    """Verify that `chol` is proper."""
    chol = ops.convert_to_tensor(chol, name='chol')
    if not self.verify_pd:
      return chol

    shape = array_ops.shape(chol)
    rank = array_ops.rank(chol)

    is_matrix = check_ops.assert_rank_at_least(chol, 2)
    is_square = check_ops.assert_equal(
        array_ops.gather(shape, rank - 2), array_ops.gather(shape, rank - 1))

    deps = [is_matrix, is_square]
    diag = array_ops.batch_matrix_diag_part(chol)
    deps.append(check_ops.assert_positive(diag))

    return control_flow_ops.with_dependencies(deps, chol)
示例#20
0
    def __init__(self, chol, verify_pd=True, name='OperatorPDCholesky'):
        """Initialize an OperatorPDCholesky.

    Args:
      chol:  Shape `[N1,...,Nb, k, k]` tensor with `b >= 0`, `k >= 1`, and
        positive diagonal elements.  The strict upper triangle of `chol` is
        never used, and the user may set these elements to zero, or ignore them.
      verify_pd: Whether to check that `chol` has positive diagonal (this is
        equivalent to it being a Cholesky factor of a symmetric positive
        definite matrix.  If `verify_pd` is `False`, correct behavior is not
        guaranteed.
      name:  A name to prepend to all ops created by this class.
    """
        self._verify_pd = verify_pd
        self._name = name
        with ops.name_scope(name):
            with ops.op_scope([chol], 'init'):
                self._diag = array_ops.batch_matrix_diag_part(chol)
                self._chol = self._check_chol(chol)
示例#21
0
  def __init__(self, chol, verify_pd=True, name='OperatorPDCholesky'):
    """Initialize an OperatorPDCholesky.

    Args:
      chol:  Shape `[N1,...,Nb, k, k]` tensor with `b >= 0`, `k >= 1`, and
        positive diagonal elements.  The strict upper triangle of `chol` is
        never used, and the user may set these elements to zero, or ignore them.
      verify_pd: Whether to check that `chol` has positive diagonal (this is
        equivalent to it being a Cholesky factor of a symmetric positive
        definite matrix.  If `verify_pd` is `False`, correct behavior is not
        guaranteed.
      name:  A name to prepend to all ops created by this class.
    """
    self._verify_pd = verify_pd
    self._name = name
    with ops.name_scope(name):
      with ops.op_scope([chol], 'init'):
        self._diag = array_ops.batch_matrix_diag_part(chol)
        self._chol = self._check_chol(chol)
示例#22
0
def _assert_batch_positive_definite(sigma_chol):
  """Add assertions checking that the sigmas are all Positive Definite.

  Given `sigma_chol == cholesky(sigma)`, it is sufficient to check that
  `all(diag(sigma_chol) > 0)`.  This is because to check that a matrix is PD,
  it is sufficient that its cholesky factorization is PD, and to check that a
  triangular matrix is PD, it is sufficient to check that its diagonal
  entries are positive.

  Args:
    sigma_chol: N-D.  The lower triangular cholesky decomposition of `sigma`.

  Returns:
    An assertion op to use with `control_dependencies`, verifying that
    `sigma_chol` is positive definite.
  """
  sigma_batch_diag = array_ops.batch_matrix_diag_part(sigma_chol)
  return logging_ops.Assert(
      math_ops.reduce_all(sigma_batch_diag > 0),
      ["sigma_chol is not positive definite.  batched diagonals: ",
       sigma_batch_diag, " shaped: ", array_ops.shape(sigma_batch_diag)])
 def _add_to_tensor(self, mat):
   # Add to a tensor in O(k) time!
   mat_diag = array_ops.batch_matrix_diag_part(mat)
   new_diag = constant_op.constant(1, dtype=self.dtype) + mat_diag
   return array_ops.batch_matrix_set_diag(mat, new_diag)
示例#24
0
 def _inverse_log_det_jacobian(self, x):  # pylint: disable=unused-argument
     return -math_ops.reduce_sum(math_ops.log(
         array_ops.batch_matrix_diag_part(self.scale)),
                                 reduction_indices=[-1])
示例#25
0
def _BatchMatrixDiagGrad(_, grad):
    return array_ops.batch_matrix_diag_part(grad)
示例#26
0
  def log_prob(self, x, name='log_prob'):
    """Log of the probability density/mass function.

    Args:
      x: `float` or `double` `Tensor`.
      name: The name to give this op.

    Returns:
      log_prob: a `Tensor` of shape `sample_shape(x) + self.batch_shape` with
        values of type `self.dtype`.
    """
    with ops.name_scope(self.name):
      with ops.name_scope(name, values=[x] + list(self.inputs.values())):
        x = ops.convert_to_tensor(x, name='x')
        contrib_tensor_util.assert_same_float_dtype(
            (self.scale_operator_pd, x))
        if self.cholesky_input_output_matrices:
          x_sqrt = x
        else:
          # Complexity: O(nbk^3)
          x_sqrt = linalg_ops.batch_cholesky(x)

        batch_shape = self.batch_shape()
        event_shape = self.event_shape()
        ndims = array_ops.rank(x_sqrt)
        # sample_ndims = ndims - batch_ndims - event_ndims
        sample_ndims = ndims - array_ops.shape(batch_shape)[0] - 2
        sample_shape = array_ops.slice(
            array_ops.shape(x_sqrt), [0], [sample_ndims])

        # We need to be able to pre-multiply each matrix by its corresponding
        # batch scale matrix.  Since a Distribution Tensor supports multiple
        # samples per batch, this means we need to reshape the input matrix `x`
        # so that the first b dimensions are batch dimensions and the last two
        # are of shape [dimension, dimensions*number_of_samples]. Doing these
        # gymnastics allows us to do a batch_solve.
        #
        # After we're done with sqrt_solve (the batch operation) we need to undo
        # this reshaping so what we're left with is a Tensor partitionable by
        # sample, batch, event dimensions.

        # Complexity: O(nbk^2) since transpose must access every element.
        scale_sqrt_inv_x_sqrt = x_sqrt
        perm = array_ops.concat(0, (math_ops.range(sample_ndims, ndims),
                                    math_ops.range(0, sample_ndims)))
        scale_sqrt_inv_x_sqrt = array_ops.transpose(scale_sqrt_inv_x_sqrt, perm)
        shape = array_ops.concat(
            0, (batch_shape,
                (math_ops.cast(self.dimension, dtype=dtypes.int32), -1)))
        scale_sqrt_inv_x_sqrt = array_ops.reshape(scale_sqrt_inv_x_sqrt, shape)

        # Complexity: O(nbM*k) where M is the complexity of the operator solving
        # a vector system.  E.g., for OperatorPDDiag, each solve is O(k), so
        # this complexity is O(nbk^2). For OperatorPDCholesky, each solve is
        # O(k^2) so this step has complexity O(nbk^3).
        scale_sqrt_inv_x_sqrt = self.scale_operator_pd.sqrt_solve(
            scale_sqrt_inv_x_sqrt)

        # Undo make batch-op ready.
        # Complexity: O(nbk^2)
        shape = array_ops.concat(0, (batch_shape, event_shape, sample_shape))
        scale_sqrt_inv_x_sqrt = array_ops.reshape(scale_sqrt_inv_x_sqrt, shape)
        perm = array_ops.concat(0, (math_ops.range(ndims - sample_ndims, ndims),
                                    math_ops.range(0, ndims - sample_ndims)))
        scale_sqrt_inv_x_sqrt = array_ops.transpose(scale_sqrt_inv_x_sqrt, perm)

        # Write V = SS', X = LL'. Then:
        # tr[inv(V) X] = tr[inv(S)' inv(S) L L']
        #              = tr[inv(S) L L' inv(S)']
        #              = tr[(inv(S) L) (inv(S) L)']
        #              = sum_{ik} (inv(S) L)_{ik}^2
        # The second equality follows from the cyclic permutation property.
        # Complexity: O(nbk^2)
        trace_scale_inv_x = math_ops.reduce_sum(
            math_ops.square(scale_sqrt_inv_x_sqrt),
            reduction_indices=[-2, -1])

        # Complexity: O(nbk)
        half_log_det_x = math_ops.reduce_sum(
            math_ops.log(array_ops.batch_matrix_diag_part(x_sqrt)),
            reduction_indices=[-1])

        # Complexity: O(nbk^2)
        log_prob = ((self.df - self.dimension - 1.) * half_log_det_x -
                    0.5 * trace_scale_inv_x -
                    self.log_normalizing_constant())

        # Set shape hints.
        # Try to merge what we know from the input then what we know from the
        # parameters of this distribution.
        if x.get_shape().ndims is not None:
          log_prob.set_shape(x.get_shape()[:-2])
        if (log_prob.get_shape().ndims is not None and
            self.get_batch_shape().ndims is not None and
            self.get_batch_shape().ndims > 0):
          log_prob.get_shape()[-self.get_batch_shape().ndims:].merge_with(
              self.get_batch_shape())

        return log_prob
示例#27
0
 def _add_to_tensor(self, mat):
     # Add to a tensor in O(k) time!
     mat_diag = array_ops.batch_matrix_diag_part(mat)
     new_diag = constant_op.constant(1, dtype=self.dtype) + mat_diag
     return array_ops.batch_matrix_set_diag(mat, new_diag)
示例#28
0
 def _add_to_tensor(self, mat):
   mat_diag = array_ops.batch_matrix_diag_part(mat)
   new_diag = math_ops.square(self._diag) + mat_diag
   return array_ops.batch_matrix_set_diag(mat, new_diag)
 def _batch_sqrt_log_det(self):
   # Here we compute the Cholesky factor of "C", then pass the result on.
   diag_chol_c = array_ops.batch_matrix_diag_part(self._chol_capacitance(
       batch_mode=True))
   return self._sqrt_log_det_core(diag_chol_c)
示例#30
0
  def _log_prob(self, x):
    if self.cholesky_input_output_matrices:
      x_sqrt = x
    else:
      # Complexity: O(nbk^3)
      x_sqrt = linalg_ops.batch_cholesky(x)

    batch_shape = self.batch_shape()
    event_shape = self.event_shape()
    ndims = array_ops.rank(x_sqrt)
    # sample_ndims = ndims - batch_ndims - event_ndims
    sample_ndims = ndims - array_ops.shape(batch_shape)[0] - 2
    sample_shape = array_ops.slice(
        array_ops.shape(x_sqrt), [0], [sample_ndims])

    # We need to be able to pre-multiply each matrix by its corresponding
    # batch scale matrix.  Since a Distribution Tensor supports multiple
    # samples per batch, this means we need to reshape the input matrix `x`
    # so that the first b dimensions are batch dimensions and the last two
    # are of shape [dimension, dimensions*number_of_samples]. Doing these
    # gymnastics allows us to do a batch_solve.
    #
    # After we're done with sqrt_solve (the batch operation) we need to undo
    # this reshaping so what we're left with is a Tensor partitionable by
    # sample, batch, event dimensions.

    # Complexity: O(nbk^2) since transpose must access every element.
    scale_sqrt_inv_x_sqrt = x_sqrt
    perm = array_ops.concat(0, (math_ops.range(sample_ndims, ndims),
                                math_ops.range(0, sample_ndims)))
    scale_sqrt_inv_x_sqrt = array_ops.transpose(scale_sqrt_inv_x_sqrt, perm)
    shape = array_ops.concat(
        0, (batch_shape,
            (math_ops.cast(self.dimension, dtype=dtypes.int32), -1)))
    scale_sqrt_inv_x_sqrt = array_ops.reshape(scale_sqrt_inv_x_sqrt, shape)

    # Complexity: O(nbM*k) where M is the complexity of the operator solving
    # a vector system.  E.g., for OperatorPDDiag, each solve is O(k), so
    # this complexity is O(nbk^2). For OperatorPDCholesky, each solve is
    # O(k^2) so this step has complexity O(nbk^3).
    scale_sqrt_inv_x_sqrt = self.scale_operator_pd.sqrt_solve(
        scale_sqrt_inv_x_sqrt)

    # Undo make batch-op ready.
    # Complexity: O(nbk^2)
    shape = array_ops.concat(0, (batch_shape, event_shape, sample_shape))
    scale_sqrt_inv_x_sqrt = array_ops.reshape(scale_sqrt_inv_x_sqrt, shape)
    perm = array_ops.concat(0, (math_ops.range(ndims - sample_ndims, ndims),
                                math_ops.range(0, ndims - sample_ndims)))
    scale_sqrt_inv_x_sqrt = array_ops.transpose(scale_sqrt_inv_x_sqrt, perm)

    # Write V = SS', X = LL'. Then:
    # tr[inv(V) X] = tr[inv(S)' inv(S) L L']
    #              = tr[inv(S) L L' inv(S)']
    #              = tr[(inv(S) L) (inv(S) L)']
    #              = sum_{ik} (inv(S) L)_{ik}^2
    # The second equality follows from the cyclic permutation property.
    # Complexity: O(nbk^2)
    trace_scale_inv_x = math_ops.reduce_sum(
        math_ops.square(scale_sqrt_inv_x_sqrt),
        reduction_indices=[-2, -1])

    # Complexity: O(nbk)
    half_log_det_x = math_ops.reduce_sum(
        math_ops.log(array_ops.batch_matrix_diag_part(x_sqrt)),
        reduction_indices=[-1])

    # Complexity: O(nbk^2)
    log_prob = ((self.df - self.dimension - 1.) * half_log_det_x -
                0.5 * trace_scale_inv_x -
                self.log_normalizing_constant())

    # Set shape hints.
    # Try to merge what we know from the input then what we know from the
    # parameters of this distribution.
    if x.get_shape().ndims is not None:
      log_prob.set_shape(x.get_shape()[:-2])
    if (log_prob.get_shape().ndims is not None and
        self.get_batch_shape().ndims is not None and
        self.get_batch_shape().ndims > 0):
      log_prob.get_shape()[-self.get_batch_shape().ndims:].merge_with(
          self.get_batch_shape())

    return log_prob
示例#31
0
 def _add_to_tensor(self, mat):
     mat_diag = array_ops.batch_matrix_diag_part(mat)
     new_diag = math_ops.square(self._diag) + mat_diag
     return array_ops.batch_matrix_set_diag(mat, new_diag)
示例#32
0
def _BatchMatrixDiagGrad(_, grad):
  return array_ops.batch_matrix_diag_part(grad)
示例#33
0
 def _inverse_log_det_jacobian(self, x):  # pylint: disable=unused-argument
   return -math_ops.reduce_sum(
       math_ops.log(array_ops.batch_matrix_diag_part(self.scale)),
       reduction_indices=[-1])