def __init__(self, base_kernel, fixed_inputs, diag_shift=None, validate_args=False, name='SchurComplement'): """Construct a SchurComplement kernel instance. Args: base_kernel: A `PositiveSemidefiniteKernel` instance, the kernel used to build the block matrices of which this kernel computes the Schur complement. fixed_inputs: A Tensor, representing a collection of inputs. The Schur complement that this kernel computes comes from a block matrix, whose bottom-right corner is derived from `base_kernel.matrix(fixed_inputs, fixed_inputs)`, and whose top-right and bottom-left pieces are constructed by computing the base_kernel at pairs of input locations together with these `fixed_inputs`. `fixed_inputs` is allowed to be an empty collection (either `None` or having a zero shape entry), in which case the kernel falls back to the trivial application of `base_kernel` to inputs. See class-level docstring for more details on the exact computation this does; `fixed_inputs` correspond to the `Z` structure discussed there. `fixed_inputs` is assumed to have shape `[b1, ..., bB, N, f1, ..., fF]` where the `b`'s are batch shape entries, the `f`'s are feature_shape entries, and `N` is the number of fixed inputs. Use of this kernel entails a 1-time O(N^3) cost of computing the Cholesky decomposition of the k(Z, Z) matrix. The batch shape elements of `fixed_inputs` must be broadcast compatible with `base_kernel.batch_shape`. diag_shift: A floating point scalar to be added to the diagonal of the divisor_matrix before computing its Cholesky. validate_args: If `True`, parameters are checked for validity despite possibly degrading runtime performance. Default value: `False` name: Python `str` name prefixed to Ops created by this class. Default value: `"SchurComplement"` """ parameters = dict(locals()) # Delayed import to avoid circular dependency between `tfp.bijectors` and # `tfp.math` # pylint: disable=g-import-not-at-top from tensorflow_probability.python.bijectors import cholesky_outer_product from tensorflow_probability.python.bijectors import invert # pylint: enable=g-import-not-at-top with tf.name_scope(name) as name: dtype = dtype_util.common_dtype( [base_kernel, fixed_inputs, diag_shift], tf.float32) self._base_kernel = base_kernel self._diag_shift = tensor_util.convert_nonref_to_tensor( diag_shift, dtype=dtype, name='diag_shift') self._fixed_inputs = tensor_util.convert_nonref_to_tensor( fixed_inputs, dtype=dtype, name='fixed_inputs') self._cholesky_bijector = invert.Invert( cholesky_outer_product.CholeskyOuterProduct()) super(SchurComplement, self).__init__( base_kernel.feature_ndims, dtype=dtype, name=name, parameters=parameters)
def _default_event_space_bijector(self): # TODO(b/145620027) Finalize choice of bijector. cholesky_bijector = correlation_cholesky_bijector.CorrelationCholesky( validate_args=self.validate_args) if self.input_output_cholesky: return cholesky_bijector return chain_bijector.Chain([ cholesky_outer_product_bijector.CholeskyOuterProduct( validate_args=self.validate_args), cholesky_bijector ], validate_args=self.validate_args)
def _parameter_properties(cls, dtype, num_classes=None): # pylint: disable=g-long-lambda return dict( loc=parameter_properties.ParameterProperties(event_ndims=1), covariance_matrix=parameter_properties.ParameterProperties( event_ndims=2, shape_fn=lambda sample_shape: ps.concat( [sample_shape, sample_shape[-1:]], axis=0), default_constraining_bijector_fn=( lambda: chain_bijector.Chain([ cholesky_outer_product_bijector.CholeskyOuterProduct(), fill_scale_tril_bijector.FillScaleTriL( diag_shift=dtype_util.eps(dtype)) ]))))
def _default_event_space_bijector(self): # TODO(b/145620027) Finalize choice of bijector. tril_bijector = chain_bijector.Chain([ transform_diagonal_bijector.TransformDiagonal( diag_bijector=softplus_bijector.Softplus( validate_args=self.validate_args), validate_args=self.validate_args), fill_scale_tril_bijector.FillScaleTriL( validate_args=self.validate_args) ], validate_args=self.validate_args) if self.input_output_cholesky: return tril_bijector return chain_bijector.Chain([ cholesky_outer_product_bijector.CholeskyOuterProduct( validate_args=self.validate_args), tril_bijector ], validate_args=self.validate_args)
def _default_event_space_bijector(self): # TODO(b/145620027) Finalize choice of bijector. cholesky_bijector = correlation_cholesky_bijector.CorrelationCholesky( validate_args=self.validate_args) if self.input_output_cholesky: return cholesky_bijector return chain_bijector.Chain([ # We need to explictly clip the output of this bijector because the # other two bijectors sometimes return values that exceed the bounds by # an epsilon due to minute numerical errors. Even numerically stable # algorithms (which the other two bijectors employ) allow for symmetric # errors about the true value, which is inappropriate for a one-sided # validity constraint associated with correlation matrices. _ClipByValue(-1., tf.ones([], self.dtype)), cholesky_outer_product_bijector.CholeskyOuterProduct( validate_args=self.validate_args), cholesky_bijector ], validate_args=self.validate_args)
def __init__(self, base_kernel, fixed_inputs, diag_shift=None, validate_args=False, name='SchurComplement'): """Construct a SchurComplement kernel instance. Args: base_kernel: A `PositiveSemidefiniteKernel` instance, the kernel used to build the block matrices of which this kernel computes the Schur complement. fixed_inputs: A Tensor, representing a collection of inputs. The Schur complement that this kernel computes comes from a block matrix, whose bottom-right corner is derived from `base_kernel.matrix(fixed_inputs, fixed_inputs)`, and whose top-right and bottom-left pieces are constructed by computing the base_kernel at pairs of input locations together with these `fixed_inputs`. `fixed_inputs` is allowed to be an empty collection (either `None` or having a zero shape entry), in which case the kernel falls back to the trivial application of `base_kernel` to inputs. See class-level docstring for more details on the exact computation this does; `fixed_inputs` correspond to the `Z` structure discussed there. `fixed_inputs` is assumed to have shape `[b1, ..., bB, N, f1, ..., fF]` where the `b`'s are batch shape entries, the `f`'s are feature_shape entries, and `N` is the number of fixed inputs. Use of this kernel entails a 1-time O(N^3) cost of computing the Cholesky decomposition of the k(Z, Z) matrix. The batch shape elements of `fixed_inputs` must be broadcast compatible with `base_kernel.batch_shape`. diag_shift: A floating point scalar to be added to the diagonal of the divisor_matrix before computing its Cholesky. validate_args: If `True`, parameters are checked for validity despite possibly degrading runtime performance. Default value: `False` name: Python `str` name prefixed to Ops created by this class. Default value: `"SchurComplement"` """ with tf.compat.v1.name_scope( name, values=[base_kernel, fixed_inputs]) as name: # If the base_kernel doesn't have a specified dtype, we can't pass it off # to common_dtype, which always expects `tf.as_dtype(dtype)` to work (and # it doesn't if the given `dtype` is None. # TODO(b/130421035): Consider changing common_dtype to allow Nones, and # clean this up after. # # Thus, we spell out the logic # here: use the dtype of `fixed_inputs` if possible. If base_kernel.dtype # is not None, use the usual logic. if base_kernel.dtype is None: dtype = None if fixed_inputs is None else fixed_inputs.dtype else: dtype = dtype_util.common_dtype([base_kernel, fixed_inputs], tf.float32) self._base_kernel = base_kernel self._fixed_inputs = (None if fixed_inputs is None else tf.convert_to_tensor(value=fixed_inputs, dtype=dtype)) if not self._is_empty_fixed_inputs(): # We create and store this matrix here, so that we get the caching # benefit when we later access its cholesky. If we computed the matrix # every time we needed the cholesky, the bijector cache wouldn't be hit. self._divisor_matrix = base_kernel.matrix(fixed_inputs, fixed_inputs) if diag_shift is not None: self._divisor_matrix = _add_diagonal_shift( self._divisor_matrix, diag_shift) self._cholesky_bijector = invert.Invert( cholesky_outer_product.CholeskyOuterProduct()) super(SchurComplement, self).__init__( base_kernel.feature_ndims, dtype=dtype, name=name)
def __init__(self, base_kernel, fixed_inputs, fixed_inputs_mask=None, fixed_inputs_is_missing=None, diag_shift=None, cholesky_fn=None, validate_args=False, name='SchurComplement', _precomputed_divisor_matrix_cholesky=None): """Construct a SchurComplement kernel instance. Args: base_kernel: A `PositiveSemidefiniteKernel` instance, the kernel used to build the block matrices of which this kernel computes the Schur complement. fixed_inputs: A Tensor, representing a collection of inputs. The Schur complement that this kernel computes comes from a block matrix, whose bottom-right corner is derived from `base_kernel.matrix(fixed_inputs, fixed_inputs)`, and whose top-right and bottom-left pieces are constructed by computing the base_kernel at pairs of input locations together with these `fixed_inputs`. `fixed_inputs` is allowed to be an empty collection (either `None` or having a zero shape entry), in which case the kernel falls back to the trivial application of `base_kernel` to inputs. See class-level docstring for more details on the exact computation this does; `fixed_inputs` correspond to the `Z` structure discussed there. `fixed_inputs` is assumed to have shape `[b1, ..., bB, N, f1, ..., fF]` where the `b`'s are batch shape entries, the `f`'s are feature_shape entries, and `N` is the number of fixed inputs. Use of this kernel entails a 1-time O(N^3) cost of computing the Cholesky decomposition of the k(Z, Z) matrix. The batch shape elements of `fixed_inputs` must be broadcast compatible with `base_kernel.batch_shape`. fixed_inputs_mask: Deprecated. A boolean Tensor of shape `[..., N]`. When `mask` is not None and an element of `mask` is `False`, this kernel will return values computed as if the divisor matrix did not contain the corresponding row or column. fixed_inputs_is_missing: A boolean Tensor of shape `[..., N]`. When `is_missing` is not None and an element of `mask` is `True`, this kernel will return values computed as if the divisor matrix did not contain the corresponding row or column. diag_shift: A floating point scalar to be added to the diagonal of the divisor_matrix before computing its Cholesky. cholesky_fn: Callable which takes a single (batch) matrix argument and returns a Cholesky-like lower triangular factor. Default value: `None`, in which case `make_cholesky_with_jitter_fn` is used with the `jitter` parameter. validate_args: If `True`, parameters are checked for validity despite possibly degrading runtime performance. Default value: `False` name: Python `str` name prefixed to Ops created by this class. Default value: `"SchurComplement"` _precomputed_divisor_matrix_cholesky: Internal parameter -- do not use. """ parameters = dict(locals()) # Delayed import to avoid circular dependency between `tfp.bijectors` and # `tfp.math` # pylint: disable=g-import-not-at-top from tensorflow_probability.python.bijectors import cholesky_outer_product from tensorflow_probability.python.bijectors import invert # pylint: enable=g-import-not-at-top with tf.name_scope(name) as name: dtype = dtype_util.common_dtype( [base_kernel, fixed_inputs, diag_shift, _precomputed_divisor_matrix_cholesky], tf.float32) self._base_kernel = base_kernel self._diag_shift = tensor_util.convert_nonref_to_tensor( diag_shift, dtype=dtype, name='diag_shift') self._fixed_inputs = tensor_util.convert_nonref_to_tensor( fixed_inputs, dtype=dtype, name='fixed_inputs') if ((fixed_inputs_mask is not None) and (fixed_inputs_is_missing is not None)): raise ValueError('Expected at most one of `fixed_inputs_mask` or ' '`fixed_inputs_is_missing`') self._fixed_inputs_mask = tensor_util.convert_nonref_to_tensor( fixed_inputs_mask, dtype=tf.bool, name='fixed_inputs_mask') self._fixed_inputs_is_missing = tensor_util.convert_nonref_to_tensor( fixed_inputs_is_missing, dtype=tf.bool, name='fixed_inputs_is_missing') self._cholesky_bijector = invert.Invert( cholesky_outer_product.CholeskyOuterProduct()) self._precomputed_divisor_matrix_cholesky = _precomputed_divisor_matrix_cholesky if self._precomputed_divisor_matrix_cholesky is not None: self._precomputed_divisor_matrix_cholesky = tf.convert_to_tensor( self._precomputed_divisor_matrix_cholesky, dtype) if cholesky_fn is None: from tensorflow_probability.python.distributions import cholesky_util # pylint:disable=g-import-not-at-top cholesky_fn = cholesky_util.make_cholesky_with_jitter_fn() self._cholesky_fn = cholesky_fn self._cholesky_bijector = invert.Invert( cholesky_outer_product.CholeskyOuterProduct(cholesky_fn=cholesky_fn)) super(SchurComplement, self).__init__( base_kernel.feature_ndims, dtype=dtype, name=name, parameters=parameters)