def _log_determinant_from_sigma_chol(sigma_chol): det_last_dim = array_ops.rank(sigma_chol) - 2 sigma_batch_diag = array_ops.batch_matrix_diag_part(sigma_chol) log_det = 2.0 * math_ops.reduce_sum(math_ops.log(sigma_batch_diag), reduction_indices=det_last_dim) log_det.set_shape(sigma_chol.get_shape()[:-2]) return log_det
def _determinant_from_sigma_chol(sigma_chol): det_last_dim = array_ops.rank(sigma_chol) - 2 sigma_batch_diag = array_ops.batch_matrix_diag_part(sigma_chol) det = math_ops.square(math_ops.reduce_prod( sigma_batch_diag, reduction_indices=det_last_dim)) det.set_shape(sigma_chol.get_shape()[:-2]) return det
def variance(self, name="variance"): """Variance of the Wishart distribution. This function should not be confused with the covariance of the Wishart. The covariance matrix would have shape `q x q` where, `q = dimension * (dimension+1) / 2` and having elements corresponding to some mapping from a lower-triangular matrix to a vector-space. This function returns the diagonal of the Covariance matrix but shaped as a `dimension x dimension` matrix. Args: name: The name of this op. Returns: variance: `Tensor` of dtype `self.dtype`. """ with ops.name_scope(self.name): with ops.name_scope(name, values=list(self.inputs.values())): x = math_ops.sqrt(self.df) * self.scale_operator_pd.to_dense() d = array_ops.expand_dims(array_ops.batch_matrix_diag_part(x), -1) v = math_ops.square(x) + math_ops.batch_matmul(d, d, adj_y=True) if self.cholesky_input_output_matrices: return linalg_ops.batch_cholesky(v) else: return v
def _log_determinant_from_sigma_chol(sigma_chol): det_last_dim = array_ops.rank(sigma_chol) - 2 sigma_batch_diag = array_ops.batch_matrix_diag_part(sigma_chol) log_det = 2.0 * math_ops.reduce_sum( math_ops.log(sigma_batch_diag), reduction_indices=det_last_dim) log_det.set_shape(sigma_chol.get_shape()[:-2]) return log_det
def _assert_batch_positive_definite(sigma_chol): """Add assertions checking that the sigmas are all Positive Definite. Given `sigma_chol == cholesky(sigma)`, it is sufficient to check that `all(diag(sigma_chol) > 0)`. This is because to check that a matrix is PD, it is sufficient that its cholesky factorization is PD, and to check that a triangular matrix is PD, it is sufficient to check that its diagonal entries are positive. Args: sigma_chol: N-D. The lower triangular cholesky decomposition of `sigma`. Returns: An assertion op to use with `control_dependencies`, verifying that `sigma_chol` is positive definite. """ sigma_batch_diag = array_ops.batch_matrix_diag_part(sigma_chol) return logging_ops.Assert( math_ops.reduce_all(sigma_batch_diag > 0), [ "sigma_chol is not positive definite. batched diagonals: ", sigma_batch_diag, " shaped: ", array_ops.shape(sigma_batch_diag), ], )
def _determinant_from_sigma_chol(sigma_chol): det_last_dim = array_ops.rank(sigma_chol) - 2 sigma_batch_diag = array_ops.batch_matrix_diag_part(sigma_chol) det = math_ops.square( math_ops.reduce_prod(sigma_batch_diag, reduction_indices=det_last_dim)) det.set_shape(sigma_chol.get_shape()[:-2]) return det
def variance(self, name='variance'): """Variance of the Wishart distribution. This function should not be confused with the covariance of the Wishart. The covariance matrix would have shape `q x q` where, `q = dimension * (dimension+1) / 2` and having elements corresponding to some mapping from a lower-triangular matrix to a vector-space. This function returns the diagonal of the Covariance matrix but shaped as a `dimension x dimension` matrix. Args: name: The name of this op. Returns: variance: `Tensor` of dtype `self.dtype`. """ with ops.name_scope(self.name): with ops.name_scope(name, values=list(self.inputs.values())): x = math_ops.sqrt(self.df) * self.scale_operator_pd.to_dense() d = array_ops.expand_dims(array_ops.batch_matrix_diag_part(x), -1) v = math_ops.square(x) + math_ops.batch_matmul(d, d, adj_y=True) if self.cholesky_input_output_matrices: return linalg_ops.batch_cholesky(v) else: return v
def batch_matrix_diag_transform(matrix, transform=None, name=None): """Transform diagonal of [batch-]matrix, leave rest of matrix unchanged. Create a trainable covariance defined by a Cholesky factor: ```python # Transform network layer into 2 x 2 array. matrix_values = tf.contrib.layers.fully_connected(activations, 4) matrix = tf.reshape(matrix_values, (batch_size, 2, 2)) # Make the diagonal positive. If the upper triangle was zero, this would be a # valid Cholesky factor. chol = batch_matrix_diag_transform(matrix, transform=tf.nn.softplus) # OperatorPDCholesky ignores the upper triangle. operator = OperatorPDCholesky(chol) ``` Example of heteroskedastic 2-D linear regression. ```python # Get a trainable Cholesky factor. matrix_values = tf.contrib.layers.fully_connected(activations, 4) matrix = tf.reshape(matrix_values, (batch_size, 2, 2)) chol = batch_matrix_diag_transform(matrix, transform=tf.nn.softplus) # Get a trainable mean. mu = tf.contrib.layers.fully_connected(activations, 2) # This is a fully trainable multivariate normal! dist = tf.contrib.distributions.MVNCholesky(mu, chol) # Standard log loss. Minimizing this will "train" mu and chol, and then dist # will be a distribution predicting labels as multivariate Gaussians. loss = -1 * tf.reduce_mean(dist.log_pdf(labels)) ``` Args: matrix: Rank `R` `Tensor`, `R >= 2`, where the last two dimensions are equal. transform: Element-wise function mapping `Tensors` to `Tensors`. To be applied to the diagonal of `matrix`. If `None`, `matrix` is returned unchanged. Defaults to `None`. name: A name to give created ops. Defaults to "batch_matrix_diag_transform". Returns: A `Tensor` with same shape and `dtype` as `matrix`. """ with ops.name_scope(name, 'batch_matrix_diag_transform', [matrix]): matrix = ops.convert_to_tensor(matrix, name='matrix') if transform is None: return matrix # Replace the diag with transformed diag. diag = array_ops.batch_matrix_diag_part(matrix) transformed_diag = transform(diag) transformed_mat = array_ops.batch_matrix_set_diag( matrix, transformed_diag) return transformed_mat
def _variance(self): x = math_ops.sqrt(self.df) * self.scale_operator_pd.to_dense() d = array_ops.expand_dims(array_ops.batch_matrix_diag_part(x), -1) v = math_ops.square(x) + math_ops.batch_matmul(d, d, adj_y=True) if self.cholesky_input_output_matrices: return linalg_ops.cholesky(v) return v
def _variance(self): x = math_ops.sqrt(self.df) * self.scale_operator_pd.to_dense() d = array_ops.expand_dims(array_ops.batch_matrix_diag_part(x), -1) v = math_ops.square(x) + math_ops.batch_matmul(d, d, adj_y=True) if self.cholesky_input_output_matrices: return linalg_ops.batch_cholesky(v) return v
def batch_matrix_diag_transform(matrix, transform=None, name=None): """Transform diagonal of [batch-]matrix, leave rest of matrix unchanged. Create a trainable covariance defined by a Cholesky factor: ```python # Transform network layer into 2 x 2 array. matrix_values = tf.contrib.layers.fully_connected(activations, 4) matrix = tf.reshape(matrix_values, (batch_size, 2, 2)) # Make the diagonal positive. If the upper triangle was zero, this would be a # valid Cholesky factor. chol = batch_matrix_diag_transform(matrix, transform=tf.nn.softplus) # OperatorPDCholesky ignores the upper triangle. operator = OperatorPDCholesky(chol) ``` Example of heteroskedastic 2-D linear regression. ```python # Get a trainable Cholesky factor. matrix_values = tf.contrib.layers.fully_connected(activations, 4) matrix = tf.reshape(matrix_values, (batch_size, 2, 2)) chol = batch_matrix_diag_transform(matrix, transform=tf.nn.softplus) # Get a trainable mean. mu = tf.contrib.layers.fully_connected(activations, 2) # This is a fully trainable multivariate normal! dist = tf.contrib.distributions.MVNCholesky(mu, chol) # Standard log loss. Minimizing this will "train" mu and chol, and then dist # will be a distribution predicting labels as multivariate Gaussians. loss = -1 * tf.reduce_mean(dist.log_pdf(labels)) ``` Args: matrix: Rank `R` `Tensor`, `R >= 2`, where the last two dimensions are equal. transform: Element-wise function mapping `Tensors` to `Tensors`. To be applied to the diagonal of `matrix`. If `None`, `matrix` is returned unchanged. Defaults to `None`. name: A name to give created ops. Defaults to "batch_matrix_diag_transform". Returns: A `Tensor` with same shape and `dtype` as `matrix`. """ with ops.name_scope(name, "batch_matrix_diag_transform", [matrix]): matrix = ops.convert_to_tensor(matrix, name="matrix") if transform is None: return matrix # Replace the diag with transformed diag. diag = array_ops.batch_matrix_diag_part(matrix) transformed_diag = transform(diag) transformed_mat = array_ops.batch_matrix_set_diag(matrix, transformed_diag) return transformed_mat
def _sqrt_log_det(self): # The matrix determinant lemma states: # det(M + VDV^T) = det(D^{-1} + V^T M^{-1} V) * det(D) * det(M) # = det(C) * det(D) * det(M) # # Here we compute the Cholesky factor of "C", then pass the result on. diag_chol_c = array_ops.batch_matrix_diag_part(self._chol_capacitance(batch_mode=False)) return self._sqrt_log_det_core(diag_chol_c)
def _sqrt_log_det(self): # The matrix determinant lemma states: # det(M + VDV^T) = det(D^{-1} + V^T M^{-1} V) * det(D) * det(M) # = det(C) * det(D) * det(M) # # Here we compute the Cholesky factor of "C", then pass the result on. diag_chol_c = array_ops.batch_matrix_diag_part(self._chol_capacitance( batch_mode=False)) return self._sqrt_log_det_core(diag_chol_c)
def _batch_log_det(self): """Log determinant of every batch member.""" # Note that array_ops.diag_part does not seem more efficient for non-batch, # and would give a bad result for a batch matrix, so aways use # batch_matrix_diag_part. diag = array_ops.batch_matrix_diag_part(self._chol) det = 2.0 * math_ops.reduce_sum(math_ops.log(diag), reduction_indices=[-1]) det.set_shape(self.get_shape()[:-2]) return det
def _BatchMatrixSetDiagGrad(op, grad): diag_shape = op.inputs[1].get_shape() diag_shape = diag_shape.merge_with(op.inputs[0].get_shape()[:-1]) diag_shape = diag_shape.merge_with(grad.get_shape()[:-1]) if diag_shape.is_fully_defined(): diag_shape = diag_shape.as_list() else: diag_shape = array_ops.shape(grad) diag_shape = array_ops.slice(diag_shape, [0], [array_ops.rank(grad) - 1]) grad_input = array_ops.batch_matrix_set_diag( grad, array_ops.zeros(diag_shape, dtype=grad.dtype)) grad_diag = array_ops.batch_matrix_diag_part(grad) return (grad_input, grad_diag)
def _check_chol(self, chol): """Verify that `chol` is proper.""" chol = ops.convert_to_tensor(chol, name='chol') if not self.verify_pd: return chol shape = array_ops.shape(chol) rank = array_ops.rank(chol) is_matrix = check_ops.assert_rank_at_least(chol, 2) is_square = check_ops.assert_equal(array_ops.gather(shape, rank - 2), array_ops.gather(shape, rank - 1)) deps = [is_matrix, is_square] diag = array_ops.batch_matrix_diag_part(chol) deps.append(check_ops.assert_positive(diag)) return control_flow_ops.with_dependencies(deps, chol)
def _check_chol(self, chol): """Verify that `chol` is proper.""" chol = ops.convert_to_tensor(chol, name='chol') if not self.verify_pd: return chol shape = array_ops.shape(chol) rank = array_ops.rank(chol) is_matrix = check_ops.assert_rank_at_least(chol, 2) is_square = check_ops.assert_equal( array_ops.gather(shape, rank - 2), array_ops.gather(shape, rank - 1)) deps = [is_matrix, is_square] diag = array_ops.batch_matrix_diag_part(chol) deps.append(check_ops.assert_positive(diag)) return control_flow_ops.with_dependencies(deps, chol)
def __init__(self, chol, verify_pd=True, name='OperatorPDCholesky'): """Initialize an OperatorPDCholesky. Args: chol: Shape `[N1,...,Nb, k, k]` tensor with `b >= 0`, `k >= 1`, and positive diagonal elements. The strict upper triangle of `chol` is never used, and the user may set these elements to zero, or ignore them. verify_pd: Whether to check that `chol` has positive diagonal (this is equivalent to it being a Cholesky factor of a symmetric positive definite matrix. If `verify_pd` is `False`, correct behavior is not guaranteed. name: A name to prepend to all ops created by this class. """ self._verify_pd = verify_pd self._name = name with ops.name_scope(name): with ops.op_scope([chol], 'init'): self._diag = array_ops.batch_matrix_diag_part(chol) self._chol = self._check_chol(chol)
def _assert_batch_positive_definite(sigma_chol): """Add assertions checking that the sigmas are all Positive Definite. Given `sigma_chol == cholesky(sigma)`, it is sufficient to check that `all(diag(sigma_chol) > 0)`. This is because to check that a matrix is PD, it is sufficient that its cholesky factorization is PD, and to check that a triangular matrix is PD, it is sufficient to check that its diagonal entries are positive. Args: sigma_chol: N-D. The lower triangular cholesky decomposition of `sigma`. Returns: An assertion op to use with `control_dependencies`, verifying that `sigma_chol` is positive definite. """ sigma_batch_diag = array_ops.batch_matrix_diag_part(sigma_chol) return logging_ops.Assert( math_ops.reduce_all(sigma_batch_diag > 0), ["sigma_chol is not positive definite. batched diagonals: ", sigma_batch_diag, " shaped: ", array_ops.shape(sigma_batch_diag)])
def _add_to_tensor(self, mat): # Add to a tensor in O(k) time! mat_diag = array_ops.batch_matrix_diag_part(mat) new_diag = constant_op.constant(1, dtype=self.dtype) + mat_diag return array_ops.batch_matrix_set_diag(mat, new_diag)
def _inverse_log_det_jacobian(self, x): # pylint: disable=unused-argument return -math_ops.reduce_sum(math_ops.log( array_ops.batch_matrix_diag_part(self.scale)), reduction_indices=[-1])
def _BatchMatrixDiagGrad(_, grad): return array_ops.batch_matrix_diag_part(grad)
def log_prob(self, x, name='log_prob'): """Log of the probability density/mass function. Args: x: `float` or `double` `Tensor`. name: The name to give this op. Returns: log_prob: a `Tensor` of shape `sample_shape(x) + self.batch_shape` with values of type `self.dtype`. """ with ops.name_scope(self.name): with ops.name_scope(name, values=[x] + list(self.inputs.values())): x = ops.convert_to_tensor(x, name='x') contrib_tensor_util.assert_same_float_dtype( (self.scale_operator_pd, x)) if self.cholesky_input_output_matrices: x_sqrt = x else: # Complexity: O(nbk^3) x_sqrt = linalg_ops.batch_cholesky(x) batch_shape = self.batch_shape() event_shape = self.event_shape() ndims = array_ops.rank(x_sqrt) # sample_ndims = ndims - batch_ndims - event_ndims sample_ndims = ndims - array_ops.shape(batch_shape)[0] - 2 sample_shape = array_ops.slice( array_ops.shape(x_sqrt), [0], [sample_ndims]) # We need to be able to pre-multiply each matrix by its corresponding # batch scale matrix. Since a Distribution Tensor supports multiple # samples per batch, this means we need to reshape the input matrix `x` # so that the first b dimensions are batch dimensions and the last two # are of shape [dimension, dimensions*number_of_samples]. Doing these # gymnastics allows us to do a batch_solve. # # After we're done with sqrt_solve (the batch operation) we need to undo # this reshaping so what we're left with is a Tensor partitionable by # sample, batch, event dimensions. # Complexity: O(nbk^2) since transpose must access every element. scale_sqrt_inv_x_sqrt = x_sqrt perm = array_ops.concat(0, (math_ops.range(sample_ndims, ndims), math_ops.range(0, sample_ndims))) scale_sqrt_inv_x_sqrt = array_ops.transpose(scale_sqrt_inv_x_sqrt, perm) shape = array_ops.concat( 0, (batch_shape, (math_ops.cast(self.dimension, dtype=dtypes.int32), -1))) scale_sqrt_inv_x_sqrt = array_ops.reshape(scale_sqrt_inv_x_sqrt, shape) # Complexity: O(nbM*k) where M is the complexity of the operator solving # a vector system. E.g., for OperatorPDDiag, each solve is O(k), so # this complexity is O(nbk^2). For OperatorPDCholesky, each solve is # O(k^2) so this step has complexity O(nbk^3). scale_sqrt_inv_x_sqrt = self.scale_operator_pd.sqrt_solve( scale_sqrt_inv_x_sqrt) # Undo make batch-op ready. # Complexity: O(nbk^2) shape = array_ops.concat(0, (batch_shape, event_shape, sample_shape)) scale_sqrt_inv_x_sqrt = array_ops.reshape(scale_sqrt_inv_x_sqrt, shape) perm = array_ops.concat(0, (math_ops.range(ndims - sample_ndims, ndims), math_ops.range(0, ndims - sample_ndims))) scale_sqrt_inv_x_sqrt = array_ops.transpose(scale_sqrt_inv_x_sqrt, perm) # Write V = SS', X = LL'. Then: # tr[inv(V) X] = tr[inv(S)' inv(S) L L'] # = tr[inv(S) L L' inv(S)'] # = tr[(inv(S) L) (inv(S) L)'] # = sum_{ik} (inv(S) L)_{ik}^2 # The second equality follows from the cyclic permutation property. # Complexity: O(nbk^2) trace_scale_inv_x = math_ops.reduce_sum( math_ops.square(scale_sqrt_inv_x_sqrt), reduction_indices=[-2, -1]) # Complexity: O(nbk) half_log_det_x = math_ops.reduce_sum( math_ops.log(array_ops.batch_matrix_diag_part(x_sqrt)), reduction_indices=[-1]) # Complexity: O(nbk^2) log_prob = ((self.df - self.dimension - 1.) * half_log_det_x - 0.5 * trace_scale_inv_x - self.log_normalizing_constant()) # Set shape hints. # Try to merge what we know from the input then what we know from the # parameters of this distribution. if x.get_shape().ndims is not None: log_prob.set_shape(x.get_shape()[:-2]) if (log_prob.get_shape().ndims is not None and self.get_batch_shape().ndims is not None and self.get_batch_shape().ndims > 0): log_prob.get_shape()[-self.get_batch_shape().ndims:].merge_with( self.get_batch_shape()) return log_prob
def _add_to_tensor(self, mat): mat_diag = array_ops.batch_matrix_diag_part(mat) new_diag = math_ops.square(self._diag) + mat_diag return array_ops.batch_matrix_set_diag(mat, new_diag)
def _batch_sqrt_log_det(self): # Here we compute the Cholesky factor of "C", then pass the result on. diag_chol_c = array_ops.batch_matrix_diag_part(self._chol_capacitance( batch_mode=True)) return self._sqrt_log_det_core(diag_chol_c)
def _log_prob(self, x): if self.cholesky_input_output_matrices: x_sqrt = x else: # Complexity: O(nbk^3) x_sqrt = linalg_ops.batch_cholesky(x) batch_shape = self.batch_shape() event_shape = self.event_shape() ndims = array_ops.rank(x_sqrt) # sample_ndims = ndims - batch_ndims - event_ndims sample_ndims = ndims - array_ops.shape(batch_shape)[0] - 2 sample_shape = array_ops.slice( array_ops.shape(x_sqrt), [0], [sample_ndims]) # We need to be able to pre-multiply each matrix by its corresponding # batch scale matrix. Since a Distribution Tensor supports multiple # samples per batch, this means we need to reshape the input matrix `x` # so that the first b dimensions are batch dimensions and the last two # are of shape [dimension, dimensions*number_of_samples]. Doing these # gymnastics allows us to do a batch_solve. # # After we're done with sqrt_solve (the batch operation) we need to undo # this reshaping so what we're left with is a Tensor partitionable by # sample, batch, event dimensions. # Complexity: O(nbk^2) since transpose must access every element. scale_sqrt_inv_x_sqrt = x_sqrt perm = array_ops.concat(0, (math_ops.range(sample_ndims, ndims), math_ops.range(0, sample_ndims))) scale_sqrt_inv_x_sqrt = array_ops.transpose(scale_sqrt_inv_x_sqrt, perm) shape = array_ops.concat( 0, (batch_shape, (math_ops.cast(self.dimension, dtype=dtypes.int32), -1))) scale_sqrt_inv_x_sqrt = array_ops.reshape(scale_sqrt_inv_x_sqrt, shape) # Complexity: O(nbM*k) where M is the complexity of the operator solving # a vector system. E.g., for OperatorPDDiag, each solve is O(k), so # this complexity is O(nbk^2). For OperatorPDCholesky, each solve is # O(k^2) so this step has complexity O(nbk^3). scale_sqrt_inv_x_sqrt = self.scale_operator_pd.sqrt_solve( scale_sqrt_inv_x_sqrt) # Undo make batch-op ready. # Complexity: O(nbk^2) shape = array_ops.concat(0, (batch_shape, event_shape, sample_shape)) scale_sqrt_inv_x_sqrt = array_ops.reshape(scale_sqrt_inv_x_sqrt, shape) perm = array_ops.concat(0, (math_ops.range(ndims - sample_ndims, ndims), math_ops.range(0, ndims - sample_ndims))) scale_sqrt_inv_x_sqrt = array_ops.transpose(scale_sqrt_inv_x_sqrt, perm) # Write V = SS', X = LL'. Then: # tr[inv(V) X] = tr[inv(S)' inv(S) L L'] # = tr[inv(S) L L' inv(S)'] # = tr[(inv(S) L) (inv(S) L)'] # = sum_{ik} (inv(S) L)_{ik}^2 # The second equality follows from the cyclic permutation property. # Complexity: O(nbk^2) trace_scale_inv_x = math_ops.reduce_sum( math_ops.square(scale_sqrt_inv_x_sqrt), reduction_indices=[-2, -1]) # Complexity: O(nbk) half_log_det_x = math_ops.reduce_sum( math_ops.log(array_ops.batch_matrix_diag_part(x_sqrt)), reduction_indices=[-1]) # Complexity: O(nbk^2) log_prob = ((self.df - self.dimension - 1.) * half_log_det_x - 0.5 * trace_scale_inv_x - self.log_normalizing_constant()) # Set shape hints. # Try to merge what we know from the input then what we know from the # parameters of this distribution. if x.get_shape().ndims is not None: log_prob.set_shape(x.get_shape()[:-2]) if (log_prob.get_shape().ndims is not None and self.get_batch_shape().ndims is not None and self.get_batch_shape().ndims > 0): log_prob.get_shape()[-self.get_batch_shape().ndims:].merge_with( self.get_batch_shape()) return log_prob
def _inverse_log_det_jacobian(self, x): # pylint: disable=unused-argument return -math_ops.reduce_sum( math_ops.log(array_ops.batch_matrix_diag_part(self.scale)), reduction_indices=[-1])