def __init__(self, loc=None, scale_tril_log_diag=None, validate_args=False, allow_nan_stats=True, name="MultivariateNormalTriL"): parameters = dict(locals()) def _convert_to_tensor(x, name, dtype): return None if x is None else tf.convert_to_tensor( x, name=name, dtype=dtype) if loc is None and scale_tril_log_diag is None: raise ValueError( "Must specify one or both of `loc`, `scale_tril`.") with tf.name_scope(name) as name: with tf.name_scope("init", values=[loc, scale_tril_log_diag]): dtype = dtype_util.common_dtype([loc, scale_tril_log_diag], tf.float32) loc = _convert_to_tensor(loc, name="loc", dtype=dtype) scale_tril_log_diag = _convert_to_tensor(scale_tril_log_diag, name="scale_tril", dtype=dtype) if scale_tril_log_diag is None: scale = custom_lin_op.LinearOperatorLogIdentity( num_rows=distribution_util.dimension_size(loc, -1), dtype=loc.dtype, is_self_adjoint=True, is_positive_definite=True, assert_proper_shapes=validate_args) else: # No need to validate that scale_tril is non-singular. # LinearOperatorLowerTriangular has an assert_non_singular # method that is called by the Bijector. scale = custom_lin_op.LinearOperatorLowerTriangularLogDiagonal( scale_tril_log_diag, is_non_singular=True, is_self_adjoint=False, is_positive_definite=False) super(MultivariateNormalTriLLogDiagonal, self).__init__(loc=loc, scale=scale, validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=name) self._parameters = parameters
def __init__(self, loc=None, scale_diag=None, scale_perturb_factor=None, scale_perturb_diag=None, validate_args=False, allow_nan_stats=True, name='MultivariateNormalDiagPlusLowRank'): """Construct Multivariate Normal distribution on `R^k`. The `batch_shape` is the broadcast shape between `loc` and `scale` arguments. The `event_shape` is given by last dimension of the matrix implied by `scale`. The last dimension of `loc` (if provided) must broadcast with this. Recall that `covariance = scale @ scale.T`. A (non-batch) `scale` matrix is: ```none scale = diag(scale_diag) + scale_perturb_factor @ diag(scale_perturb_diag) @ scale_perturb_factor.T ``` where: * `scale_diag.shape = [k]`, * `scale_perturb_factor.shape = [k, r]`, typically `k >> r`, and, * `scale_perturb_diag.shape = [r]`. Additional leading dimensions (if any) will index batches. Args: loc: Floating-point `Tensor`. If this is set to `None`, `loc` is implicitly `0`. When specified, must have shape `[B1, ..., Bb, k]` where `b >= 0` and `k` is the event size. scale_diag: Floating-point `Tensor` representing a non-singular diagonal matrix added to `scale`. Must have shape `[B1, ..., Bb, k]`, `b >= 0`, and characterizes `b`-batches of `k x k` diagonal matrices added to `scale`. When `scale_diag` is `None` it defaults to the `Identity` matrix. scale_perturb_factor: Floating-point `Tensor` representing a rank-`r` perturbation added to `scale`. Must have shape `[B1, ..., Bb, k, r]`, `b >= 0`, and characterizes `b`-batches of rank-`r` updates to `scale`. When `None`, no rank-`r` update is added to `scale`. scale_perturb_diag: Floating-point `Tensor` representing a non-singular diagonal matrix inside the rank-`r` perturbation added to `scale`. Must have shape `[B1, ..., Bb, r]`, `b >= 0`, and characterizes `b`-batches of `r x r` diagonal matrices inside the perturbation added to `scale`. When `None`, an identity matrix is used inside the perturbation. Can only be specified if `scale_perturb_factor` is also specified. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value '`NaN`' to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: ValueError: if at most `scale_identity_multiplier` is specified. """ parameters = dict(locals()) if all(x is None for x in [loc, scale_diag, scale_perturb_factor]): raise ValueError('At least one of `loc`, `scale_diag`, or ' '`scale_perturb_factor` required.') if scale_perturb_diag is not None and scale_perturb_factor is None: raise ValueError('`scale_perturb_diag` should be set only if ' '`scale_perturb_factor` is also set.') with tf.name_scope(name) as name: dtype = dtype_util.common_dtype( [loc, scale_diag, scale_perturb_factor, scale_perturb_diag], tf.float32) loc = tensor_util.convert_nonref_to_tensor(loc, dtype=dtype, name='loc') scale_diag = tensor_util.convert_nonref_to_tensor( scale_diag, dtype=dtype, name='scale_diag') scale_perturb_factor = tensor_util.convert_nonref_to_tensor( scale_perturb_factor, dtype=dtype, name='scale_perturb_factor') scale_perturb_diag = tensor_util.convert_nonref_to_tensor( scale_perturb_diag, dtype=dtype, name='scale_perturb_diag') if scale_diag is not None: scale = tf.linalg.LinearOperatorDiag( diag=scale_diag, is_non_singular=True, is_self_adjoint=True, is_positive_definite=False) else: # This might not be tape-safe if shape unknown and assigned to later. if loc is not None: num_rows = distribution_util.dimension_size(loc, -1) else: num_rows = distribution_util.dimension_size( scale_perturb_factor, -2) scale = tf.linalg.LinearOperatorIdentity( num_rows=num_rows, dtype=dtype, is_self_adjoint=True, is_positive_definite=True, assert_proper_shapes=validate_args) if scale_perturb_factor is not None: scale = tf.linalg.LinearOperatorLowRankUpdate( scale, u=scale_perturb_factor, diag_update=scale_perturb_diag, is_diag_update_positive=scale_perturb_diag is None, is_self_adjoint=True, is_non_singular=True, is_square=True) super(MultivariateNormalDiagPlusLowRank, self).__init__(loc=loc, scale=scale, validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=name) self._parameters = parameters
def __init__(self, loc=None, scale_tril=None, validate_args=False, allow_nan_stats=True, name='MultivariateNormalTriL'): """Construct Multivariate Normal distribution on `R^k`. The `batch_shape` is the broadcast shape between `loc` and `scale` arguments. The `event_shape` is given by last dimension of the matrix implied by `scale`. The last dimension of `loc` (if provided) must broadcast with this. Recall that `covariance = scale @ scale.T`. A (non-batch) `scale` matrix is: ```none scale = scale_tril ``` where `scale_tril` is lower-triangular `k x k` matrix with non-zero diagonal, i.e., `tf.diag_part(scale_tril) != 0`. Additional leading dimensions (if any) will index batches. Args: loc: Floating-point `Tensor`. If this is set to `None`, `loc` is implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where `b >= 0` and `k` is the event size. scale_tril: Floating-point, lower-triangular `Tensor` with non-zero diagonal elements. `scale_tril` has shape `[B1, ..., Bb, k, k]` where `b >= 0` and `k` is the event size. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: ValueError: if neither `loc` nor `scale_tril` are specified. """ parameters = dict(locals()) if loc is None and scale_tril is None: raise ValueError( 'Must specify one or both of `loc`, `scale_tril`.') with tf.name_scope(name) as name: dtype = dtype_util.common_dtype([loc, scale_tril], tf.float32) loc = tensor_util.convert_nonref_to_tensor(loc, name='loc', dtype=dtype) scale_tril = tensor_util.convert_nonref_to_tensor( scale_tril, name='scale_tril', dtype=dtype) self._scale_tril = scale_tril if scale_tril is None: scale = tf.linalg.LinearOperatorIdentity( num_rows=distribution_util.dimension_size(loc, -1), dtype=loc.dtype, is_self_adjoint=True, is_positive_definite=True, assert_proper_shapes=validate_args) else: # No need to validate that scale_tril is non-singular. # LinearOperatorLowerTriangular has an assert_non_singular # method that is called by the Bijector. scale = tf.linalg.LinearOperatorLowerTriangular( scale_tril, is_non_singular=True, is_self_adjoint=False, is_positive_definite=False) super(MultivariateNormalTriL, self).__init__(loc=loc, scale=scale, validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=name) self._parameters = parameters
def _create_scale_operator(self, identity_multiplier, diag, tril, perturb_diag, perturb_factor, shift, validate_args, dtype): """Construct `scale` from various components. Args: identity_multiplier: floating point rank 0 `Tensor` representing a scaling done to the identity matrix. diag: Floating-point `Tensor` representing the diagonal matrix.`diag` has shape `[N1, N2, ... k]`, which represents a k x k diagonal matrix. tril: Floating-point `Tensor` representing the lower triangular matrix. `tril` has shape `[N1, N2, ... k, k]`, which represents a k x k lower triangular matrix. perturb_diag: Floating-point `Tensor` representing the diagonal matrix of the low rank update. perturb_factor: Floating-point `Tensor` representing factor matrix. shift: Floating-point `Tensor` representing `shift in `scale @ X + shift`. validate_args: Python `bool` indicating whether arguments should be checked for correctness. dtype: `DType` for arg `Tensor` conversions. Returns: scale. In the case of scaling by a constant, scale is a floating point `Tensor`. Otherwise, scale is a `LinearOperator`. Raises: ValueError: if all of `tril`, `diag` and `identity_multiplier` are `None`. """ identity_multiplier = _as_tensor(identity_multiplier, "identity_multiplier", dtype) diag = _as_tensor(diag, "diag", dtype) tril = _as_tensor(tril, "tril", dtype) perturb_diag = _as_tensor(perturb_diag, "perturb_diag", dtype) perturb_factor = _as_tensor(perturb_factor, "perturb_factor", dtype) # If possible, use the low rank update to infer the shape of # the identity matrix, when scale represents a scaled identity matrix # with a low rank update. shape_hint = None if perturb_factor is not None: shape_hint = distribution_util.dimension_size(perturb_factor, axis=-2) if self._is_only_identity_multiplier: if validate_args: return distribution_util.with_dependencies([ assert_util.assert_none_equal( identity_multiplier, tf.zeros([], identity_multiplier.dtype), ["identity_multiplier should be non-zero."]) ], identity_multiplier) return identity_multiplier scale = _make_tril_scale(loc=shift, scale_tril=tril, scale_diag=diag, scale_identity_multiplier=identity_multiplier, validate_args=validate_args, assert_positive=False, shape_hint=shape_hint) if perturb_factor is not None: return tf.linalg.LinearOperatorLowRankUpdate( scale, u=perturb_factor, diag_update=perturb_diag, is_diag_update_positive=perturb_diag is None, is_non_singular=True, # Implied by is_positive_definite=True. is_self_adjoint=True, is_positive_definite=True, is_square=True) return scale
def __init__(self, loc=None, scale_tril=None, validate_args=False, allow_nan_stats=True, name="MultivariateNormalTriL"): """Construct Multivariate Normal distribution on `R^k`. The `batch_shape` is the broadcast shape between `loc` and `scale` arguments. The `event_shape` is given by last dimension of the matrix implied by `scale`. The last dimension of `loc` (if provided) must broadcast with this. Recall that `covariance = scale @ scale.T`. A (non-batch) `scale` matrix is: ```none scale = scale_tril ``` where `scale_tril` is lower-triangular `k x k` matrix with non-zero diagonal, i.e., `tf.diag_part(scale_tril) != 0`. Additional leading dimensions (if any) will index batches. Args: loc: Floating-point `Tensor`. If this is set to `None`, `loc` is implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where `b >= 0` and `k` is the event size. scale_tril: Floating-point, lower-triangular `Tensor` with non-zero diagonal elements. `scale_tril` has shape `[B1, ..., Bb, k, k]` where `b >= 0` and `k` is the event size. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: ValueError: if neither `loc` nor `scale_tril` are specified. """ parameters = dict(locals()) def _convert_to_tensor(x, name, dtype): return None if x is None else tf.convert_to_tensor( x, name=name, dtype=dtype) if loc is None and scale_tril is None: raise ValueError("Must specify one or both of `loc`, `scale_tril`.") with tf.name_scope(name) as name: with tf.name_scope("init", values=[loc, scale_tril]): dtype = dtype_util.common_dtype([loc, scale_tril], tf.float32) loc = _convert_to_tensor(loc, name="loc", dtype=dtype) scale_tril = _convert_to_tensor( scale_tril, name="scale_tril", dtype=dtype) if scale_tril is None: scale = tf.linalg.LinearOperatorIdentity( num_rows=distribution_util.dimension_size(loc, -1), dtype=loc.dtype, is_self_adjoint=True, is_positive_definite=True, assert_proper_shapes=validate_args) else: # No need to validate that scale_tril is non-singular. # LinearOperatorLowerTriangular has an assert_non_singular # method that is called by the Bijector. scale = tf.linalg.LinearOperatorLowerTriangular( scale_tril, is_non_singular=True, is_self_adjoint=False, is_positive_definite=False) super(MultivariateNormalTriL, self).__init__( loc=loc, scale=scale, validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=name) self._parameters = parameters
def _create_scale_operator(self, identity_multiplier, diag, tril, perturb_diag, perturb_factor, shift, validate_args): """Construct `scale` from various components. Args: identity_multiplier: floating point rank 0 `Tensor` representing a scaling done to the identity matrix. diag: Floating-point `Tensor` representing the diagonal matrix. `scale_diag` has shape [N1, N2, ... k], which represents a k x k diagonal matrix. tril: Floating-point `Tensor` representing the diagonal matrix. `scale_tril` has shape [N1, N2, ... k], which represents a k x k lower triangular matrix. perturb_diag: Floating-point `Tensor` representing the diagonal matrix of the low rank update. perturb_factor: Floating-point `Tensor` representing factor matrix. shift: Floating-point `Tensor` representing `shift in `scale @ X + shift`. validate_args: Python `bool` indicating whether arguments should be checked for correctness. Returns: scale. In the case of scaling by a constant, scale is a floating point `Tensor`. Otherwise, scale is a `LinearOperator`. Raises: ValueError: if all of `tril`, `diag` and `identity_multiplier` are `None`. """ identity_multiplier = _as_tensor(identity_multiplier, "identity_multiplier") diag = _as_tensor(diag, "diag") tril = _as_tensor(tril, "tril") perturb_diag = _as_tensor(perturb_diag, "perturb_diag") perturb_factor = _as_tensor(perturb_factor, "perturb_factor") # If possible, use the low rank update to infer the shape of # the identity matrix, when scale represents a scaled identity matrix # with a low rank update. shape_hint = None if perturb_factor is not None: shape_hint = distribution_util.dimension_size(perturb_factor, axis=-2) if self._is_only_identity_multiplier: if validate_args: return control_flow_ops.with_dependencies([ tf.assert_none_equal(identity_multiplier, tf.zeros([], identity_multiplier.dtype), ["identity_multiplier should be non-zero."]) ], identity_multiplier) return identity_multiplier scale = distribution_util.make_tril_scale( loc=shift, scale_tril=tril, scale_diag=diag, scale_identity_multiplier=identity_multiplier, validate_args=validate_args, assert_positive=False, shape_hint=shape_hint) if perturb_factor is not None: return tf.linalg.LinearOperatorLowRankUpdate( scale, u=perturb_factor, diag_update=perturb_diag, is_diag_update_positive=perturb_diag is None, is_non_singular=True, # Implied by is_positive_definite=True. is_self_adjoint=True, is_positive_definite=True, is_square=True) return scale