def __init__(self, tril, v, diag=None, validate_args=False): """Creates an instance of _TriLPlusVDVTLightweightOperatorPD. WARNING: This object is not to be used outside of `Affine` where it is currently being temporarily used for refactoring purposes. Args: tril: `Tensor` of shape `[B1,..,Bb, d, d]`. v: `Tensor` of shape `[B1,...,Bb, d, k]`. diag: `Tensor` of shape `[B1,...,Bb, k, k]` or None validate_args: Python `bool` indicating whether arguments should be checked for correctness. """ self._m = tril self._v = v self._validate_args = validate_args self._inputs = [tril, v] if diag is not None: self._inputs += [diag] self._d = operator_pd_diag.OperatorPDDiag(diag, verify_pd=validate_args) self._d_inv = operator_pd_diag.OperatorPDDiag( 1. / diag, verify_pd=validate_args) return if v.get_shape().is_fully_defined(): v_shape = v.get_shape().as_list() id_shape = v_shape[:-2] + [v_shape[-1], v_shape[-1]] else: v_shape = array_ops.shape(v) id_shape = array_ops.concat( [v_shape[:-2], [v_shape[-1], v_shape[-1]]], 0) self._d = operator_pd_identity.OperatorPDIdentity( id_shape, v.dtype, verify_pd=self.validate_args) self._d_inv = self._d
def __init__(self, operator, v, diag=None, verify_pd=True, verify_shapes=True, name="OperatorPDSqrtVDVTUpdate"): """Initialize an `OperatorPDSqrtVDVTUpdate`. Args: operator: Subclass of `OperatorPDBase`. Represents the (batch) positive definite matrix `M` in `R^{k x k}`. v: `Tensor` defining batch matrix of same `dtype` and `batch_shape` as `operator`, and last two dimensions of shape `(k, r)`. diag: Optional `Tensor` defining batch vector of same `dtype` and `batch_shape` as `operator`, and last dimension of size `r`. If `None`, the update becomes `VV^T` rather than `VDV^T`. verify_pd: `Boolean`. If `True`, add asserts that `diag > 0`, which, along with the positive definiteness of `operator`, is sufficient to make the resulting operator positive definite. verify_shapes: `Boolean`. If `True`, check that `operator`, `v`, and `diag` have compatible shapes. name: A name to prepend to `Op` names. """ if not isinstance(operator, operator_pd.OperatorPDBase): raise TypeError("operator was not instance of OperatorPDBase.") with ops.name_scope(name): with ops.name_scope("init", values=operator.inputs + [v, diag]): self._operator = operator self._v = ops.convert_to_tensor(v, name="v") self._verify_pd = verify_pd self._verify_shapes = verify_shapes self._name = name # This operator will be PD so long as the diag is PSD, but Woodbury # and determinant lemmas require diag to be PD. So require diag PD # whenever we ask to "verify_pd". if diag is not None: self._diag = ops.convert_to_tensor(diag, name="diag") self._diag_operator = operator_pd_diag.OperatorPDDiag( diag, verify_pd=self.verify_pd) # No need to verify that the inverse of a PD is PD. self._diag_inv_operator = operator_pd_diag.OperatorPDDiag( 1 / self._diag, verify_pd=False) else: self._diag = None self._diag_operator = self._get_identity_operator(self._v) self._diag_inv_operator = self._diag_operator self._check_types(operator, self._v, self._diag) # Always check static. checked = self._check_shapes_static(operator, self._v, self._diag) if not checked and self._verify_shapes: self._v, self._diag = self._check_shapes_dynamic( operator, self._v, self._diag)
def __init__(self, mu, diag_large, v, diag_small=None, validate_args=False, allow_nan_stats=True, name="MultivariateNormalDiagPlusVDVT"): """Multivariate Normal distributions on `R^k`. For every batch member, this distribution represents `k` random variables `(X_1,...,X_k)`, with mean `E[X_i] = mu[i]`, and covariance matrix `C_{ij} := E[(X_i - mu[i])(X_j - mu[j])]` The user initializes this class by providing the mean `mu`, and a lightweight definition of `C`: ``` C = SS^T = SS = (M + V D V^T) (M + V D V^T) M is diagonal (k x k) V = is shape (k x r), typically r << k D = is diagonal (r x r), optional (defaults to identity). ``` Args: mu: Rank `n + 1` floating point tensor with shape `[N1,...,Nn, k]`, `n >= 0`. The means. diag_large: Optional rank `n + 1` floating point tensor, shape `[N1,...,Nn, k]` `n >= 0`. Defines the diagonal matrix `M`. v: Rank `n + 1` floating point tensor, shape `[N1,...,Nn, k, r]` `n >= 0`. Defines the matrix `V`. diag_small: Rank `n + 1` floating point tensor, shape `[N1,...,Nn, k]` `n >= 0`. Defines the diagonal matrix `D`. Default is `None`, which means `D` will be the identity matrix. validate_args: `Boolean`, default `False`. Whether to validate input with asserts. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. allow_nan_stats: `Boolean`, default `True`. If `False`, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member If `True`, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. name: The name to give Ops created by the initializer. """ parameters = locals() parameters.pop("self") with ops.name_scope(name, values=[diag_large, v, diag_small]) as ns: cov = operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate( operator_pd_diag.OperatorPDDiag(diag_large, verify_pd=validate_args), v, diag=diag_small, verify_pd=validate_args, verify_shapes=validate_args) super(MultivariateNormalDiagPlusVDVT, self).__init__(mu, cov, allow_nan_stats=allow_nan_stats, validate_args=validate_args, name=ns) self._parameters = parameters
def _create_scale_operator(self, identity_multiplier, diag, tril, perturb_diag, perturb_factor, event_ndims, validate_args): """Construct `scale` from various components. Args: identity_multiplier: floating point rank 0 `Tensor` representing a scaling done to the identity matrix. diag: Floating-point `Tensor` representing the diagonal matrix. `scale_diag` has shape [N1, N2, ... k], which represents a k x k diagonal matrix. tril: Floating-point `Tensor` representing the diagonal matrix. `scale_tril` has shape [N1, N2, ... k], which represents a k x k lower triangular matrix. perturb_diag: Floating-point `Tensor` representing the diagonal matrix of the low rank update. perturb_factor: Floating-point `Tensor` representing factor matrix. event_ndims: Scalar `int32` `Tensor` indicating the number of dimensions associated with a particular draw from the distribution. Must be 0 or 1 validate_args: Python `bool` indicating whether arguments should be checked for correctness. Returns: scale. In the case of scaling by a constant, scale is a floating point `Tensor`. Otherwise, scale is an `OperatorPD`. Raises: ValueError: if all of `tril`, `diag` and `identity_multiplier` are `None`. """ identity_multiplier = _as_tensor(identity_multiplier, "identity_multiplier") diag = _as_tensor(diag, "diag") tril = _as_tensor(tril, "tril") perturb_diag = _as_tensor(perturb_diag, "perturb_diag") perturb_factor = _as_tensor(perturb_factor, "perturb_factor") identity_multiplier = self._maybe_validate_identity_multiplier( identity_multiplier, validate_args) if perturb_factor is not None: perturb_factor = self._process_matrix(perturb_factor, min_rank=2, event_ndims=event_ndims) if perturb_diag is not None: perturb_diag = self._process_matrix(perturb_diag, min_rank=1, event_ndims=event_ndims) # The following if-statments are ordered by increasingly stronger # assumptions in the base matrix, i.e., we process in the order: # TriL, Diag, Identity. if tril is not None: tril = self._preprocess_tril(identity_multiplier, diag, tril, event_ndims) if perturb_factor is None: return operator_pd_cholesky.OperatorPDCholesky( tril, verify_pd=validate_args) return _TriLPlusVDVTLightweightOperatorPD( tril=tril, v=perturb_factor, diag=perturb_diag, validate_args=validate_args) if diag is not None: diag = self._preprocess_diag(identity_multiplier, diag, event_ndims) if perturb_factor is None: return operator_pd_diag.OperatorPDSqrtDiag( diag, verify_pd=validate_args) return operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate( operator=operator_pd_diag.OperatorPDDiag( diag, verify_pd=validate_args), v=perturb_factor, diag=perturb_diag, verify_pd=validate_args) if identity_multiplier is not None: if perturb_factor is None: return identity_multiplier # Infer the shape from the V and D. v_shape = array_ops.shape(perturb_factor) identity_shape = array_ops.concat([v_shape[:-1], [v_shape[-2]]], 0) scaled_identity = operator_pd_identity.OperatorPDIdentity( identity_shape, perturb_factor.dtype.base_dtype, scale=identity_multiplier, verify_pd=validate_args) return operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate( operator=scaled_identity, v=perturb_factor, diag=perturb_diag, verify_pd=validate_args) raise ValueError( "One of tril, diag and/or identity_multiplier must be " "specified.")