def __init__(self, tril, v, diag=None, validate_args=False): """Creates an instance of _TriLPlusVDVTLightweightOperatorPD. WARNING: This object is not to be used outside of `Affine` where it is currently being temporarily used for refactoring purposes. Args: tril: `Tensor` of shape `[B1,..,Bb, d, d]`. v: `Tensor` of shape `[B1,...,Bb, d, k]`. diag: `Tensor` of shape `[B1,...,Bb, k, k]` or None validate_args: Python `bool` indicating whether arguments should be checked for correctness. """ self._m = tril self._v = v self._validate_args = validate_args self._inputs = [tril, v] if diag is not None: self._inputs += [diag] self._d = operator_pd_diag.OperatorPDDiag(diag, verify_pd=validate_args) self._d_inv = operator_pd_diag.OperatorPDDiag( 1. / diag, verify_pd=validate_args) return if v.get_shape().is_fully_defined(): v_shape = v.get_shape().as_list() id_shape = v_shape[:-2] + [v_shape[-1], v_shape[-1]] else: v_shape = array_ops.shape(v) id_shape = array_ops.concat( [v_shape[:-2], [v_shape[-1], v_shape[-1]]], 0) self._d = operator_pd_identity.OperatorPDIdentity( id_shape, v.dtype, verify_pd=self.validate_args) self._d_inv = self._d
def _build_operator_and_mat(self, batch_shape, k, dtype=np.float64): # Build an identity matrix with right shape and dtype. # Build an operator that should act the same way. batch_shape = list(batch_shape) diag_shape = batch_shape + [k] matrix_shape = batch_shape + [k, k] diag = tf.ones(diag_shape, dtype=dtype) identity_matrix = tf.batch_matrix_diag(diag) operator = operator_pd_identity.OperatorPDIdentity(matrix_shape, dtype) return operator, identity_matrix.eval()
def _build_operator_and_mat(self, batch_shape, k, dtype=np.float64): # Build an identity matrix with right shape and dtype. # Build an operator that should act the same way. batch_shape = list(batch_shape) diag_shape = batch_shape + [k] matrix_shape = batch_shape + [k, k] diag = array_ops.ones(diag_shape, dtype=dtype) scale = constant_op.constant(2.0, dtype=dtype) scaled_identity_matrix = scale * array_ops.matrix_diag(diag) operator = operator_pd_identity.OperatorPDIdentity(matrix_shape, dtype, scale=scale) return operator, scaled_identity_matrix.eval()
def _get_identity_operator(self, v): """Get an `OperatorPDIdentity` to play the role of `D` in `VDV^T`.""" with ops.op_scope([v], 'get_identity_operator'): if v.get_shape().is_fully_defined(): v_shape = v.get_shape().as_list() v_batch_shape = v_shape[:-2] r = v_shape[-1] id_shape = v_batch_shape + [r, r] else: v_shape = array_ops.shape(v) v_rank = array_ops.rank(v) v_batch_shape = array_ops.slice(v_shape, [0], [v_rank - 2]) r = array_ops.gather(v_shape, v_rank - 1) # Last dim of v id_shape = array_ops.concat(0, (v_batch_shape, [r, r])) return operator_pd_identity.OperatorPDIdentity( id_shape, v.dtype, verify_pd=self._verify_pd)
def _create_scale_operator(self, identity_multiplier, diag, tril, perturb_diag, perturb_factor, event_ndims, validate_args): """Construct `scale` from various components. Args: identity_multiplier: floating point rank 0 `Tensor` representing a scaling done to the identity matrix. diag: Floating-point `Tensor` representing the diagonal matrix. `scale_diag` has shape [N1, N2, ... k], which represents a k x k diagonal matrix. tril: Floating-point `Tensor` representing the diagonal matrix. `scale_tril` has shape [N1, N2, ... k], which represents a k x k lower triangular matrix. perturb_diag: Floating-point `Tensor` representing the diagonal matrix of the low rank update. perturb_factor: Floating-point `Tensor` representing factor matrix. event_ndims: Scalar `int32` `Tensor` indicating the number of dimensions associated with a particular draw from the distribution. Must be 0 or 1 validate_args: Python `bool` indicating whether arguments should be checked for correctness. Returns: scale. In the case of scaling by a constant, scale is a floating point `Tensor`. Otherwise, scale is an `OperatorPD`. Raises: ValueError: if all of `tril`, `diag` and `identity_multiplier` are `None`. """ identity_multiplier = _as_tensor(identity_multiplier, "identity_multiplier") diag = _as_tensor(diag, "diag") tril = _as_tensor(tril, "tril") perturb_diag = _as_tensor(perturb_diag, "perturb_diag") perturb_factor = _as_tensor(perturb_factor, "perturb_factor") identity_multiplier = self._maybe_validate_identity_multiplier( identity_multiplier, validate_args) if perturb_factor is not None: perturb_factor = self._process_matrix(perturb_factor, min_rank=2, event_ndims=event_ndims) if perturb_diag is not None: perturb_diag = self._process_matrix(perturb_diag, min_rank=1, event_ndims=event_ndims) # The following if-statments are ordered by increasingly stronger # assumptions in the base matrix, i.e., we process in the order: # TriL, Diag, Identity. if tril is not None: tril = self._preprocess_tril(identity_multiplier, diag, tril, event_ndims) if perturb_factor is None: return operator_pd_cholesky.OperatorPDCholesky( tril, verify_pd=validate_args) return _TriLPlusVDVTLightweightOperatorPD( tril=tril, v=perturb_factor, diag=perturb_diag, validate_args=validate_args) if diag is not None: diag = self._preprocess_diag(identity_multiplier, diag, event_ndims) if perturb_factor is None: return operator_pd_diag.OperatorPDSqrtDiag( diag, verify_pd=validate_args) return operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate( operator=operator_pd_diag.OperatorPDDiag( diag, verify_pd=validate_args), v=perturb_factor, diag=perturb_diag, verify_pd=validate_args) if identity_multiplier is not None: if perturb_factor is None: return identity_multiplier # Infer the shape from the V and D. v_shape = array_ops.shape(perturb_factor) identity_shape = array_ops.concat([v_shape[:-1], [v_shape[-2]]], 0) scaled_identity = operator_pd_identity.OperatorPDIdentity( identity_shape, perturb_factor.dtype.base_dtype, scale=identity_multiplier, verify_pd=validate_args) return operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate( operator=scaled_identity, v=perturb_factor, diag=perturb_diag, verify_pd=validate_args) raise ValueError( "One of tril, diag and/or identity_multiplier must be " "specified.")