def __init__(self, loc=None, scale=None, validate_args=False, allow_nan_stats=True, name="MultivariateNormalLinearOperator"): parameters = dict(locals()) if scale is None: raise ValueError("Missing required `scale` parameter.") if not scale.dtype.is_floating: raise TypeError( "`scale` parameter must have floating-point dtype.") with tf.name_scope(name, values=[loc] + scale.graph_parents) as name: # Since expand_dims doesn't preserve constant-ness, we obtain the # non-dynamic value if possible. loc = loc if loc is None else tf.convert_to_tensor( loc, name="loc", dtype=scale.dtype) batch_shape, event_shape = distribution_util.shapes_from_loc_and_scale( loc, scale) super(MultivariateNormalLogLinearOperator, self).__init__( distribution=normal.Normal(loc=tf.zeros([], dtype=scale.dtype), scale=tf.ones([], dtype=scale.dtype)), bijector=affine_linear_operator_bijector.AffineLinearOperator( shift=loc, scale=scale, validate_args=validate_args), batch_shape=batch_shape, event_shape=event_shape, validate_args=validate_args, name=name) self._parameters = parameters
def __init__(self, loc=None, scale=None, validate_args=False, allow_nan_stats=True, name="VectorLaplaceLinearOperator"): """Construct Vector Laplace distribution on `R^k`. The `batch_shape` is the broadcast shape between `loc` and `scale` arguments. The `event_shape` is given by last dimension of the matrix implied by `scale`. The last dimension of `loc` (if provided) must broadcast with this. Recall that `covariance = 2 * scale @ scale.T`. Additional leading dimensions (if any) will index batches. Args: loc: Floating-point `Tensor`. If this is set to `None`, `loc` is implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where `b >= 0` and `k` is the event size. scale: Instance of `LinearOperator` with same `dtype` as `loc` and shape `[B1, ..., Bb, k, k]`. validate_args: Python `bool`, default `False`. Whether to validate input with asserts. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. allow_nan_stats: Python `bool`, default `True`. If `False`, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member If `True`, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. name: The name to give Ops created by the initializer. Raises: ValueError: if `scale` is unspecified. TypeError: if not `scale.dtype.is_floating` """ parameters = dict(locals()) if scale is None: raise ValueError("Missing required `scale` parameter.") if not dtype_util.is_floating(scale.dtype): raise TypeError("`scale` parameter must have floating-point dtype.") with tf.name_scope(name): # Since expand_dims doesn't preserve constant-ness, we obtain the # non-dynamic value if possible. loc = loc if loc is None else tf.convert_to_tensor( loc, name="loc", dtype=scale.dtype) batch_shape, event_shape = distribution_util.shapes_from_loc_and_scale( loc, scale) super(VectorLaplaceLinearOperator, self).__init__( distribution=laplace.Laplace( loc=tf.zeros([], dtype=scale.dtype), scale=tf.ones([], dtype=scale.dtype)), bijector=affine_linear_operator_bijector.AffineLinearOperator( shift=loc, scale=scale, validate_args=validate_args), batch_shape=batch_shape, event_shape=event_shape, validate_args=validate_args, name=name) self._parameters = parameters
def __init__(self, mix_loc, temperature, distribution, loc=None, scale=None, quadrature_size=8, quadrature_fn=quadrature_scheme_softmaxnormal_quantiles, validate_args=False, allow_nan_stats=True, name="VectorDiffeomixture"): """Constructs the VectorDiffeomixture on `R^d`. The vector diffeomixture (VDM) approximates the compound distribution ```none p(x) = int p(x | z) p(z) dz, where z is in the K-simplex, and p(x | z) := p(x | loc=sum_k z[k] loc[k], scale=sum_k z[k] scale[k]) ``` Args: mix_loc: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`. In terms of samples, larger `mix_loc[..., k]` ==> `Z` is more likely to put more weight on its `kth` component. temperature: `float`-like `Tensor`. Broadcastable with `mix_loc`. In terms of samples, smaller `temperature` means one component is more likely to dominate. I.e., smaller `temperature` makes the VDM look more like a standard mixture of `K` components. distribution: `tfp.distributions.Distribution`-like instance. Distribution from which `d` iid samples are used as input to the selected affine transformation. Must be a scalar-batch, scalar-event distribution. Typically `distribution.reparameterization_type = FULLY_REPARAMETERIZED` or it is a function of non-trainable parameters. WARNING: If you backprop through a VectorDiffeomixture sample and the `distribution` is not `FULLY_REPARAMETERIZED` yet is a function of trainable variables, then the gradient will be incorrect! loc: Length-`K` list of `float`-type `Tensor`s. The `k`-th element represents the `shift` used for the `k`-th affine transformation. If the `k`-th item is `None`, `loc` is implicitly `0`. When specified, must have shape `[B1, ..., Bb, d]` where `b >= 0` and `d` is the event size. scale: Length-`K` list of `LinearOperator`s. Each should be positive-definite and operate on a `d`-dimensional vector space. The `k`-th element represents the `scale` used for the `k`-th affine transformation. `LinearOperator`s must have shape `[B1, ..., Bb, d, d]`, `b >= 0`, i.e., characterizes `b`-batches of `d x d` matrices quadrature_size: Python `int` scalar representing number of quadrature points. Larger `quadrature_size` means `q_N(x)` better approximates `p(x)`. quadrature_fn: Python callable taking `normal_loc`, `normal_scale`, `quadrature_size`, `validate_args` and returning `tuple(grid, probs)` representing the SoftmaxNormal grid and corresponding normalized weight. normalized) weight. Default value: `quadrature_scheme_softmaxnormal_quantiles`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: ValueError: if `not scale or len(scale) < 2`. ValueError: if `len(loc) != len(scale)` ValueError: if `quadrature_grid_and_probs is not None` and `len(quadrature_grid_and_probs[0]) != len(quadrature_grid_and_probs[1])` ValueError: if `validate_args` and any not scale.is_positive_definite. TypeError: if any scale.dtype != scale[0].dtype. TypeError: if any loc.dtype != scale[0].dtype. NotImplementedError: if `len(scale) != 2`. ValueError: if `not distribution.is_scalar_batch`. ValueError: if `not distribution.is_scalar_event`. """ parameters = dict(locals()) with tf.name_scope(name) as name: if not scale or len(scale) < 2: raise ValueError( "Must specify list (or list-like object) of scale " "LinearOperators, one for each component with " "num_component >= 2.") if loc is None: loc = [None] * len(scale) if len(loc) != len(scale): raise ValueError("loc/scale must be same-length lists " "(or same-length list-like objects).") dtype = dtype_util.base_dtype(scale[0].dtype) loc = [ tf.convert_to_tensor( value=loc_, dtype=dtype, name="loc{}".format(k)) if loc_ is not None else None for k, loc_ in enumerate(loc) ] for k, scale_ in enumerate(scale): if validate_args and not scale_.is_positive_definite: raise ValueError( "scale[{}].is_positive_definite = {} != True".format( k, scale_.is_positive_definite)) if dtype_util.base_dtype(scale_.dtype) != dtype: raise TypeError( "dtype mismatch; scale[{}].base_dtype=\"{}\" != \"{}\"" .format(k, dtype_util.name(scale_.dtype), dtype_util.name(dtype))) self._endpoint_affine = [ affine_linear_operator_bijector.AffineLinearOperator( # pylint: disable=g-complex-comprehension shift=loc_, scale=scale_, validate_args=validate_args, name="endpoint_affine_{}".format(k)) for k, (loc_, scale_) in enumerate(zip(loc, scale)) ] # TODO(jvdillon): Remove once we support k-mixtures. # We make this assertion here because otherwise `grid` would need to be a # vector not a scalar. if len(scale) != 2: raise NotImplementedError( "Currently only bimixtures are supported; " "len(scale)={} is not 2.".format(len(scale))) mix_loc = tf.convert_to_tensor(value=mix_loc, dtype=dtype, name="mix_loc") temperature = tf.convert_to_tensor(value=temperature, dtype=dtype, name="temperature") self._grid, probs = tuple( quadrature_fn(mix_loc / temperature, 1. / temperature, quadrature_size, validate_args)) # Note: by creating the logits as `log(prob)` we ensure that # `self.mixture_distribution.logits` is equivalent to # `math_ops.log(self.mixture_distribution.probs)`. self._mixture_distribution = categorical.Categorical( logits=tf.math.log(probs), validate_args=validate_args, allow_nan_stats=allow_nan_stats) asserts = distribution_util.maybe_check_scalar_distribution( distribution, dtype, validate_args) if asserts: self._grid = distribution_util.with_dependencies( asserts, self._grid) self._distribution = distribution self._interpolated_affine = [ affine_linear_operator_bijector.AffineLinearOperator( # pylint: disable=g-complex-comprehension shift=loc_, scale=scale_, validate_args=validate_args, name="interpolated_affine_{}".format(k)) for k, (loc_, scale_) in enumerate( zip(interpolate_loc(self._grid, loc), interpolate_scale(self._grid, scale))) ] [ self._batch_shape_, self._batch_shape_tensor_, self._event_shape_, self._event_shape_tensor_, ] = determine_batch_event_shapes(self._grid, self._endpoint_affine) super(VectorDiffeomixture, self).__init__( dtype=dtype, # We hard-code `FULLY_REPARAMETERIZED` because when # `validate_args=True` we verify that indeed # `distribution.reparameterization_type == FULLY_REPARAMETERIZED`. A # distribution which is a function of only non-trainable parameters # also implies we can use `FULLY_REPARAMETERIZED`. However, we cannot # easily test for that possibility thus we use `validate_args=False` # as a "back-door" to allow users a way to use non # `FULLY_REPARAMETERIZED` distribution. In such cases IT IS THE USERS # RESPONSIBILITY to verify that the base distribution is a function of # non-trainable parameters. reparameterization_type=reparameterization. FULLY_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=( distribution._graph_parents # pylint: disable=protected-access + [loc_ for loc_ in loc if loc_ is not None] + [p for scale_ in scale for p in scale_.graph_parents]), # pylint: disable=g-complex-comprehension name=name)
def __init__(self, loc=None, scale=None, validate_args=False, allow_nan_stats=True, name='VectorExponentialLinearOperator'): """Construct Vector Exponential distribution supported on a subset of `R^k`. The `batch_shape` is the broadcast shape between `loc` and `scale` arguments. The `event_shape` is given by last dimension of the matrix implied by `scale`. The last dimension of `loc` (if provided) must broadcast with this. Recall that `covariance = scale @ scale.T`. Additional leading dimensions (if any) will index batches. Args: loc: Floating-point `Tensor`. If this is set to `None`, `loc` is implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where `b >= 0` and `k` is the event size. scale: Instance of `LinearOperator` with same `dtype` as `loc` and shape `[B1, ..., Bb, k, k]`. validate_args: Python `bool`, default `False`. Whether to validate input with asserts. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. allow_nan_stats: Python `bool`, default `True`. If `False`, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member If `True`, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. name: The name to give Ops created by the initializer. Raises: ValueError: if `scale` is unspecified. TypeError: if not `scale.dtype.is_floating` """ parameters = dict(locals()) if scale is None: raise ValueError('Missing required `scale` parameter.') if not dtype_util.is_floating(scale.dtype): raise TypeError( '`scale` parameter must have floating-point dtype.') with tf.name_scope(name) as name: # Since expand_dims doesn't preserve constant-ness, we obtain the # non-dynamic value if possible. loc = loc if loc is None else tf.convert_to_tensor( loc, name='loc', dtype=scale.dtype) batch_shape, event_shape = distribution_util.shapes_from_loc_and_scale( loc, scale) super(VectorExponentialLinearOperator, self).__init__( # TODO(b/137665504): Use batch-adding meta-distribution to set the # batch shape instead of tf.ones. # We use `Sample` instead of `Independent` because `Independent` # requires concatenating `batch_shape` and `event_shape`, which loses # static `batch_shape` information when `event_shape` is not # statically known. distribution=sample.Sample( exponential.Exponential(rate=tf.ones(batch_shape, dtype=scale.dtype), allow_nan_stats=allow_nan_stats), event_shape), bijector=affine_linear_operator_bijector.AffineLinearOperator( shift=loc, scale=scale, validate_args=validate_args), validate_args=validate_args, name=name) self._parameters = parameters