def _MatrixTriangularSolveGrad(op, grad): """Gradient for MatrixTriangularSolve.""" a = op.inputs[0] adjoint_a = op.get_attr("adjoint") lower_a = op.get_attr("lower") c = op.outputs[0] grad_b = linalg_ops.matrix_triangular_solve( a, grad, lower=lower_a, adjoint=not adjoint_a) if adjoint_a: grad_a = -math_ops.batch_matmul(c, grad_b, adj_y=True) else: grad_a = -math_ops.batch_matmul(grad_b, c, adj_y=True) if lower_a: grad_a = array_ops.batch_matrix_band_part(grad_a, -1, 0) else: grad_a = array_ops.batch_matrix_band_part(grad_a, 0, -1) return (grad_a, grad_b)
def _batch_matmul(self, x, transpose_x=False): # tf.batch_matmul is defined x * y, so "y" is on the right, not "x". chol = array_ops.batch_matrix_band_part(self._chol, -1, 0) chol_times_x = math_ops.batch_matmul(chol, x, adj_x=True, adj_y=transpose_x) return math_ops.batch_matmul(chol, chol_times_x)
def _matmul(self, x, transpose_x=False): # tf.matmul is defined a * b. chol = array_ops.batch_matrix_band_part(self._chol, -1, 0) chol_times_x = math_ops.matmul(chol, x, transpose_a=True, transpose_b=transpose_x) return math_ops.matmul(chol, chol_times_x)
def _sample_n(self, n, seed): batch_shape = self.batch_shape() event_shape = self.event_shape() batch_ndims = array_ops.shape(batch_shape)[0] ndims = batch_ndims + 3 # sample_ndims=1, event_ndims=2 shape = array_ops.concat(0, ((n, ), batch_shape, event_shape)) # Complexity: O(nbk^2) x = random_ops.random_normal(shape=shape, mean=0., stddev=1., dtype=self.dtype, seed=seed) # Complexity: O(nbk) # This parametrization is equivalent to Chi2, i.e., # ChiSquared(k) == Gamma(alpha=k/2, beta=1/2) g = random_ops.random_gamma(shape=(n, ), alpha=self._multi_gamma_sequence( 0.5 * self.df, self.dimension), beta=0.5, dtype=self.dtype, seed=seed) # Complexity: O(nbk^2) x = array_ops.batch_matrix_band_part(x, -1, 0) # Tri-lower. # Complexity: O(nbk) x = array_ops.batch_matrix_set_diag(x, math_ops.sqrt(g)) # Make batch-op ready. # Complexity: O(nbk^2) perm = array_ops.concat(0, (math_ops.range(1, ndims), (0, ))) x = array_ops.transpose(x, perm) shape = array_ops.concat(0, (batch_shape, (event_shape[0], -1))) x = array_ops.reshape(x, shape) # Complexity: O(nbM) where M is the complexity of the operator solving a # vector system. E.g., for OperatorPDDiag, each matmul is O(k^2), so # this complexity is O(nbk^2). For OperatorPDCholesky, each matmul is # O(k^3) so this step has complexity O(nbk^3). x = self.scale_operator_pd.sqrt_matmul(x) # Undo make batch-op ready. # Complexity: O(nbk^2) shape = array_ops.concat(0, (batch_shape, event_shape, (n, ))) x = array_ops.reshape(x, shape) perm = array_ops.concat(0, ((ndims - 1, ), math_ops.range(0, ndims - 1))) x = array_ops.transpose(x, perm) if not self.cholesky_input_output_matrices: # Complexity: O(nbk^3) x = math_ops.batch_matmul(x, x, adj_y=True) return x
def _sample_n(self, n, seed): batch_shape = self.batch_shape() event_shape = self.event_shape() batch_ndims = array_ops.shape(batch_shape)[0] ndims = batch_ndims + 3 # sample_ndims=1, event_ndims=2 shape = array_ops.concat(0, ((n,), batch_shape, event_shape)) # Complexity: O(nbk^2) x = random_ops.random_normal(shape=shape, mean=0., stddev=1., dtype=self.dtype, seed=seed) # Complexity: O(nbk) # This parametrization is equivalent to Chi2, i.e., # ChiSquared(k) == Gamma(alpha=k/2, beta=1/2) g = random_ops.random_gamma(shape=(n,), alpha=self._multi_gamma_sequence( 0.5 * self.df, self.dimension), beta=0.5, dtype=self.dtype, seed=seed) # Complexity: O(nbk^2) x = array_ops.batch_matrix_band_part(x, -1, 0) # Tri-lower. # Complexity: O(nbk) x = array_ops.batch_matrix_set_diag(x, math_ops.sqrt(g)) # Make batch-op ready. # Complexity: O(nbk^2) perm = array_ops.concat(0, (math_ops.range(1, ndims), (0,))) x = array_ops.transpose(x, perm) shape = array_ops.concat(0, (batch_shape, (event_shape[0], -1))) x = array_ops.reshape(x, shape) # Complexity: O(nbM) where M is the complexity of the operator solving a # vector system. E.g., for OperatorPDDiag, each matmul is O(k^2), so # this complexity is O(nbk^2). For OperatorPDCholesky, each matmul is # O(k^3) so this step has complexity O(nbk^3). x = self.scale_operator_pd.sqrt_matmul(x) # Undo make batch-op ready. # Complexity: O(nbk^2) shape = array_ops.concat(0, (batch_shape, event_shape, (n,))) x = array_ops.reshape(x, shape) perm = array_ops.concat(0, ((ndims-1,), math_ops.range(0, ndims-1))) x = array_ops.transpose(x, perm) if not self.cholesky_input_output_matrices: # Complexity: O(nbk^3) x = math_ops.batch_matmul(x, x, adj_y=True) return x
def _BatchMatrixTriangularSolveGrad(op, grad): """Gradient for BatchMatrixTriangularSolve.""" a = op.inputs[0] adjoint_a = op.get_attr("adjoint") lower_a = op.get_attr("lower") c = op.outputs[0] grad_b = linalg_ops.batch_matrix_triangular_solve(a, grad, lower=lower_a, adjoint=not adjoint_a) if adjoint_a: grad_a = -math_ops.batch_matmul(c, grad_b, adj_y=True) else: grad_a = -math_ops.batch_matmul(grad_b, c, adj_y=True) if lower_a: grad_a = array_ops.batch_matrix_band_part(grad_a, -1, 0) else: grad_a = array_ops.batch_matrix_band_part(grad_a, 0, -1) return (grad_a, grad_b)
def sqrt_matmul(self, x, name='sqrt_matmul'): """Left (batch) matmul `x` by a sqrt of this matrix: `Sx` where `A = S S^T. Args: x: `Tensor` with shape broadcastable to `[N1,...,Nb, k]` and same `dtype` as self. name: A name scope to use for ops added by this method. Returns: Shape `[N1,...,Nb, k]` `Tensor` holding the product `S x`. """ with ops.name_scope(self.name): with ops.op_scope([x] + self.inputs, name): chol_lower = array_ops.batch_matrix_band_part(self._chol, -1, 0) return math_ops.batch_matmul(chol_lower, x)
def sqrt_matmul(self, x, name='sqrt_matmul'): """Left (batch) matmul `x` by a sqrt of this matrix: `Sx` where `A = S S^T. Args: x: `Tensor` with shape broadcastable to `[N1,...,Nb, k]` and same `dtype` as self. name: A name scope to use for ops added by this method. Returns: Shape `[N1,...,Nb, k]` `Tensor` holding the product `S x`. """ with ops.name_scope(self.name): with ops.op_scope([x] + self.inputs, name): chol_lower = array_ops.batch_matrix_band_part( self._chol, -1, 0) return math_ops.batch_matmul(chol_lower, x)
def __init__(self, mu, sigma=None, sigma_chol=None, name=None): """Multivariate Normal distributions on `R^k`. User must provide means `mu`, which are tensors of rank `N+1` (`N >= 0`) with the last dimension having length `k`. User must provide exactly one of `sigma` (the covariance matrices) or `sigma_chol` (the cholesky decompositions of the covariance matrices). `sigma` or `sigma_chol` must be of rank `N+2`. The last two dimensions must both have length `k`. The first `N` dimensions correspond to batch indices. If `sigma_chol` is not provided, the batch cholesky factorization of `sigma` is calculated for you. The shapes of `mu` and `sigma` must match for the first `N` dimensions. Regardless of which parameter is provided, the covariance matrices must all be **positive definite** (an error is raised if one of them is not). Args: mu: (N+1)-D. `float` or `double` tensor, the means of the distributions. sigma: (N+2)-D. (optional) `float` or `double` tensor, the covariances of the distribution(s). The first `N+1` dimensions must match those of `mu`. Must be batch-positive-definite. sigma_chol: (N+2)-D. (optional) `float` or `double` tensor, a lower-triangular factorization of `sigma` (`sigma = sigma_chol . sigma_chol^*`). The first `N+1` dimensions must match those of `mu`. The tensor itself need not be batch lower triangular: we ignore the upper triangular part. However, the batch diagonals must be positive (i.e., sigma_chol must be batch-positive-definite). name: The name to give Ops created by the initializer. Raises: ValueError: if neither sigma nor sigma_chol is provided. TypeError: if mu and sigma (resp. sigma_chol) are different dtypes. """ if (sigma is None) == (sigma_chol is None): raise ValueError("Exactly one of sigma and sigma_chol must be provided") with ops.op_scope([mu, sigma, sigma_chol], name, "MultivariateNormal"): sigma_or_half = sigma_chol if sigma is None else sigma mu = ops.convert_to_tensor(mu) sigma_or_half = ops.convert_to_tensor(sigma_or_half) contrib_tensor_util.assert_same_float_dtype((mu, sigma_or_half)) with ops.control_dependencies([ _assert_compatible_shapes(mu, sigma_or_half)]): mu = array_ops.identity(mu, name="mu") # Store the dimensionality of the MVNs self._k = array_ops.gather(array_ops.shape(mu), array_ops.rank(mu) - 1) if sigma_chol is not None: # Ensure we only keep the lower triangular part. sigma_chol = array_ops.batch_matrix_band_part( sigma_chol, num_lower=-1, num_upper=0) sigma_det = _determinant_from_sigma_chol(sigma_chol) with ops.control_dependencies([ _assert_batch_positive_definite(sigma_chol)]): self._sigma = math_ops.batch_matmul( sigma_chol, sigma_chol, adj_y=True, name="sigma") self._sigma_chol = array_ops.identity(sigma_chol, "sigma_chol") self._sigma_det = array_ops.identity(sigma_det, "sigma_det") self._mu = array_ops.identity(mu, "mu") else: # sigma is not None sigma_chol = linalg_ops.batch_cholesky(sigma) sigma_det = _determinant_from_sigma_chol(sigma_chol) # batch_cholesky checks for PSD; so we can just use it here. with ops.control_dependencies([sigma_chol]): self._sigma = array_ops.identity(sigma, "sigma") self._sigma_chol = array_ops.identity(sigma_chol, "sigma_chol") self._sigma_det = array_ops.identity(sigma_det, "sigma_det") self._mu = array_ops.identity(mu, "mu")
def sample_n(self, n, seed=None, name='sample'): # pylint: disable=line-too-long """Generate `n` samples. Complexity: O(nbk^3) The sampling procedure is based on the [Bartlett decomposition]( https://en.wikipedia.org/wiki/Wishart_distribution#Bartlett_decomposition) and [using a Gamma distribution to generate Chi2 random variates]( https://en.wikipedia.org/wiki/Chi-squared_distribution#Gamma.2C_exponential.2C_and_related_distributions). Args: n: `Scalar` `Tensor` of type `int32` or `int64`, the number of observations to sample. seed: Python integer; random number generator seed. name: The name of this op. Returns: samples: a `Tensor` of shape `(n,) + self.batch_shape + self.event_shape` with values of type `self.dtype`. """ with ops.name_scope(self.name): with ops.name_scope(name, values=[n] + list(self.inputs.values())): n = ops.convert_to_tensor(n, name='n') if n.dtype != dtypes.int32: raise TypeError('n.dtype=%s which is not int32' % n.dtype) batch_shape = self.batch_shape() event_shape = self.event_shape() batch_ndims = array_ops.shape(batch_shape)[0] ndims = batch_ndims + 3 # sample_ndims=1, event_ndims=2 shape = array_ops.concat(0, ((n,), batch_shape, event_shape)) # Complexity: O(nbk^2) x = random_ops.random_normal(shape=shape, mean=0., stddev=1., dtype=self.dtype, seed=seed) # Complexity: O(nbk) # This parametrization is equivalent to Chi2, i.e., # ChiSquared(k) == Gamma(alpha=k/2, beta=1/2) g = random_ops.random_gamma(shape=(n,), alpha=self._multi_gamma_sequence( 0.5 * self.df, self.dimension), beta=0.5, dtype=self.dtype, seed=seed) # Complexity: O(nbk^2) x = array_ops.batch_matrix_band_part(x, -1, 0) # Tri-lower. # Complexity: O(nbk) x = array_ops.batch_matrix_set_diag(x, math_ops.sqrt(g)) # Make batch-op ready. # Complexity: O(nbk^2) perm = array_ops.concat(0, (math_ops.range(1, ndims), (0,))) x = array_ops.transpose(x, perm) shape = array_ops.concat(0, (batch_shape, (event_shape[0], -1))) x = array_ops.reshape(x, shape) # Complexity: O(nbM) where M is the complexity of the operator solving a # vector system. E.g., for OperatorPDDiag, each matmul is O(k^2), so # this complexity is O(nbk^2). For OperatorPDCholesky, each matmul is # O(k^3) so this step has complexity O(nbk^3). x = self.scale_operator_pd.sqrt_matmul(x) # Undo make batch-op ready. # Complexity: O(nbk^2) shape = array_ops.concat(0, (batch_shape, event_shape, (n,))) x = array_ops.reshape(x, shape) perm = array_ops.concat(0, ((ndims-1,), math_ops.range(0, ndims-1))) x = array_ops.transpose(x, perm) if not self.cholesky_input_output_matrices: # Complexity: O(nbk^3) x = math_ops.batch_matmul(x, x, adj_y=True) # Set shape hints. if self.scale_operator_pd.get_shape().ndims is not None: x.set_shape(tensor_shape.TensorShape( [tensor_util.constant_value(n)] + self.scale_operator_pd.get_shape().as_list())) elif x.get_shape().ndims is not None: x.get_shape()[0].merge_with( tensor_shape.TensorDimension(tensor_util.constant_value(n))) return x
def _BatchMatrixBandPartGrad(op, grad): num_lower = op.inputs[1] num_upper = op.inputs[2] return (array_ops.batch_matrix_band_part(grad, num_lower, num_upper), None, None)
def _matmul(self, x, transpose_x=False): # tf.matmul is defined a * b. chol = array_ops.batch_matrix_band_part(self._chol, -1, 0) chol_times_x = math_ops.matmul( chol, x, transpose_a=True, transpose_b=transpose_x) return math_ops.matmul(chol, chol_times_x)
def _to_dense(self): chol = array_ops.batch_matrix_band_part(self._chol, -1, 0) return math_ops.batch_matmul(chol, chol, adj_y=True)
def _sqrt_to_dense(self): chol = array_ops.batch_matrix_band_part(self._chol, -1, 0) return array_ops.identity(chol)
def sample_n(self, n, seed=None, name="sample"): # pylint: disable=line-too-long """Generate `n` samples. Complexity: O(nbk^3) The sampling procedure is based on the [Bartlett decomposition]( https://en.wikipedia.org/wiki/Wishart_distribution#Bartlett_decomposition) and [using a Gamma distribution to generate Chi2 random variates]( https://en.wikipedia.org/wiki/Chi-squared_distribution#Gamma.2C_exponential.2C_and_related_distributions). Args: n: `Scalar` `Tensor` of type `int32` or `int64`, the number of observations to sample. seed: Python integer; random number generator seed. name: The name of this op. Returns: samples: a `Tensor` of shape `(n,) + self.batch_shape + self.event_shape` with values of type `self.dtype`. """ with ops.name_scope(self.name): with ops.name_scope(name, values=[n] + list(self.inputs.values())): n = ops.convert_to_tensor(n, name="n") if n.dtype != dtypes.int32: raise TypeError("n.dtype=%s which is not int32" % n.dtype) batch_shape = self.batch_shape() event_shape = self.event_shape() batch_ndims = array_ops.shape(batch_shape)[0] ndims = batch_ndims + 3 # sample_ndims=1, event_ndims=2 shape = array_ops.concat(0, ((n,), batch_shape, event_shape)) # Complexity: O(nbk^2) x = random_ops.random_normal(shape=shape, mean=0.0, stddev=1.0, dtype=self.dtype, seed=seed) # Complexity: O(nbk) # This parametrization is equivalent to Chi2, i.e., # ChiSquared(k) == Gamma(alpha=k/2, beta=1/2) g = random_ops.random_gamma( shape=(n,), alpha=self._multi_gamma_sequence(0.5 * self.df, self.dimension), beta=0.5, dtype=self.dtype, seed=seed, ) # Complexity: O(nbk^2) x = array_ops.batch_matrix_band_part(x, -1, 0) # Tri-lower. # Complexity: O(nbk) x = array_ops.batch_matrix_set_diag(x, math_ops.sqrt(g)) # Make batch-op ready. # Complexity: O(nbk^2) perm = array_ops.concat(0, (math_ops.range(1, ndims), (0,))) x = array_ops.transpose(x, perm) shape = array_ops.concat(0, (batch_shape, (event_shape[0], -1))) x = array_ops.reshape(x, shape) # Complexity: O(nbM) where M is the complexity of the operator solving a # vector system. E.g., for OperatorPDDiag, each matmul is O(k^2), so # this complexity is O(nbk^2). For OperatorPDCholesky, each matmul is # O(k^3) so this step has complexity O(nbk^3). x = self.scale_operator_pd.sqrt_matmul(x) # Undo make batch-op ready. # Complexity: O(nbk^2) shape = array_ops.concat(0, (batch_shape, event_shape, (n,))) x = array_ops.reshape(x, shape) perm = array_ops.concat(0, ((ndims - 1,), math_ops.range(0, ndims - 1))) x = array_ops.transpose(x, perm) if not self.cholesky_input_output_matrices: # Complexity: O(nbk^3) x = math_ops.batch_matmul(x, x, adj_y=True) # Set shape hints. if self.scale_operator_pd.get_shape().ndims is not None: x.set_shape( tensor_shape.TensorShape( [tensor_util.constant_value(n)] + self.scale_operator_pd.get_shape().as_list() ) ) elif x.get_shape().ndims is not None: x.get_shape()[0].merge_with(tensor_shape.TensorDimension(tensor_util.constant_value(n))) return x
def _sqrt_matmul(self, x, transpose_x=False): chol = array_ops.batch_matrix_band_part(self._chol, -1, 0) # tf.matmul is defined a * b return math_ops.matmul(chol, x, transpose_b=transpose_x)
def _batch_matmul(self, x, transpose_x=False): # tf.batch_matmul is defined x * y, so "y" is on the right, not "x". chol = array_ops.batch_matrix_band_part(self._chol, -1, 0) chol_times_x = math_ops.batch_matmul( chol, x, adj_x=True, adj_y=transpose_x) return math_ops.batch_matmul(chol, chol_times_x)
def _batch_sqrt_matmul(self, x, transpose_x=False): chol = array_ops.batch_matrix_band_part(self._chol, -1, 0) # tf.batch_matmul is defined x * y, so "y" is on the right, not "x". return math_ops.batch_matmul(chol, x, adj_y=transpose_x)
def __init__(self, mu, sigma=None, sigma_chol=None, name=None): """Multivariate Normal distributions on `R^k`. User must provide means `mu`, which are tensors of rank `N+1` (`N >= 0`) with the last dimension having length `k`. User must provide exactly one of `sigma` (the covariance matrices) or `sigma_chol` (the cholesky decompositions of the covariance matrices). `sigma` or `sigma_chol` must be of rank `N+2`. The last two dimensions must both have length `k`. The first `N` dimensions correspond to batch indices. If `sigma_chol` is not provided, the batch cholesky factorization of `sigma` is calculated for you. The shapes of `mu` and `sigma` must match for the first `N` dimensions. Regardless of which parameter is provided, the covariance matrices must all be **positive definite** (an error is raised if one of them is not). Args: mu: (N+1)-D. `float` or `double` tensor, the means of the distributions. sigma: (N+2)-D. (optional) `float` or `double` tensor, the covariances of the distribution(s). The first `N+1` dimensions must match those of `mu`. Must be batch-positive-definite. sigma_chol: (N+2)-D. (optional) `float` or `double` tensor, a lower-triangular factorization of `sigma` (`sigma = sigma_chol . sigma_chol^*`). The first `N+1` dimensions must match those of `mu`. The tensor itself need not be batch lower triangular: we ignore the upper triangular part. However, the batch diagonals must be positive (i.e., sigma_chol must be batch-positive-definite). name: The name to give Ops created by the initializer. Raises: ValueError: if neither sigma nor sigma_chol is provided. TypeError: if mu and sigma (resp. sigma_chol) are different dtypes. """ if (sigma is None) == (sigma_chol is None): raise ValueError( "Exactly one of sigma and sigma_chol must be provided") with ops.op_scope([mu, sigma, sigma_chol], name, "MultivariateNormal"): sigma_or_half = sigma_chol if sigma is None else sigma mu = ops.convert_to_tensor(mu) sigma_or_half = ops.convert_to_tensor(sigma_or_half) contrib_tensor_util.assert_same_float_dtype((mu, sigma_or_half)) with ops.control_dependencies( [_assert_compatible_shapes(mu, sigma_or_half)]): mu = array_ops.identity(mu, name="mu") # Store the dimensionality of the MVNs self._k = array_ops.gather(array_ops.shape(mu), array_ops.rank(mu) - 1) if sigma_chol is not None: # Ensure we only keep the lower triangular part. sigma_chol = array_ops.batch_matrix_band_part(sigma_chol, num_lower=-1, num_upper=0) sigma_det = _determinant_from_sigma_chol(sigma_chol) with ops.control_dependencies( [_assert_batch_positive_definite(sigma_chol)]): self._sigma = math_ops.batch_matmul(sigma_chol, sigma_chol, adj_y=True, name="sigma") self._sigma_chol = array_ops.identity( sigma_chol, "sigma_chol") self._sigma_det = array_ops.identity( sigma_det, "sigma_det") self._mu = array_ops.identity(mu, "mu") else: # sigma is not None sigma_chol = linalg_ops.batch_cholesky(sigma) sigma_det = _determinant_from_sigma_chol(sigma_chol) # batch_cholesky checks for PSD; so we can just use it here. with ops.control_dependencies([sigma_chol]): self._sigma = array_ops.identity(sigma, "sigma") self._sigma_chol = array_ops.identity( sigma_chol, "sigma_chol") self._sigma_det = array_ops.identity( sigma_det, "sigma_det") self._mu = array_ops.identity(mu, "mu")