def benchmarkMatrixBandPartOp(self): for shape_ in self.shapes: for limits in (-1, -1), (-1, 0), (0, -1), (2, 2): with ops.Graph().as_default(), \ session.Session() as sess, \ ops.device("/cpu:0"): matrix = variables.Variable(array_ops.ones(shape_)) band = array_ops.matrix_band_part(matrix, limits[0], limits[1]) variables.global_variables_initializer().run() self.run_op_benchmark( sess, control_flow_ops.group(band), min_iters=10, name="matrix_band_part_cpu_{shape}_{limits}".format( shape=shape_, limits=limits)) if test_lib.is_gpu_available(True): with ops.Graph().as_default(), \ session.Session() as sess, \ ops.device("/gpu:0"): matrix = variables.Variable(array_ops.ones(shape_)) band = array_ops.matrix_band_part(matrix, limits[0], limits[1]) variables.global_variables_initializer().run() self.run_op_benchmark( sess, control_flow_ops.group(band), min_iters=10, name="matrix_band_part_gpu_{shape}_{limits}".format( shape=shape_, limits=limits))
def _verifyLu(self, x, output_idx_type=dtypes.int64): # Verify that Px = LU. lu, perm = linalg_ops.lu(x, output_idx_type=output_idx_type) # Prepare the lower factor of shape num_rows x num_rows lu_shape = np.array(lu.shape.as_list()) batch_shape = lu_shape[:-2] num_rows = lu_shape[-2] num_cols = lu_shape[-1] lower = array_ops.matrix_band_part(lu, -1, 0) if num_rows > num_cols: eye = linalg_ops.eye( num_rows, batch_shape=batch_shape, dtype=lower.dtype) lower = array_ops.concat([lower, eye[..., num_cols:]], axis=-1) elif num_rows < num_cols: lower = lower[..., :num_rows] # Fill the diagonal with ones. ones_diag = array_ops.ones( np.append(batch_shape, num_rows), dtype=lower.dtype) lower = array_ops.matrix_set_diag(lower, ones_diag) # Prepare the upper factor. upper = array_ops.matrix_band_part(lu, 0, -1) verification = math_ops.matmul(lower, upper) # Permute the rows of product of the Cholesky factors. if num_rows > 0: # Reshape the product of the triangular factors and permutation indices # to a single batch dimension. This makes it easy to apply # invert_permutation and gather_nd ops. perm_reshaped = array_ops.reshape(perm, [-1, num_rows]) verification_reshaped = array_ops.reshape(verification, [-1, num_rows, num_cols]) # Invert the permutation in each batch. inv_perm_reshaped = map_fn.map_fn(array_ops.invert_permutation, perm_reshaped) batch_size = perm_reshaped.shape.as_list()[0] # Prepare the batch indices with the same shape as the permutation. # The corresponding batch index is paired with each of the `num_rows` # permutation indices. batch_indices = math_ops.cast( array_ops.broadcast_to( math_ops.range(batch_size)[:, None], perm_reshaped.shape), dtype=output_idx_type) permuted_verification_reshaped = array_ops.gather_nd( verification_reshaped, array_ops.stack([batch_indices, inv_perm_reshaped], axis=-1)) # Reshape the verification matrix back to the original shape. verification = array_ops.reshape(permuted_verification_reshaped, lu_shape) self._verifyLuBase(x, lower, upper, perm, verification, output_idx_type)
def _random_cholesky_array(self, shape): mat = self._rng.rand(*shape) chol = distribution_util.matrix_diag_transform( mat, transform=nn_ops.softplus) # Zero the upper triangle because we're using this as a true Cholesky factor # in our tests. return array_ops.matrix_band_part(chol, -1, 0).eval()
def testMatrixBandPart(self, batch_shape, rows, cols): # TODO(b/125505881): Disabled due to LLVM backend crash. if self.device == 'XLA_CPU' and cols == 7 and rows == 1 and batch_shape == [ 1, 3, 2 ]: pass for dtype in self.float_types: with self.cached_session(): mat = np.ones(batch_shape + [rows, cols]).astype(dtype) batch_mat = np.tile(mat, batch_shape + [1, 1]) for lower in -1, 0, 1, rows - 1: for upper in -1, 0, 1, cols - 1: band_np = mat if lower >= 0: band_np = np.triu(band_np, -lower) if upper >= 0: band_np = np.tril(band_np, upper) if batch_shape: band_np = np.tile(band_np, batch_shape + [1, 1]) placeholder = array_ops.placeholder(dtype) with self.test_scope(): band = array_ops.matrix_band_part( placeholder, constant_op.constant(lower, dtype=dtypes.int32), constant_op.constant(upper, dtype=dtypes.int32)) feed_dict = {placeholder: batch_mat} self.assertAllEqual(band_np, band.eval(feed_dict=feed_dict))
def _QrGrad(op, dq, dr): """Gradient for Qr.""" q, r = op.outputs if q.dtype.is_complex: raise NotImplementedError("QrGrad not implemented for dtype: %s" % q.dtype) if (r.shape.ndims is None or r.shape.as_list()[-2] is None or r.shape.as_list()[-1] is None): raise NotImplementedError("QrGrad not implemented with dynamic shapes.") if r.shape.dims[-2].value != r.shape.dims[-1].value: raise NotImplementedError("QrGrad not implemented when ncols > nrows " "or full_matrices is true and ncols != nrows.") qdq = math_ops.matmul(q, dq, adjoint_a=True) qdq_ = qdq - _linalg.adjoint(qdq) rdr = math_ops.matmul(r, dr, adjoint_b=True) rdr_ = rdr - _linalg.adjoint(rdr) tril = array_ops.matrix_band_part(qdq_ + rdr_, -1, 0) def _TriangularSolve(x, r): """Equiv to matmul(x, adjoint(matrix_inverse(r))) if r is upper-tri.""" return _linalg.adjoint( linalg_ops.matrix_triangular_solve( r, _linalg.adjoint(x), lower=False, adjoint=False)) grad_a = math_ops.matmul(q, dr + _TriangularSolve(tril, r)) grad_b = _TriangularSolve(dq - math_ops.matmul(q, qdq), r) return grad_a + grad_b
def _sample_n(self, n, seed): batch_shape = self.batch_shape_tensor() event_shape = self.event_shape_tensor() batch_ndims = array_ops.shape(batch_shape)[0] ndims = batch_ndims + 3 # sample_ndims=1, event_ndims=2 shape = array_ops.concat([[n], batch_shape, event_shape], 0) # Complexity: O(nbk**2) x = random_ops.random_normal(shape=shape, mean=0., stddev=1., dtype=self.dtype, seed=seed) # Complexity: O(nbk) # This parametrization is equivalent to Chi2, i.e., # ChiSquared(k) == Gamma(alpha=k/2, beta=1/2) expanded_df = self.df * array_ops.ones( self.scale_operator.batch_shape_tensor(), dtype=self.df.dtype.base_dtype) g = random_ops.random_gamma(shape=[n], alpha=self._multi_gamma_sequence( 0.5 * expanded_df, self.dimension), beta=0.5, dtype=self.dtype, seed=distribution_util.gen_new_seed( seed, "wishart")) # Complexity: O(nbk**2) x = array_ops.matrix_band_part(x, -1, 0) # Tri-lower. # Complexity: O(nbk) x = array_ops.matrix_set_diag(x, math_ops.sqrt(g)) # Make batch-op ready. # Complexity: O(nbk**2) perm = array_ops.concat([math_ops.range(1, ndims), [0]], 0) x = array_ops.transpose(x, perm) shape = array_ops.concat([batch_shape, [event_shape[0]], [-1]], 0) x = array_ops.reshape(x, shape) # Complexity: O(nbM) where M is the complexity of the operator solving a # vector system. E.g., for LinearOperatorDiag, each matmul is O(k**2), so # this complexity is O(nbk**2). For LinearOperatorLowerTriangular, # each matmul is O(k^3) so this step has complexity O(nbk^3). x = self.scale_operator.matmul(x) # Undo make batch-op ready. # Complexity: O(nbk**2) shape = array_ops.concat([batch_shape, event_shape, [n]], 0) x = array_ops.reshape(x, shape) perm = array_ops.concat([[ndims - 1], math_ops.range(0, ndims - 1)], 0) x = array_ops.transpose(x, perm) if not self.cholesky_input_output_matrices: # Complexity: O(nbk^3) x = math_ops.matmul(x, x, adjoint_b=True) return x
def random_tril_matrix(shape, dtype, force_well_conditioned=False, remove_upper=True): """[batch] lower triangular matrix. Args: shape: `TensorShape` or Python `list`. Shape of the returned matrix. dtype: `TensorFlow` `dtype` or Python dtype force_well_conditioned: Python `bool`. If `True`, returned matrix will have eigenvalues with modulus in `(1, 2)`. Otherwise, eigenvalues are unit normal random variables. remove_upper: Python `bool`. If `True`, zero out the strictly upper triangle. If `False`, the lower triangle of returned matrix will have desired properties, but will not have the strictly upper triangle zero'd out. Returns: `Tensor` with desired shape and dtype. """ with ops.name_scope("random_tril_matrix"): # Totally random matrix. Has no nice properties. tril = random_normal(shape, dtype=dtype) if remove_upper: tril = array_ops.matrix_band_part(tril, -1, 0) # Create a diagonal with entries having modulus in [1, 2]. if force_well_conditioned: maxval = ops.convert_to_tensor(np.sqrt(2.), dtype=dtype.real_dtype) diag = random_sign_uniform( shape[:-1], dtype=dtype, minval=1., maxval=maxval) tril = array_ops.matrix_set_diag(tril, diag) return tril
def _MatrixTriangularSolveGrad(op, grad): """Gradient for MatrixTriangularSolve.""" a = op.inputs[0] adjoint_a = op.get_attr("adjoint") lower_a = op.get_attr("lower") c = op.outputs[0] grad_b = linalg_ops.matrix_triangular_solve(a, grad, lower=lower_a, adjoint=not adjoint_a) if adjoint_a: grad_a = -math_ops.batch_matmul(c, grad_b, adj_y=True) else: grad_a = -math_ops.batch_matmul(grad_b, c, adj_y=True) if lower_a: grad_a = array_ops.matrix_band_part(grad_a, -1, 0) else: grad_a = array_ops.matrix_band_part(grad_a, 0, -1) return (grad_a, grad_b)
def sign_magnitude_positive_definite( raw, off_diagonal_scale=0., overall_scale=0.): """Constructs a positive definite matrix from an unconstrained input matrix. We want to keep the whole matrix on a log scale, but also allow off-diagonal elements to be negative, so the sign of off-diagonal elements is modeled separately from their magnitude (using the lower and upper triangles respectively). Specifically: for i < j, we have: output_cholesky[i, j] = raw[j, i] / (abs(raw[j, i]) + 1) * exp((off_diagonal_scale + overall_scale + raw[i, j]) / 2) output_cholesky[i, i] = exp((raw[i, i] + overall_scale) / 2) output = output_cholesky^T * output_cholesky where raw, off_diagonal_scale, and overall_scale are un-constrained real-valued variables. The resulting values are stable around zero due to the exponential (and the softsign keeps the function smooth). Args: raw: A [..., M, M] Tensor. off_diagonal_scale: A scalar or [...] shaped Tensor controlling the relative scale of off-diagonal values in the output matrix. overall_scale: A scalar or [...] shaped Tensor controlling the overall scale of the output matrix. Returns: The `output` matrix described above, a [..., M, M] positive definite matrix. """ raw = ops.convert_to_tensor(raw) diagonal = array_ops.matrix_diag_part(raw) def _right_pad_with_ones(tensor, target_rank): # Allow broadcasting even if overall_scale and off_diagonal_scale have batch # dimensions tensor = ops.convert_to_tensor(tensor, dtype=raw.dtype.base_dtype) return array_ops.reshape(tensor, array_ops.concat( [ array_ops.shape(tensor), array_ops.ones( [target_rank - array_ops.rank(tensor)], dtype=target_rank.dtype) ], axis=0)) # We divide the log values by 2 to compensate for the squaring that happens # when transforming Cholesky factors into positive definite matrices. sign_magnitude = (gen_math_ops.exp( (raw + _right_pad_with_ones(off_diagonal_scale, array_ops.rank(raw)) + _right_pad_with_ones(overall_scale, array_ops.rank(raw))) / 2.) * nn.softsign(array_ops.matrix_transpose(raw))) sign_magnitude.set_shape(raw.get_shape()) cholesky_factor = array_ops.matrix_set_diag( input=array_ops.matrix_band_part(sign_magnitude, 0, -1), diagonal=gen_math_ops.exp((diagonal + _right_pad_with_ones( overall_scale, array_ops.rank(diagonal))) / 2.)) return math_ops.matmul(cholesky_factor, cholesky_factor, transpose_a=True)
def _sample_n(self, n, seed): batch_shape = self.batch_shape() event_shape = self.event_shape() batch_ndims = array_ops.shape(batch_shape)[0] ndims = batch_ndims + 3 # sample_ndims=1, event_ndims=2 shape = array_ops.concat(((n,), batch_shape, event_shape), 0) # Complexity: O(nbk^2) x = random_ops.random_normal(shape=shape, mean=0., stddev=1., dtype=self.dtype, seed=seed) # Complexity: O(nbk) # This parametrization is equivalent to Chi2, i.e., # ChiSquared(k) == Gamma(alpha=k/2, beta=1/2) g = random_ops.random_gamma(shape=(n,), alpha=self._multi_gamma_sequence( 0.5 * self.df, self.dimension), beta=0.5, dtype=self.dtype, seed=distribution_util.gen_new_seed( seed, "wishart")) # Complexity: O(nbk^2) x = array_ops.matrix_band_part(x, -1, 0) # Tri-lower. # Complexity: O(nbk) x = array_ops.matrix_set_diag(x, math_ops.sqrt(g)) # Make batch-op ready. # Complexity: O(nbk^2) perm = array_ops.concat((math_ops.range(1, ndims), (0,)), 0) x = array_ops.transpose(x, perm) shape = array_ops.concat((batch_shape, (event_shape[0], -1)), 0) x = array_ops.reshape(x, shape) # Complexity: O(nbM) where M is the complexity of the operator solving a # vector system. E.g., for OperatorPDDiag, each matmul is O(k^2), so # this complexity is O(nbk^2). For OperatorPDCholesky, each matmul is # O(k^3) so this step has complexity O(nbk^3). x = self.scale_operator_pd.sqrt_matmul(x) # Undo make batch-op ready. # Complexity: O(nbk^2) shape = array_ops.concat((batch_shape, event_shape, (n,)), 0) x = array_ops.reshape(x, shape) perm = array_ops.concat(((ndims - 1,), math_ops.range(0, ndims - 1)), 0) x = array_ops.transpose(x, perm) if not self.cholesky_input_output_matrices: # Complexity: O(nbk^3) x = math_ops.matmul(x, x, adjoint_b=True) return x
def _GradWithInverseL(l, l_inverse, grad): middle = math_ops.matmul(l, grad, adjoint_a=True) middle = array_ops.matrix_set_diag(middle, 0.5 * array_ops.matrix_diag_part(middle)) middle = array_ops.matrix_band_part(middle, -1, 0) grad_a = math_ops.matmul( math_ops.matmul(l_inverse, middle, adjoint_a=True), l_inverse) grad_a += math_ops.conj(array_ops.matrix_transpose(grad_a)) return grad_a * 0.5
def CheckUnitary(self, x): # Tests that x[...,:,:]^H * x[...,:,:] is close to the identity. xx = math_ops.matmul(x, x, adjoint_a=True) identity = array_ops.matrix_band_part(array_ops.ones_like(xx), 0, 0) if is_single: tol = 1e-5 else: tol = 1e-14 self.assertAllClose(identity.eval(), self.evaluate(xx), atol=tol)
def _forward(self, x): if self.validate_args: is_matrix = check_ops.assert_rank_at_least(x, 2) shape = array_ops.shape(x) is_square = check_ops.assert_equal(shape[-2], shape[-1]) x = control_flow_ops.with_dependencies([is_matrix, is_square], x) # For safety, explicitly zero-out the upper triangular part. x = array_ops.matrix_band_part(x, -1, 0) return math_ops.matmul(x, x, adjoint_b=True)
def Test(self): shape = batch_shape_ + shape_ x = constant_op.constant(np.random.rand(*shape), dtype=dtype_) with self.test_session(use_gpu=True): for lower in -1, 0, 1, shape_[-2] - 1: for upper in -1, 0, 1, shape_[-1] - 1: y = array_ops.matrix_band_part(x, lower, upper) error = gradient_checker.compute_gradient_error( x, x.get_shape().as_list(), y, y.get_shape().as_list()) self.assertLess(error, 1e-4)
def __init__(self, tril, is_non_singular=None, is_self_adjoint=None, is_positive_definite=None, is_square=None, name="LinearOperatorLowerTriangular"): r"""Initialize a `LinearOperatorLowerTriangular`. Args: tril: Shape `[B1,...,Bb, N, N]` with `b >= 0`, `N >= 0`. The lower triangular part of `tril` defines this operator. The strictly upper triangle is ignored. Allowed dtypes: `float16`, `float32`, `float64`. is_non_singular: Expect that this operator is non-singular. This operator is non-singular if and only if its diagonal elements are all non-zero. is_self_adjoint: Expect that this operator is equal to its hermitian transpose. This operator is self-adjoint only if it is diagonal with real-valued diagonal entries. In this case it is advised to use `LinearOperatorDiag`. is_positive_definite: Expect that this operator is positive definite, meaning the quadratic form `x^H A x` has positive real part for all nonzero `x`. Note that we do not require the operator to be self-adjoint to be positive-definite. See: https://en.wikipedia.org/wiki/Positive-definite_matrix\ #Extension_for_non_symmetric_matrices is_square: Expect that this operator acts like square [batch] matrices. name: A name for this `LinearOperator`. Raises: TypeError: If `diag.dtype` is not an allowed type. ValueError: If `is_square` is `False`. """ if is_square is False: raise ValueError( "Only square lower triangular operators supported at this time.") is_square = True with ops.name_scope(name, values=[tril]): self._tril = ops.convert_to_tensor(tril, name="tril") self._check_tril(self._tril) self._tril = array_ops.matrix_band_part(tril, -1, 0) self._diag = array_ops.matrix_diag_part(self._tril) super(LinearOperatorLowerTriangular, self).__init__( dtype=self._tril.dtype, graph_parents=[self._tril], is_non_singular=is_non_singular, is_self_adjoint=is_self_adjoint, is_positive_definite=is_positive_definite, is_square=is_square, name=name)
def _preprocess_tril(self, identity_multiplier, diag, tril, event_ndims): """Helper to preprocess a lower triangular matrix.""" tril = array_ops.matrix_band_part(tril, -1, 0) # Zero out TriU. if identity_multiplier is None and diag is None: return self._process_matrix(tril, min_rank=2, event_ndims=event_ndims) new_diag = array_ops.matrix_diag_part(tril) if identity_multiplier is not None: new_diag += identity_multiplier if diag is not None: new_diag += diag tril = array_ops.matrix_set_diag(tril, new_diag) return self._process_matrix(tril, min_rank=2, event_ndims=event_ndims)
def __init__(self, tril, is_non_singular=None, is_self_adjoint=None, is_positive_definite=None, name="LinearOperatorTriL"): """Initialize a `LinearOperatorTriL`. Args: tril: Shape `[B1,...,Bb, N, N]` with `b >= 0`, `N >= 0`. The lower triangular part of `tril` defines this operator. The strictly upper triangle is ignored. Allowed dtypes: `float32`, `float64`. is_non_singular: Expect that this operator is non-singular. This operator is non-singular if and only if its diagonal elements are all non-zero. is_self_adjoint: Expect that this operator is equal to its hermitian transpose. This operator is self-adjoint only if it is diagonal with real-valued diagonal entries. In this case it is advised to use `LinearOperatorDiag`. is_positive_definite: Expect that this operator is positive definite, meaning the real part of all eigenvalues is positive. We do not require the operator to be self-adjoint to be positive-definite. See: https://en.wikipedia.org/wiki/Positive-definite_matrix #Extension_for_non_symmetric_matrices name: A name for this `LinearOperator`. Raises: TypeError: If `diag.dtype` is not an allowed type. """ # TODO(langmore) Add complex types once matrix_triangular_solve works for # them. allowed_dtypes = [dtypes.float32, dtypes.float64] with ops.name_scope(name, values=[tril]): self._tril = array_ops.matrix_band_part(tril, -1, 0) self._diag = array_ops.matrix_diag_part(self._tril) dtype = self._tril.dtype if dtype not in allowed_dtypes: raise TypeError( "Argument tril must have dtype in %s. Found: %s" % (allowed_dtypes, dtype)) super(LinearOperatorTriL, self).__init__( dtype=self._tril.dtype, graph_parents=[self._tril], is_non_singular=is_non_singular, is_self_adjoint=is_self_adjoint, is_positive_definite=is_positive_definite, name=name)
def Test(self): mat = np.ones(shape_).astype(dtype_) batch_mat = np.tile(mat, batch_shape_ + (1, 1)) with self.test_session(use_gpu=True): for lower in -1, 0, 1, shape_[-2] - 1: for upper in -1, 0, 1, shape_[-1] - 1: band_np = mat if lower >= 0: band_np = np.triu(band_np, -lower) if upper >= 0: band_np = np.tril(band_np, upper) if batch_shape_ is not (): band_np = np.tile(band_np, batch_shape + (1, 1)) band = array_ops.matrix_band_part(batch_mat, lower, upper) self.assertAllEqual(band_np, band.eval())
def _operator_and_matrix(self, build_info, dtype, use_placeholder): shape = list(build_info.shape) # Upper triangle will be nonzero, but ignored. # Use a diagonal that ensures this matrix is well conditioned. tril = linear_operator_test_util.random_tril_matrix( shape, dtype=dtype, force_well_conditioned=True, remove_upper=False) lin_op_tril = tril if use_placeholder: lin_op_tril = array_ops.placeholder_with_default(lin_op_tril, shape=None) operator = linalg.LinearOperatorLowerTriangular(lin_op_tril) matrix = array_ops.matrix_band_part(tril, -1, 0) return operator, matrix
def TriAngSolveCompositeGrad(l, grad): # Gradient is l^{-H} @ ((l^{H} @ grad) * (tril(ones)-1/2*eye)) @ l^{-1} # Compute ((l^{H} @ grad) * (tril(ones)-1/2*eye)) = middle middle = math_ops.matmul(l, grad, adjoint_a=True) middle = array_ops.matrix_set_diag(middle, 0.5 * array_ops.matrix_diag_part(middle)) middle = array_ops.matrix_band_part(middle, -1, 0) # Compute l^{-H} @ middle = z l_inverse_middle = linalg_ops.matrix_triangular_solve(l, middle, adjoint=True) # We need to compute z @ l^{-1}. With matrix_triangular_solve we # actually compute l^{-H} @ z^{H} = grad. Since we later add grad^{H} # we can ommit the conjugate transpose here. z_h = math_ops.conj(array_ops.matrix_transpose(l_inverse_middle)) grad_a = linalg_ops.matrix_triangular_solve(l, z_h, adjoint=True) grad_a += linalg.adjoint(grad_a) return grad_a * 0.5
def Test(self): mat = np.ones(shape_).astype(dtype_) batch_mat = np.tile(mat, batch_shape_ + (1, 1)) for lower in -1, 0, 1, shape_[-2] - 1: for upper in -1, 0, 1, shape_[-1] - 1: band_np = mat if lower >= 0: band_np = np.triu(band_np, -lower) if upper >= 0: band_np = np.tril(band_np, upper) if batch_shape_ is not (): band_np = np.tile(band_np, batch_shape_ + (1, 1)) for index_dtype in [dtypes_lib.int32, dtypes_lib.int64]: with self.test_session(use_gpu=False): band = array_ops.matrix_band_part( batch_mat, constant_op.constant(lower, index_dtype), constant_op.constant(upper, index_dtype)) self.assertAllEqual(band_np, band.eval())
def _CholeskyGrad(op, grad): """Gradient for Cholesky.""" # Gradient is l^{-H} @ ((l^{H} @ grad) * (tril(ones)-1/2*eye)) @ l^{-1} l = op.outputs[0] num_rows = array_ops.shape(l)[-1] batch_shape = array_ops.shape(l)[:-2] l_inverse = linalg_ops.matrix_triangular_solve( l, linalg_ops.eye(num_rows, batch_shape=batch_shape, dtype=l.dtype)) middle = math_ops.matmul(l, grad, adjoint_a=True) middle = array_ops.matrix_set_diag(middle, 0.5 * array_ops.matrix_diag_part(middle)) middle = array_ops.matrix_band_part(middle, -1, 0) grad_a = math_ops.matmul( math_ops.matmul(l_inverse, middle, adjoint_a=True), l_inverse) grad_a += math_ops.conj(array_ops.matrix_transpose(grad_a)) return grad_a * 0.5
def _operator_and_mat_and_feed_dict(self, shape, dtype, use_placeholder): # Upper triangle will be nonzero, but ignored. # Use a diagonal that ensures this matrix is well conditioned. tril = linear_operator_test_util.random_tril_matrix( shape, dtype=dtype, force_well_conditioned=True, remove_upper=False) if use_placeholder: tril_ph = array_ops.placeholder(dtype=dtype) # Evaluate the tril here because (i) you cannot feed a tensor, and (ii) # tril is random and we want the same value used for both mat and # feed_dict. tril = tril.eval() operator = linalg.LinearOperatorTriL(tril_ph) feed_dict = {tril_ph: tril} else: operator = linalg.LinearOperatorTriL(tril) feed_dict = None mat = array_ops.matrix_band_part(tril, -1, 0) return operator, mat, feed_dict
def TriAngSolveCompositeGrad(l, grad): # Gradient is l^{-H} @ ((l^{H} @ grad) * (tril(ones)-1/2*eye)) @ l^{-1} # Compute ((l^{H} @ grad) * (tril(ones)-1/2*eye)) = middle middle = math_ops.matmul(l, grad, adjoint_a=True) middle = array_ops.matrix_set_diag( middle, 0.5 * array_ops.matrix_diag_part(middle)) middle = array_ops.matrix_band_part(middle, -1, 0) # Compute l^{-H} @ middle = z l_inverse_middle = linalg_ops.matrix_triangular_solve(l, middle, adjoint=True) # We need to compute z @ l^{-1}. With matrix_triangular_solve we # actually compute l^{-H} @ z^{H} = grad. Since we later add grad^{H} # we can ommit the conjugate transpose here. z_h = math_ops.conj(array_ops.matrix_transpose(l_inverse_middle)) grad_a = linalg_ops.matrix_triangular_solve(l, z_h, adjoint=True) grad_a += math_ops.conj(array_ops.matrix_transpose(grad_a)) return grad_a * 0.5
def _SelfAdjointEigV2Grad(op, grad_e, grad_v): """Gradient for SelfAdjointEigV2.""" e = op.outputs[0] compute_v = op.get_attr("compute_v") # a = op.inputs[0], which satisfies # a[...,:,:] * v[...,:,i] = e[...,i] * v[...,i] with ops.control_dependencies([grad_e, grad_v]): if compute_v: v = op.outputs[1] # Construct the matrix f(i,j) = (i != j ? 1 / (e_i - e_j) : 0). # Notice that because of the term involving f, the gradient becomes # infinite (or NaN in practice) when eigenvalues are not unique. # Mathematically this should not be surprising, since for (k-fold) # degenerate eigenvalues, the corresponding eigenvectors are only defined # up to arbitrary rotation in a (k-dimensional) subspace. f = array_ops.matrix_set_diag( math_ops.reciprocal( array_ops.expand_dims(e, -2) - array_ops.expand_dims(e, -1)), array_ops.zeros_like(e)) grad_a = math_ops.matmul( v, math_ops.matmul( array_ops.matrix_diag(grad_e) + f * math_ops.matmul(v, grad_v, adjoint_a=True), v, adjoint_b=True)) else: _, v = linalg_ops.self_adjoint_eig(op.inputs[0]) grad_a = math_ops.matmul(v, math_ops.matmul( array_ops.matrix_diag(grad_e), v, adjoint_b=True)) # The forward op only depends on the lower triangular part of a, so here we # symmetrize and take the lower triangle grad_a = array_ops.matrix_band_part( grad_a + math_ops.conj(array_ops.matrix_transpose(grad_a)), -1, 0) grad_a = array_ops.matrix_set_diag(grad_a, 0.5 * array_ops.matrix_diag_part(grad_a)) return grad_a
def _SelfAdjointEigV2Grad(op, grad_e, grad_v): """Gradient for SelfAdjointEigV2.""" e = op.outputs[0] compute_v = op.get_attr("compute_v") # a = op.inputs[0], which satisfies # a[...,:,:] * v[...,:,i] = e[...,i] * v[...,i] with ops.control_dependencies([grad_e, grad_v]): if compute_v: v = op.outputs[1] # Construct the matrix f(i,j) = (i != j ? 1 / (e_i - e_j) : 0). # Notice that because of the term involving f, the gradient becomes # infinite (or NaN in practice) when eigenvalues are not unique. # Mathematically this should not be surprising, since for (k-fold) # degenerate eigenvalues, the corresponding eigenvectors are only defined # up to arbitrary rotation in a (k-dimensional) subspace. f = array_ops.matrix_set_diag( math_ops.reciprocal( array_ops.expand_dims(e, -2) - array_ops.expand_dims(e, -1)), array_ops.zeros_like(e)) grad_a = math_ops.matmul( v, math_ops.matmul(array_ops.matrix_diag(grad_e) + f * math_ops.matmul(v, grad_v, adjoint_a=True), v, adjoint_b=True)) else: _, v = linalg_ops.self_adjoint_eig(op.inputs[0]) grad_a = math_ops.matmul( v, math_ops.matmul(array_ops.matrix_diag(grad_e), v, adjoint_b=True)) # The forward op only depends on the lower triangular part of a, so here we # symmetrize and take the lower triangle grad_a = array_ops.matrix_band_part(grad_a + _linalg.adjoint(grad_a), -1, 0) grad_a = array_ops.matrix_set_diag( grad_a, 0.5 * array_ops.matrix_diag_part(grad_a)) return grad_a
def testMatrixBandPart(self, batch_shape, rows, cols): for dtype in self.float_types: with self.cached_session(): mat = np.ones(batch_shape + [rows, cols]).astype(dtype) batch_mat = np.tile(mat, batch_shape + [1, 1]) for lower in -1, 0, 1, rows - 1: for upper in -1, 0, 1, cols - 1: band_np = mat if lower >= 0: band_np = np.triu(band_np, -lower) if upper >= 0: band_np = np.tril(band_np, upper) if batch_shape: band_np = np.tile(band_np, batch_shape + [1, 1]) placeholder = array_ops.placeholder(dtype) with self.test_scope(): band = array_ops.matrix_band_part( placeholder, constant_op.constant(lower, dtype=dtypes.int32), constant_op.constant(upper, dtype=dtypes.int32)) feed_dict = {placeholder: batch_mat} self.assertAllEqual(band_np, band.eval(feed_dict=feed_dict))
def _assertions(self, x): if not self.validate_args: return [] shape = array_ops.shape(x) is_matrix = check_ops.assert_rank_at_least( x, 2, message="Input must have rank at least 2.") is_square = check_ops.assert_equal( shape[-2], shape[-1], message="Input must be a square matrix.") above_diagonal = array_ops.matrix_band_part( array_ops.matrix_set_diag( x, array_ops.zeros(shape[:-1], dtype=dtypes.float32)), 0, -1) is_lower_triangular = check_ops.assert_equal( above_diagonal, array_ops.zeros_like(above_diagonal), message="Input must be lower triangular.") # A lower triangular matrix is nonsingular iff all its diagonal entries are # nonzero. diag_part = array_ops.matrix_diag_part(x) is_nonsingular = check_ops.assert_none_equal( diag_part, array_ops.zeros_like(diag_part), message="Input must have all diagonal entries nonzero.") return [is_matrix, is_square, is_lower_triangular, is_nonsingular]
def _QrGradSquareAndDeepMatrices(q, r, dq, dr): """Gradient for matrix orders num_rows >= num_cols and full_matrices is false. """ qdq = math_ops.matmul(q, dq, adjoint_a=True) qdq_ = qdq - _linalg.adjoint(qdq) rdr = math_ops.matmul(r, dr, adjoint_b=True) rdr_ = rdr - _linalg.adjoint(rdr) tril = array_ops.matrix_band_part(qdq_ + rdr_, -1, 0) grad_a = math_ops.matmul(q, dr + _TriangularSolve(tril, r)) grad_b = _TriangularSolve(dq - math_ops.matmul(q, qdq), r) ret = grad_a + grad_b if q.dtype.is_complex: # need to add a correction to the gradient formula for complex case m = rdr - _linalg.adjoint(qdq) eyem = _linalg.set_diag(array_ops.zeros_like(m), _linalg.diag_part(m)) correction = eyem - math_ops.cast(math_ops.real(eyem), q.dtype) ret = ret + _TriangularSolve( math_ops.matmul(q, _linalg.adjoint(correction)), r) return ret
def _testMatrixBandPart(self, dtype, shape): with self.test_session(): batch_shape = shape[:-2] mat = np.ones(shape).astype(dtype) batch_mat = np.tile(mat, batch_shape + [1, 1]) for lower in -1, 0, 1, shape[-2] - 1: for upper in -1, 0, 1, shape[-1] - 1: band_np = mat if lower >= 0: band_np = np.triu(band_np, -lower) if upper >= 0: band_np = np.tril(band_np, upper) if batch_shape: band_np = np.tile(band_np, batch_shape + [1, 1]) placeholder = array_ops.placeholder(dtype) with self.test_scope(): band = array_ops.matrix_band_part( placeholder, constant_op.constant(lower, dtype=dtypes.int32), constant_op.constant(upper, dtype=dtypes.int32)) feed_dict = {placeholder: batch_mat} self.assertAllEqual(band_np, band.eval(feed_dict=feed_dict))
def test_matrix_triangular_solve(self): for lower in (True, False): for adjoint in (True, False): for stack_a in (True, False): for stack_b in (True, False): shape_a = (2, 4, 3, 3) if stack_a else (4, 3, 3) shape_b = (2, 4, 3, 5) if stack_b else (4, 3, 5) x = array_ops.matrix_band_part( random_ops.random_uniform(shape_a) + linalg_ops.eye(3), # Ensure well-conditioned. *((-1, 0) if lower else (0, -1))) # Ensure triangular. y = random_ops.random_uniform(shape_b) # pylint: disable=cell-var-from-loop def loop_fn(i): a = array_ops.gather(x, i) if stack_a else x b = array_ops.gather(y, i) if stack_b else y return linalg_ops.matrix_triangular_solve( a, b, lower=lower, adjoint=adjoint) # pylint: enable=cell-var-from-loop self._test_loop_fn(loop_fn, 2)
def _CholeskyGrad(op, grad): """Gradient for Cholesky.""" # Gradient is l^{-H} @ ((l^{H} @ grad) * (tril(ones)-1/2*eye)) @ l^{-1} l = op.outputs[0] num_rows = array_ops.shape(l)[-1] batch_shape = array_ops.shape(l)[:-2] l_inverse = linalg_ops.matrix_triangular_solve(l, linalg_ops.eye( num_rows, batch_shape=batch_shape, dtype=l.dtype)) middle = math_ops.matmul(l, grad, adjoint_a=True) middle = array_ops.matrix_set_diag(middle, 0.5 * array_ops.matrix_diag_part(middle)) middle = array_ops.matrix_band_part(middle, -1, 0) grad_a = math_ops.matmul( math_ops.matmul(l_inverse, middle, adjoint_a=True), l_inverse) grad_a += _linalg.adjoint(grad_a) return grad_a * 0.5
def test_matrix_triangular_solve(self): for lower in (True, False): for adjoint in (True, False): for stack_a in (True, False): for stack_b in (True, False): shape_a = (2, 4, 3, 3) if stack_a else (4, 3, 3) shape_b = (2, 4, 3, 5) if stack_b else (4, 3, 5) x = array_ops.matrix_band_part( random_ops.random_uniform(shape_a) + linalg_ops.eye(3), # Ensure well-conditioned. *((-1, 0) if lower else (0, -1))) # Ensure triangular. y = random_ops.random_uniform(shape_b) # pylint: disable=cell-var-from-loop def loop_fn(i): a = array_ops.gather(x, i) if stack_a else x b = array_ops.gather(y, i) if stack_b else y return linalg_ops.matrix_triangular_solve(a, b, lower=lower, adjoint=adjoint) # pylint: enable=cell-var-from-loop self._test_loop_fn(loop_fn, 2)
def _testMatrixBandPart(self, dtype, shape): with self.cached_session(): batch_shape = shape[:-2] mat = np.ones(shape).astype(dtype) batch_mat = np.tile(mat, batch_shape + [1, 1]) for lower in -1, 0, 1, shape[-2] - 1: for upper in -1, 0, 1, shape[-1] - 1: band_np = mat if lower >= 0: band_np = np.triu(band_np, -lower) if upper >= 0: band_np = np.tril(band_np, upper) if batch_shape: band_np = np.tile(band_np, batch_shape + [1, 1]) placeholder = array_ops.placeholder(dtype) with self.test_scope(): band = array_ops.matrix_band_part( placeholder, constant_op.constant(lower, dtype=dtypes.int32), constant_op.constant(upper, dtype=dtypes.int32)) feed_dict = {placeholder: batch_mat} self.assertAllEqual(band_np, band.eval(feed_dict=feed_dict))
def random_tril_matrix(shape, dtype, force_well_conditioned=False, remove_upper=True): """[batch] lower triangular matrix. Args: shape: `TensorShape` or Python `list`. Shape of the returned matrix. dtype: `TensorFlow` `dtype` or Python dtype force_well_conditioned: Python `bool`. If `True`, returned matrix will have eigenvalues with modulus in `(1, 2)`. Otherwise, eigenvalues are unit normal random variables. remove_upper: Python `bool`. If `True`, zero out the strictly upper triangle. If `False`, the lower triangle of returned matrix will have desired properties, but will not have the strictly upper triangle zero'd out. Returns: `Tensor` with desired shape and dtype. """ with ops.name_scope("random_tril_matrix"): # Totally random matrix. Has no nice properties. tril = random_normal(shape, dtype=dtype) if remove_upper: tril = array_ops.matrix_band_part(tril, -1, 0) # Create a diagonal with entries having modulus in [1, 2]. if force_well_conditioned: maxval = ops.convert_to_tensor(np.sqrt(2.), dtype=dtype.real_dtype) diag = random_sign_uniform(shape[:-1], dtype=dtype, minval=1., maxval=maxval) tril = array_ops.matrix_set_diag(tril, diag) return tril
def _random_pd_matrix(self, *shape): mat = rng.rand(*shape) chol = ds.matrix_diag_transform(mat, transform=nn_ops.softplus) chol = array_ops.matrix_band_part(chol, -1, 0) return math_ops.matmul(chol, chol, adjoint_b=True).eval()
def CheckUnitary(self, x, tol): # Tests that x[...,:,:]^H * x[...,:,:] is close to the identity. xx = math_ops.matmul(x, x, adjoint_a=True) identity = array_ops.matrix_band_part(array_ops.ones_like(xx), 0, 0) self.assertAllClose(identity, xx, atol=tol)
def CheckUnitary(self, x): # Tests that x[...,:,:]^H * x[...,:,:] is close to the identity. xx = math_ops.matmul(x, x, adjoint_a=True) identity = array_ops.matrix_band_part(array_ops.ones_like(xx), 0, 0) precision = self.AdjustedNorm(xx.eval() - identity.eval()) self.assertTrue(np.all(precision < 5.0))
def _random_chol(self, *shape): mat = self._rng.rand(*shape) chol = ds.matrix_diag_transform(mat, transform=nn_ops.softplus) chol = array_ops.matrix_band_part(chol, -1, 0) sigma = math_ops.matmul(chol, chol, adjoint_b=True) return chol.eval(), sigma.eval()
def _batch_sqrt_matmul(self, x, transpose_x=False): chol = array_ops.matrix_band_part(self._chol, -1, 0) # tf.batch_matmul is defined x * y, so "y" is on the right, not "x". return math_ops.matmul(chol, x, adjoint_b=transpose_x)
def _MatrixBandPartGrad(op, grad): num_lower = op.inputs[1] num_upper = op.inputs[2] return (array_ops.matrix_band_part(grad, num_lower, num_upper), None, None)
def _matmul(self, x, transpose_x=False): # tf.matmul is defined a * b. chol = array_ops.matrix_band_part(self._chol, -1, 0) chol_times_x = math_ops.matmul( chol, x, transpose_a=True, transpose_b=transpose_x) return math_ops.matmul(chol, chol_times_x)
def fill_lower_triangular(x, validate_args=False, name="fill_lower_triangular"): """Creates a (batch of) lower triangular matrix from a vector of inputs. If `x.get_shape()` is `[b1, b2, ..., bK, d]` then the output shape is `[b1, b2, ..., bK, n, n]` where `n` is such that `d = n(n+1)/2`, i.e., `n = int(0.5 * (math.sqrt(1. + 8. * d) - 1.))`. Although the non-batch complexity is O(n^2), large constants and sub-optimal vectorization means the complexity of this function is 5x slower than zeroing out the upper triangular, i.e., `tf.matrix_band_part(X, -1, 0)`. This function becomes competitive only when several matmul/cholesky/etc ops can be ellided in constructing the input. Example: wiring a fully connected layer as a covariance matrix; this function reduces the final layer by 2x and possibly reduces the network arch complexity considerably. In most cases it is better to simply build a full matrix and zero out the upper triangular elements, e.g., `tril = tf.matrix_band_part(full, -1, 0)`, rather than directly construct a lower triangular. Example: ```python fill_lower_triangular([1, 2, 3, 4, 5, 6]) # Returns: [[1, 0, 0], # [2, 3, 0], # [4, 5, 6]] ``` For comparison, a pure numpy version of this function can be found in `distribution_util_test.py`, function `_fill_lower_triangular`. Args: x: `Tensor` representing lower triangular elements. validate_args: `Boolean`, default `False`. Whether to ensure the shape of `x` can be mapped to a lower triangular matrix (controls non-static checks only). name: `String`. The name to give this op. Returns: tril: `Tensor` with lower triangular elements filled from `x`. Raises: ValueError: if shape if `x` has static shape which cannot be mapped to a lower triangular matrix. """ # TODO(jvdillon): Replace this code with dedicated op when it exists. with ops.name_scope(name, values=(x, )): x = ops.convert_to_tensor(x, name="x") if (x.get_shape().ndims is not None and x.get_shape()[-1].value is not None): d = x.get_shape()[-1].value # d = n(n+1)/2 implies n is: n = int(0.5 * (math.sqrt(1. + 8. * d) - 1.)) d_inferred = n * (n + 1) / 2 if d != d_inferred: raise ValueError( "Input cannot be mapped to a lower triangular; " "n*(n+1)/2 = %d != %d" % (d_inferred, d)) final_shape = x.get_shape()[:-1].concatenate( tensor_shape.TensorShape([n, n])) else: d = math_ops.cast(array_ops.shape(x)[-1], dtype=dtypes.float32) # d = n(n+1)/2 implies n is: n = math_ops.cast(0.5 * (dtypes.sqrt(1. + 8. * d) - 1.), dtype=dtypes.int32) if validate_args: is_valid_input_shape = check_ops.assert_equal( n * (n + 1) / 2, d, message="Input cannot be mapped to a lower triangular.") n = control_flow_ops.with_dependencies([is_valid_input_shape], n) final_shape = x.get_shape()[:-1].concatenate( tensor_shape.TensorShape([None, None])) def tril_ids(n): """Internal helper to create vector of linear indices into y.""" # Build the ids statically; chose 512 because it implies 1MiB. if not contrib_framework.is_tensor(n) and n <= 512: ids = np.arange(n**2, dtype=np.int32) rows = (ids / n).astype(np.int32) # Implicit floor. # We need to stop incrementing the index when we encounter # upper-triangular elements. The idea here is to compute the # lower-right number of zeros then by "symmetry" subtract this from the # total number of zeros, n(n-1)/2. # Then we note that: n(n-1)/2 - (n-r)*(n-r-1)/2 = r(2n-r-1)/2 offset = (rows * (2 * n - rows - 1) / 2).astype(np.int32) # We could also zero out when (rows < cols) == (rows < ids-n*rows). # mask = (ids <= (n + 1) * rows).astype(np.int32) else: ids = math_ops.range(n**2) rows = math_ops.cast(ids / n, dtype=dtypes.int32) offset = math_ops.cast(rows * (2 * n - rows - 1) / 2, dtype=dtypes.int32) return ids - offset # Special-case non-batch case. if x.get_shape().ndims == 1: y = array_ops.gather(x, array_ops.reshape(tril_ids(n), [n, n])) y = array_ops.matrix_band_part(y, -1, 0) y.set_shape(y.get_shape().merge_with(final_shape)) return y # Make ids for each batch dim. if (x.get_shape().ndims is not None and x.get_shape()[:-1].is_fully_defined()): batch_shape = np.asarray(x.get_shape()[:-1].as_list(), dtype=np.int32) m = np.prod(batch_shape).astype(np.int32) else: batch_shape = array_ops.shape(x)[:-1] m = array_ops.reduce_prod(array_ops.shape(x)[:-1]) batch_ids = math_ops.range(m) # Assemble the tril_ids into batch,tril_id pairs. idx = array_ops.pack([ array_ops.tile(array_ops.expand_dims(batch_ids, 1), [1, n * n]), array_ops.tile(array_ops.expand_dims(tril_ids(n), 0), [m, 1]) ]) idx = array_ops.transpose(idx, [1, 2, 0]) # Gather up, reshape, and return. y = array_ops.reshape(x, [-1, d]) y = array_ops.gather_nd(y, idx) y = array_ops.reshape(y, array_ops.concat_v2([batch_shape, [n, n]], 0)) y = array_ops.matrix_band_part(y, -1, 0) y.set_shape(y.get_shape().merge_with(final_shape)) return y
def _batch_matmul(self, x, transpose_x=False): # tf.matmul is defined x * y, so "y" is on the right, not "x". chol = array_ops.matrix_band_part(self._chol, -1, 0) chol_times_x = math_ops.matmul( chol, x, adjoint_a=True, adjoint_b=transpose_x) return math_ops.matmul(chol, chol_times_x)
def _to_dense(self): chol = array_ops.matrix_band_part(self._chol, -1, 0) return math_ops.matmul(chol, chol, adjoint_b=True)
def _sqrt_to_dense(self): chol = array_ops.matrix_band_part(self._chol, -1, 0) return array_ops.identity(chol)
def _tridiag(self, d, diag_value, offdiag_value): """d x d matrix with given value on diag, and one super/sub diag.""" diag_mat = linalg_ops.eye(d) * (diag_value - offdiag_value) three_bands = array_ops.matrix_band_part( array_ops.fill([d, d], offdiag_value), 1, 1) return diag_mat + three_bands
def fill_triangular(x, upper=False, name=None): """Creates a (batch of) triangular matrix from a vector of inputs. Created matrix can be lower- or upper-triangular. (It is more efficient to create the matrix as upper or lower, rather than transpose.) Triangular matrix elements are filled in a clockwise spiral. See example, below. If `x.get_shape()` is `[b1, b2, ..., bK, d]` then the output shape is `[b1, b2, ..., bK, n, n]` where `n` is such that `d = n(n+1)/2`, i.e., `n = int(np.sqrt(0.25 + 2. * m) - 0.5)`. Example: ```python fill_triangular([1, 2, 3, 4, 5, 6]) # ==> [[4, 0, 0], # [6, 5, 0], # [3, 2, 1]] fill_triangular([1, 2, 3, 4, 5, 6], upper=True) # ==> [[1, 2, 3], # [0, 5, 6], # [0, 0, 4]] ``` For comparison, a pure numpy version of this function can be found in `util_test.py`, function `_fill_triangular`. Args: x: `Tensor` representing lower (or upper) triangular elements. upper: Python `bool` representing whether output matrix should be upper triangular (`True`) or lower triangular (`False`, default). name: Python `str`. The name to give this op. Returns: tril: `Tensor` with lower (or upper) triangular elements filled from `x`. Raises: ValueError: if `x` cannot be mapped to a triangular matrix. """ with ops.name_scope(name, "fill_triangular", values=[x]): x = ops.convert_to_tensor(x, name="x") if x.shape.with_rank_at_least(1)[-1].value is not None: # Formula derived by solving for n: m = n(n+1)/2. m = np.int32(x.shape[-1].value) n = np.sqrt(0.25 + 2. * m) - 0.5 if n != np.floor(n): raise ValueError( "Input right-most shape ({}) does not " "correspond to a triangular matrix.".format(m)) n = np.int32(n) static_final_shape = x.shape[:-1].concatenate([n, n]) else: m = array_ops.shape(x)[-1] # For derivation, see above. Casting automatically lops off the 0.5, so we # omit it. We don't validate n is an integer because this has # graph-execution cost; an error will be thrown from the reshape, below. n = math_ops.cast( math_ops.sqrt(0.25 + math_ops.cast(2 * m, dtype=dtypes.float32)), dtype=dtypes.int32) static_final_shape = x.shape.with_rank_at_least( 1)[:-1].concatenate([None, None]) # We now concatenate the "tail" of `x` to `x` (and reverse one of them). # # We do this based on the insight that the input `x` provides `ceil(n/2)` # rows of an `n x n` matrix, some of which will get zeroed out being on the # wrong side of the diagonal. The first row will not get zeroed out at all, # and we need `floor(n/2)` more rows, so the first is what we omit from # `x_tail`. If we then stack those `ceil(n/2)` rows with the `floor(n/2)` # rows provided by a reversed tail, it is exactly the other set of elements # of the reversed tail which will be zeroed out for being on the wrong side # of the diagonal further up/down the matrix. And, in doing-so, we've filled # the triangular matrix in a clock-wise spiral pattern. Neat! # # Try it out in numpy: # n = 3 # x = np.arange(n * (n + 1) / 2) # m = x.shape[0] # n = np.int32(np.sqrt(.25 + 2 * m) - .5) # x_tail = x[(m - (n**2 - m)):] # np.concatenate([x_tail, x[::-1]], 0).reshape(n, n) # lower # # ==> array([[3, 4, 5], # [5, 4, 3], # [2, 1, 0]]) # np.concatenate([x, x_tail[::-1]], 0).reshape(n, n) # upper # # ==> array([[0, 1, 2], # [3, 4, 5], # [5, 4, 3]]) # # Note that we can't simply do `x[..., -(n**2 - m):]` because this doesn't # correctly handle `m == n == 1`. Hence, we do nonnegative indexing. # Furthermore observe that: # m - (n**2 - m) # = n**2 / 2 + n / 2 - (n**2 - n**2 / 2 + n / 2) # = 2 (n**2 / 2 + n / 2) - n**2 # = n**2 + n - n**2 # = n if upper: x_list = [x, array_ops.reverse(x[..., n:], axis=[-1])] else: x_list = [x[..., n:], array_ops.reverse(x, axis=[-1])] new_shape = (static_final_shape.as_list() if static_final_shape.is_fully_defined() else array_ops.concat([array_ops.shape(x)[:-1], [n, n]], axis=0)) x = array_ops.reshape(array_ops.concat(x_list, axis=-1), new_shape) x = array_ops.matrix_band_part(x, num_lower=(0 if upper else -1), num_upper=(-1 if upper else 0)) x.set_shape(static_final_shape) return x
def loop_fn(i): return array_ops.matrix_band_part(array_ops.gather(x, i), num_lower=num_lower, num_upper=num_upper)
def make_tril_scale(loc=None, scale_tril=None, scale_diag=None, scale_identity_multiplier=None, shape_hint=None, validate_args=False, assert_positive=False, name=None): """Creates a LinOp representing a lower triangular matrix. Args: loc: Floating-point `Tensor`. This is used for inferring shape in the case where only `scale_identity_multiplier` is set. scale_tril: Floating-point `Tensor` representing the diagonal matrix. `scale_diag` has shape [N1, N2, ... k, k], which represents a k x k lower triangular matrix. When `None` no `scale_tril` term is added to the LinOp. The upper triangular elements above the diagonal are ignored. scale_diag: Floating-point `Tensor` representing the diagonal matrix. `scale_diag` has shape [N1, N2, ... k], which represents a k x k diagonal matrix. When `None` no diagonal term is added to the LinOp. scale_identity_multiplier: floating point rank 0 `Tensor` representing a scaling done to the identity matrix. When `scale_identity_multiplier = scale_diag = scale_tril = None` then `scale += IdentityMatrix`. Otherwise no scaled-identity-matrix is added to `scale`. shape_hint: scalar integer `Tensor` representing a hint at the dimension of the identity matrix when only `scale_identity_multiplier` is set. validate_args: Python `bool` indicating whether arguments should be checked for correctness. assert_positive: Python `bool` indicating whether LinOp should be checked for being positive definite. name: Python `str` name given to ops managed by this object. Returns: `LinearOperator` representing a lower triangular matrix. Raises: ValueError: If only `scale_identity_multiplier` is set and `loc` and `shape_hint` are both None. """ def _maybe_attach_assertion(x): if not validate_args: return x if assert_positive: return control_flow_ops.with_dependencies([ check_ops.assert_positive( array_ops.matrix_diag_part(x), message="diagonal part must be positive"), ], x) return control_flow_ops.with_dependencies([ check_ops.assert_none_equal( array_ops.matrix_diag_part(x), array_ops.zeros([], x.dtype), message="diagonal part must be non-zero"), ], x) with ops.name_scope(name, "make_tril_scale", values=[loc, scale_diag, scale_identity_multiplier]): loc = _convert_to_tensor(loc, name="loc") scale_tril = _convert_to_tensor(scale_tril, name="scale_tril") scale_diag = _convert_to_tensor(scale_diag, name="scale_diag") scale_identity_multiplier = _convert_to_tensor( scale_identity_multiplier, name="scale_identity_multiplier") if scale_tril is not None: scale_tril = array_ops.matrix_band_part(scale_tril, -1, 0) # Zero out TriU. tril_diag = array_ops.matrix_diag_part(scale_tril) if scale_diag is not None: tril_diag += scale_diag if scale_identity_multiplier is not None: tril_diag += scale_identity_multiplier[..., array_ops.newaxis] scale_tril = array_ops.matrix_set_diag(scale_tril, tril_diag) return linalg.LinearOperatorLowerTriangular( tril=_maybe_attach_assertion(scale_tril), is_non_singular=True, is_self_adjoint=False, is_positive_definite=assert_positive) return make_diag_scale(loc=loc, scale_diag=scale_diag, scale_identity_multiplier=scale_identity_multiplier, shape_hint=shape_hint, validate_args=validate_args, assert_positive=assert_positive, name=name)
def _verifyLu(self, x, output_idx_type=dtypes.int64): # Verify that Px = LU. with self.cached_session(use_gpu=True) as sess: lu, perm = linalg_ops.lu(x, output_idx_type=output_idx_type) # Prepare the lower factor of shape num_rows x num_rows lu_shape = np.array(lu.shape.as_list()) batch_shape = lu_shape[:-2] num_rows = lu_shape[-2] num_cols = lu_shape[-1] lower = array_ops.matrix_band_part(lu, -1, 0) if num_rows > num_cols: eye = linalg_ops.eye(num_rows, batch_shape=batch_shape, dtype=lower.dtype) lower = array_ops.concat([lower, eye[..., num_cols:]], axis=-1) elif num_rows < num_cols: lower = lower[..., :num_rows] # Fill the diagonal with ones. ones_diag = array_ops.ones(np.append(batch_shape, num_rows), dtype=lower.dtype) lower = array_ops.matrix_set_diag(lower, ones_diag) # Prepare the upper factor. upper = array_ops.matrix_band_part(lu, 0, -1) verification = math_ops.matmul(lower, upper) # Permute the rows of product of the Cholesky factors. if num_rows > 0: # Reshape the product of the triangular factors and permutation indices # to a single batch dimension. This makes it easy to apply # invert_permutation and gather_nd ops. perm_reshaped = array_ops.reshape(perm, [-1, num_rows]) verification_reshaped = array_ops.reshape( verification, [-1, num_rows, num_cols]) # Invert the permutation in each batch. inv_perm_reshaped = functional_ops.map_fn( array_ops.invert_permutation, perm_reshaped) batch_size = perm_reshaped.shape.as_list()[0] # Prepare the batch indices with the same shape as the permutation. # The corresponding batch index is paired with each of the `num_rows` # permutation indices. batch_indices = math_ops.cast(array_ops.broadcast_to( math_ops.range(batch_size)[:, None], perm_reshaped.shape), dtype=output_idx_type) permuted_verification_reshaped = array_ops.gather_nd( verification_reshaped, array_ops.stack([batch_indices, inv_perm_reshaped], axis=-1)) # Reshape the verification matrix back to the original shape. verification = array_ops.reshape( permuted_verification_reshaped, lu_shape) self._verifyLuBase(sess, x, lower, upper, perm, verification, output_idx_type)
def _sqrt_matmul(self, x, transpose_x=False): chol = array_ops.matrix_band_part(self._chol, -1, 0) # tf.matmul is defined a * b return math_ops.matmul(chol, x, adjoint_b=transpose_x)
def _get_tril(self): """Gets the `tril` kwarg, with upper part zero-d out.""" return array_ops.matrix_band_part(self._tril, -1, 0)
def sign_magnitude_positive_definite(raw, off_diagonal_scale=0., overall_scale=0.): """Constructs a positive definite matrix from an unconstrained input matrix. We want to keep the whole matrix on a log scale, but also allow off-diagonal elements to be negative, so the sign of off-diagonal elements is modeled separately from their magnitude (using the lower and upper triangles respectively). Specifically: for i < j, we have: output_cholesky[i, j] = raw[j, i] / (abs(raw[j, i]) + 1) * exp((off_diagonal_scale + overall_scale + raw[i, j]) / 2) output_cholesky[i, i] = exp((raw[i, i] + overall_scale) / 2) output = output_cholesky^T * output_cholesky where raw, off_diagonal_scale, and overall_scale are un-constrained real-valued variables. The resulting values are stable around zero due to the exponential (and the softsign keeps the function smooth). Args: raw: A [..., M, M] Tensor. off_diagonal_scale: A scalar or [...] shaped Tensor controlling the relative scale of off-diagonal values in the output matrix. overall_scale: A scalar or [...] shaped Tensor controlling the overall scale of the output matrix. Returns: The `output` matrix described above, a [..., M, M] positive definite matrix. """ raw = ops.convert_to_tensor(raw) diagonal = array_ops.matrix_diag_part(raw) def _right_pad_with_ones(tensor, target_rank): # Allow broadcasting even if overall_scale and off_diagonal_scale have batch # dimensions tensor = ops.convert_to_tensor(tensor, dtype=raw.dtype.base_dtype) return array_ops.reshape( tensor, array_ops.concat([ array_ops.shape(tensor), array_ops.ones([target_rank - array_ops.rank(tensor)], dtype=target_rank.dtype) ], axis=0)) # We divide the log values by 2 to compensate for the squaring that happens # when transforming Cholesky factors into positive definite matrices. sign_magnitude = (gen_math_ops.exp( (raw + _right_pad_with_ones(off_diagonal_scale, array_ops.rank(raw)) + _right_pad_with_ones(overall_scale, array_ops.rank(raw))) / 2.) * nn.softsign(array_ops.matrix_transpose(raw))) sign_magnitude.set_shape(raw.get_shape()) cholesky_factor = array_ops.matrix_set_diag( input=array_ops.matrix_band_part(sign_magnitude, 0, -1), diagonal=gen_math_ops.exp( (diagonal + _right_pad_with_ones(overall_scale, array_ops.rank(diagonal))) / 2.)) return math_ops.matmul(cholesky_factor, cholesky_factor, transpose_a=True)