def _create_scale_operator(self, identity_multiplier, diag, tril, perturb_diag, perturb_factor, shift, validate_args): """Construct `scale` from various components. Args: identity_multiplier: floating point rank 0 `Tensor` representing a scaling done to the identity matrix. diag: Floating-point `Tensor` representing the diagonal matrix. `scale_diag` has shape [N1, N2, ... k], which represents a k x k diagonal matrix. tril: Floating-point `Tensor` representing the diagonal matrix. `scale_tril` has shape [N1, N2, ... k], which represents a k x k lower triangular matrix. perturb_diag: Floating-point `Tensor` representing the diagonal matrix of the low rank update. perturb_factor: Floating-point `Tensor` representing factor matrix. shift: Floating-point `Tensor` representing `shift in `scale @ X + shift`. validate_args: Python `bool` indicating whether arguments should be checked for correctness. Returns: scale. In the case of scaling by a constant, scale is a floating point `Tensor`. Otherwise, scale is a `LinearOperator`. Raises: ValueError: if all of `tril`, `diag` and `identity_multiplier` are `None`. """ identity_multiplier = _as_tensor(identity_multiplier, "identity_multiplier") diag = _as_tensor(diag, "diag") tril = _as_tensor(tril, "tril") perturb_diag = _as_tensor(perturb_diag, "perturb_diag") perturb_factor = _as_tensor(perturb_factor, "perturb_factor") # If possible, use the low rank update to infer the shape of # the identity matrix, when scale represents a scaled identity matrix # with a low rank update. shape_hint = None if perturb_factor is not None: shape_hint = distribution_util.dimension_size(perturb_factor, axis=-2) if self._is_only_identity_multiplier: if validate_args: return control_flow_ops.with_dependencies([ check_ops.assert_none_equal( identity_multiplier, array_ops.zeros([], identity_multiplier.dtype), ["identity_multiplier should be non-zero."]) ], identity_multiplier) return identity_multiplier scale = distribution_util.make_tril_scale( loc=shift, scale_tril=tril, scale_diag=diag, scale_identity_multiplier=identity_multiplier, validate_args=validate_args, assert_positive=False, shape_hint=shape_hint) if perturb_factor is not None: return linalg.LinearOperatorLowRankUpdate( scale, u=perturb_factor, diag_update=perturb_diag, is_diag_update_positive=perturb_diag is None, is_non_singular=True, # Implied by is_positive_definite=True. is_self_adjoint=True, is_positive_definite=True, is_square=True) return scale
def __init__(self, loc=None, scale_diag=None, scale_identity_multiplier=None, scale_perturb_factor=None, scale_perturb_diag=None, validate_args=False, allow_nan_stats=True, name="MultivariateNormalDiagPlusLowRank"): """Construct Multivariate Normal distribution on `R^k`. The `batch_shape` is the broadcast shape between `loc` and `scale` arguments. The `event_shape` is given by last dimension of the matrix implied by `scale`. The last dimension of `loc` (if provided) must broadcast with this. Recall that `covariance = scale @ scale.T`. A (non-batch) `scale` matrix is: ```none scale = diag(scale_diag + scale_identity_multiplier ones(k)) + scale_perturb_factor @ diag(scale_perturb_diag) @ scale_perturb_factor.T ``` where: * `scale_diag.shape = [k]`, * `scale_identity_multiplier.shape = []`, * `scale_perturb_factor.shape = [k, r]`, typically `k >> r`, and, * `scale_perturb_diag.shape = [r]`. Additional leading dimensions (if any) will index batches. If both `scale_diag` and `scale_identity_multiplier` are `None`, then `scale` is the Identity matrix. Args: loc: Floating-point `Tensor`. If this is set to `None`, `loc` is implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where `b >= 0` and `k` is the event size. scale_diag: Non-zero, floating-point `Tensor` representing a diagonal matrix added to `scale`. May have shape `[B1, ..., Bb, k]`, `b >= 0`, and characterizes `b`-batches of `k x k` diagonal matrices added to `scale`. When both `scale_identity_multiplier` and `scale_diag` are `None` then `scale` is the `Identity`. scale_identity_multiplier: Non-zero, floating-point `Tensor` representing a scaled-identity-matrix added to `scale`. May have shape `[B1, ..., Bb]`, `b >= 0`, and characterizes `b`-batches of scaled `k x k` identity matrices added to `scale`. When both `scale_identity_multiplier` and `scale_diag` are `None` then `scale` is the `Identity`. scale_perturb_factor: Floating-point `Tensor` representing a rank-`r` perturbation added to `scale`. May have shape `[B1, ..., Bb, k, r]`, `b >= 0`, and characterizes `b`-batches of rank-`r` updates to `scale`. When `None`, no rank-`r` update is added to `scale`. scale_perturb_diag: Floating-point `Tensor` representing a diagonal matrix inside the rank-`r` perturbation added to `scale`. May have shape `[B1, ..., Bb, r]`, `b >= 0`, and characterizes `b`-batches of `r x r` diagonal matrices inside the perturbation added to `scale`. When `None`, an identity matrix is used inside the perturbation. Can only be specified if `scale_perturb_factor` is also specified. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: ValueError: if at most `scale_identity_multiplier` is specified. """ parameters = dict(locals()) def _convert_to_tensor(x, name): return None if x is None else ops.convert_to_tensor(x, name=name) with ops.name_scope(name) as name: with ops.name_scope("init", values=[ loc, scale_diag, scale_identity_multiplier, scale_perturb_factor, scale_perturb_diag ]): has_low_rank = (scale_perturb_factor is not None or scale_perturb_diag is not None) scale = distribution_util.make_diag_scale( loc=loc, scale_diag=scale_diag, scale_identity_multiplier=scale_identity_multiplier, validate_args=validate_args, assert_positive=has_low_rank) scale_perturb_factor = _convert_to_tensor( scale_perturb_factor, name="scale_perturb_factor") scale_perturb_diag = _convert_to_tensor( scale_perturb_diag, name="scale_perturb_diag") if has_low_rank: scale = linalg.LinearOperatorLowRankUpdate( scale, u=scale_perturb_factor, diag_update=scale_perturb_diag, is_diag_update_positive=scale_perturb_diag is None, is_non_singular= True, # Implied by is_positive_definite=True. is_self_adjoint=True, is_positive_definite=True, is_square=True) super(MultivariateNormalDiagPlusLowRank, self).__init__(loc=loc, scale=scale, validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=name) self._parameters = parameters