def _default_event_space_bijector(self): return chain_bijector.Chain([ shift_bijector.Shift(shift=self.loc, validate_args=self.validate_args), scale_matvec_linear_operator.ScaleMatvecLinearOperator( scale=self.scale, validate_args=self.validate_args), softplus_bijector.Softplus(validate_args=self.validate_args) ], validate_args=self.validate_args)
def __init__(self, loc=None, scale=None, validate_args=False, allow_nan_stats=True, name='MultivariateNormalLinearOperator'): """Construct Multivariate Normal distribution on `R^k`. The `batch_shape` is the broadcast shape between `loc` and `scale` arguments. The `event_shape` is given by last dimension of the matrix implied by `scale`. The last dimension of `loc` (if provided) must broadcast with this. Recall that `covariance = scale @ scale.T`. Additional leading dimensions (if any) will index batches. Args: loc: Floating-point `Tensor`. If this is set to `None`, `loc` is implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where `b >= 0` and `k` is the event size. scale: Instance of `LinearOperator` with same `dtype` as `loc` and shape `[B1, ..., Bb, k, k]`. validate_args: Python `bool`, default `False`. Whether to validate input with asserts. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. allow_nan_stats: Python `bool`, default `True`. If `False`, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member If `True`, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. name: The name to give Ops created by the initializer. Raises: ValueError: if `scale` is unspecified. TypeError: if not `scale.dtype.is_floating` """ parameters = dict(locals()) if scale is None: raise ValueError('Missing required `scale` parameter.') if not dtype_util.is_floating(scale.dtype): raise TypeError( '`scale` parameter must have floating-point dtype.') with tf.name_scope(name) as name: dtype = dtype_util.common_dtype([loc, scale], dtype_hint=tf.float32) # Since expand_dims doesn't preserve constant-ness, we obtain the # non-dynamic value if possible. loc = tensor_util.convert_nonref_to_tensor(loc, dtype=dtype, name='loc') batch_shape, event_shape = distribution_util.shapes_from_loc_and_scale( loc, scale) self._loc = loc self._scale = scale bijector = scale_matvec_linear_operator.ScaleMatvecLinearOperator( scale, validate_args=validate_args) if loc is not None: bijector = shift_bijector.Shift( shift=loc, validate_args=validate_args)(bijector) super(MultivariateNormalLinearOperator, self).__init__( distribution=normal.Normal(loc=tf.zeros([], dtype=dtype), scale=tf.ones([], dtype=dtype)), bijector=bijector, batch_shape=batch_shape, event_shape=event_shape, validate_args=validate_args, name=name) self._parameters = parameters
def __init__(self, loc=None, scale=None, validate_args=False, allow_nan_stats=True, experimental_use_kahan_sum=False, name='MultivariateNormalLinearOperator'): """Construct Multivariate Normal distribution on `R^k`. The `batch_shape` is the broadcast shape between `loc` and `scale` arguments. The `event_shape` is given by last dimension of the matrix implied by `scale`. The last dimension of `loc` (if provided) must broadcast with this. Recall that `covariance = scale @ scale.T`. Additional leading dimensions (if any) will index batches. Args: loc: Floating-point `Tensor`. If this is set to `None`, `loc` is implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where `b >= 0` and `k` is the event size. scale: Instance of `LinearOperator` with same `dtype` as `loc` and shape `[B1, ..., Bb, k, k]`. validate_args: Python `bool`, default `False`. Whether to validate input with asserts. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. allow_nan_stats: Python `bool`, default `True`. If `False`, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member If `True`, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. experimental_use_kahan_sum: Python `bool`. When `True`, we use Kahan summation to aggregate independent underlying log_prob values. For best results, Kahan summation should also be applied when computing the log-determinant of the `LinearOperator` representing the scale matrix. Kahan summation improves against the precision of a naive float32 sum. This can be noticeable in particular for large dimensions in float32. See CPU caveat on `tfp.math.reduce_kahan_sum`. name: The name to give Ops created by the initializer. Raises: ValueError: if `scale` is unspecified. TypeError: if not `scale.dtype.is_floating` """ parameters = dict(locals()) self._experimental_use_kahan_sum = experimental_use_kahan_sum if scale is None: raise ValueError('Missing required `scale` parameter.') if not dtype_util.is_floating(scale.dtype): raise TypeError('`scale` parameter must have floating-point dtype.') with tf.name_scope(name) as name: dtype = dtype_util.common_dtype([loc, scale], dtype_hint=tf.float32) # Since expand_dims doesn't preserve constant-ness, we obtain the # non-dynamic value if possible. loc = tensor_util.convert_nonref_to_tensor( loc, dtype=dtype, name='loc') batch_shape, event_shape = distribution_util.shapes_from_loc_and_scale( loc, scale) self._loc = loc self._scale = scale bijector = scale_matvec_linear_operator.ScaleMatvecLinearOperator( scale, validate_args=validate_args) if loc is not None: bijector = shift_bijector.Shift( shift=loc, validate_args=validate_args)(bijector) super(MultivariateNormalLinearOperator, self).__init__( # TODO(b/137665504): Use batch-adding meta-distribution to set the batch # shape instead of tf.zeros. # We use `Sample` instead of `Independent` because `Independent` # requires concatenating `batch_shape` and `event_shape`, which loses # static `batch_shape` information when `event_shape` is not statically # known. distribution=sample.Sample( normal.Normal( loc=tf.zeros(batch_shape, dtype=dtype), scale=tf.ones([], dtype=dtype)), event_shape, experimental_use_kahan_sum=experimental_use_kahan_sum), bijector=bijector, validate_args=validate_args, name=name) self._parameters = parameters
def __init__(self, loc=None, precision_factor=None, precision=None, validate_args=False, allow_nan_stats=True, name='MultivariateNormalPrecisionFactorLinearOperator'): """Initialize distribution. Precision is the inverse of the covariance matrix, and `precision_factor @ precision_factor.T = precision`. The `batch_shape` of this distribution is the broadcast of `loc.shape[:-1]` and `precision_factor.batch_shape`. The `event_shape` of this distribution is determined by `loc.shape[-1:]`, OR `precision_factor.shape[-1:]`, which must match. Args: loc: Floating-point `Tensor`. If this is set to `None`, `loc` is implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where `b >= 0` and `k` is the event size. precision_factor: Required nonsingular `tf.linalg.LinearOperator` instance with same `dtype` and shape compatible with `loc`. precision: Optional square `tf.linalg.LinearOperator` instance with same `dtype` and shape compatible with `loc` and `precision_factor`. validate_args: Python `bool`, default `False`. Whether to validate input with asserts. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. allow_nan_stats: Python `bool`, default `True`. If `False`, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member If `True`, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. name: The name to give Ops created by the initializer. """ parameters = dict(locals()) with tf.name_scope(name) as name: if precision_factor is None: raise ValueError( 'Argument `precision_factor` must be provided. Found `None`') dtype = dtype_util.common_dtype([loc, precision_factor, precision], dtype_hint=tf.float32) loc = tensor_util.convert_nonref_to_tensor(loc, dtype=dtype, name='loc') self._loc = loc self._precision_factor = precision_factor self._precision = precision batch_shape, event_shape = distribution_util.shapes_from_loc_and_scale( loc, precision_factor) # Proof of factors (used throughout code): # Let, # C = covariance, # P = inv(covariance) = precision # P = F @ F.T (so F is the `precision_factor`). # # Then, the log prob term is # x.T @ inv(C) @ x # = x.T @ P @ x # = x.T @ F @ F.T @ x # = || F.T @ x ||**2 # notice it involves F.T, which is why we set adjoint=True in various # places. # # Also, if w ~ Normal(0, I), then we can sample by setting # x = inv(F.T) @ w + loc, # since then # E[(x - loc) @ (x - loc).T] # = E[inv(F.T) @ w @ w.T @ inv(F)] # = inv(F.T) @ inv(F) # = inv(F @ F.T) # = inv(P) # = C. if precision is not None: precision.shape.assert_is_compatible_with(precision_factor.shape) bijector = invert.Invert( scale_matvec_linear_operator.ScaleMatvecLinearOperator( scale=precision_factor, validate_args=validate_args, adjoint=True) ) if loc is not None: shift = shift_bijector.Shift(shift=loc, validate_args=validate_args) bijector = shift(bijector) super(MultivariateNormalPrecisionFactorLinearOperator, self).__init__( distribution=mvn_diag.MultivariateNormalDiag( loc=tf.zeros( ps.concat([batch_shape, event_shape], axis=0), dtype=dtype)), bijector=bijector, validate_args=validate_args, name=name) self._parameters = parameters
def __init__(self, loc=None, scale=None, validate_args=False, allow_nan_stats=True, name='VectorExponentialLinearOperator'): """Construct Vector Exponential distribution supported on a subset of `R^k`. The `batch_shape` is the broadcast shape between `loc` and `scale` arguments. The `event_shape` is given by last dimension of the matrix implied by `scale`. The last dimension of `loc` (if provided) must broadcast with this. Recall that `covariance = scale @ scale.T`. Additional leading dimensions (if any) will index batches. Args: loc: Floating-point `Tensor`. If this is set to `None`, `loc` is implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where `b >= 0` and `k` is the event size. scale: Instance of `LinearOperator` with same `dtype` as `loc` and shape `[B1, ..., Bb, k, k]`. validate_args: Python `bool`, default `False`. Whether to validate input with asserts. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. allow_nan_stats: Python `bool`, default `True`. If `False`, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member If `True`, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. name: The name to give Ops created by the initializer. Raises: ValueError: if `scale` is unspecified. TypeError: if not `scale.dtype.is_floating` """ parameters = dict(locals()) if loc is None: loc = 0.0 # Implicit value for backwards compatibility. if scale is None: raise ValueError('Missing required `scale` parameter.') if not dtype_util.is_floating(scale.dtype): raise TypeError( '`scale` parameter must have floating-point dtype.') with tf.name_scope(name) as name: # Since expand_dims doesn't preserve constant-ness, we obtain the # non-dynamic value if possible. loc = loc if loc is None else tf.convert_to_tensor( loc, name='loc', dtype=scale.dtype) batch_shape, event_shape = distribution_util.shapes_from_loc_and_scale( loc, scale) self._loc = loc self._scale = scale super(VectorExponentialLinearOperator, self).__init__( # TODO(b/137665504): Use batch-adding meta-distribution to set the # batch shape instead of tf.ones. # We use `Sample` instead of `Independent` because `Independent` # requires concatenating `batch_shape` and `event_shape`, which loses # static `batch_shape` information when `event_shape` is not # statically known. distribution=sample.Sample( exponential.Exponential(rate=tf.ones(batch_shape, dtype=scale.dtype), allow_nan_stats=allow_nan_stats), event_shape), bijector=shift_bijector.Shift(shift=loc)( scale_matvec_linear_operator.ScaleMatvecLinearOperator( scale=scale, validate_args=validate_args)), validate_args=validate_args, name=name) self._parameters = parameters
def _forward(self, x): scale = tf.linalg.LinearOperatorHouseholder(self.reflection_axis) reflection = smlo.ScaleMatvecLinearOperator(scale) return reflection.forward(x)