コード例 #1
0
 def _inverse_log_det_jacobian(self, y):
     # If event_ndims = 2,
     # F^{-1}(y) = (-y, y), so DF^{-1}(y) = (-1, 1),
     # so Log|DF^{-1}(y)| = Log[1, 1] = [0, 0].
     with tf.control_dependencies(self._assertions(y)):
         zero = tf.zeros([], dtype=dtype_util.base_dtype(y.dtype))
         return zero, zero
コード例 #2
0
    def _cdf(self, k):
        # TODO(b/135263541): Improve numerical precision of categorical.cdf.
        probs = self.probs_parameter()
        num_categories = self._num_categories(probs)

        k, probs = _broadcast_cat_event_and_params(
            k, probs, base_dtype=dtype_util.base_dtype(self.dtype))

        # Since the lowest number in the support is 0, any k < 0 should be zero in
        # the output.
        should_be_zero = k < 0

        # Will use k as an index in the gather below, so clip it to {0,...,K-1}.
        k = tf.clip_by_value(tf.cast(k, tf.int32), 0, num_categories - 1)

        batch_shape = tf.shape(k)

        # tf.gather(..., batch_dims=batch_dims) requires static batch_dims kwarg, so
        # to handle the case where the batch shape is dynamic, flatten the batch
        # dims (so we know batch_dims=1).
        k_flat_batch = tf.reshape(k, [-1])
        probs_flat_batch = tf.reshape(
            probs, tf.concat(([-1], [num_categories]), axis=0))

        cdf_flat = tf.gather(tf.cumsum(probs_flat_batch, axis=-1),
                             k_flat_batch[..., tf.newaxis],
                             batch_dims=1)

        cdf = tf.reshape(cdf_flat, shape=batch_shape)

        zero = np.array(0, dtype=dtype_util.as_numpy_dtype(cdf.dtype))
        return tf.where(should_be_zero, zero, cdf)
コード例 #3
0
 def _log_prob(self, k):
     logits = self.logits_parameter()
     if self.validate_args:
         k = distribution_util.embed_check_integer_casting_closed(
             k, target_dtype=tf.int32)
     k, logits = _broadcast_cat_event_and_params(
         k, logits, base_dtype=dtype_util.base_dtype(self.dtype))
     return -tf.nn.sparse_softmax_cross_entropy_with_logits(labels=k,
                                                            logits=logits)
コード例 #4
0
 def _forward_log_det_jacobian(self, x):
     # For a discussion of this (non-obvious) result, see Note 7.2.2 (and the
     # sections leading up to it, for context) in
     # http://neutrino.aquaphoenix.com/ReactionDiffusion/SERC5chap7.pdf
     with tf.control_dependencies(self._assertions(x)):
         matrix_dim = tf.cast(
             tf.shape(x)[-1], dtype_util.base_dtype(x.dtype))
         return -(matrix_dim + 1) * tf.reduce_sum(
             tf.math.log(tf.abs(tf.linalg.diag_part(x))), axis=-1)
コード例 #5
0
 def _inverse(self, y):
     # As specified in the Stan reference manual, the procedure is as follows:
     # N = y.shape[-1]
     # z_k = y_k / (1 - sum_{i=1 to k-1} y_i)
     # x_k = logit(z_k) - log(1 / (N - k))
     offset = tf.math.log(
         tf.cast(tf.range(tf.shape(y)[-1] - 1, 0, delta=-1),
                 dtype=dtype_util.base_dtype(y.dtype)))
     z = y / (1. - tf.math.cumsum(y, axis=-1, exclusive=True))
     return tf.math.log(z[..., :-1]) - tf.math.log1p(-z[..., :-1]) + offset
コード例 #6
0
    def _forward_log_det_jacobian(self, x):
        # is_constant_jacobian = True for this bijector, hence the
        # `log_det_jacobian` need only be specified for a single input, as this will
        # be tiled to match `event_ndims`.
        if self.scale is None:
            return tf.constant(0., dtype=dtype_util.base_dtype(x.dtype))

        with tf.control_dependencies(self._maybe_collect_assertions() if self.
                                     validate_args else []):
            return self.scale.log_abs_determinant()
コード例 #7
0
 def _forward(self, x):
   with tf.control_dependencies(self._assertions(x)):
     x_shape = tf.shape(x)
     identity_matrix = tf.eye(
         x_shape[-1],
         batch_shape=x_shape[:-2],
         dtype=dtype_util.base_dtype(x.dtype))
     # Note `matrix_triangular_solve` implicitly zeros upper triangular of `x`.
     y = tf.linalg.triangular_solve(x, identity_matrix)
     y = tf.matmul(y, y, adjoint_a=True)
     return tf.linalg.cholesky(y)
コード例 #8
0
 def _forward(self, x):
     # As specified in the Stan reference manual, the procedure is as follows:
     # N = x.shape[-1] + 1
     # z_k = sigmoid(x + log(1 / (N - k)))
     # y_1 = z_1
     # y_k = (1 - sum_{i=1 to k-1} y_i) * z_k
     # y_N = 1 - sum_{i=1 to N-1} y_i
     # TODO(b/128857065): The numerics can possibly be improved here with a
     # log-space computation.
     offset = -tf.math.log(
         tf.cast(tf.range(tf.shape(x)[-1], 0, delta=-1),
                 dtype=dtype_util.base_dtype(x.dtype)))
     z = tf.math.sigmoid(x + offset)
     y = z * tf.math.cumprod(1 - z, axis=-1, exclusive=True)
     return tf.concat([y, 1. - tf.reduce_sum(y, axis=-1, keepdims=True)],
                      axis=-1)
コード例 #9
0
    def _entropy(self, **kwargs):
        if not self.bijector.is_constant_jacobian:
            raise NotImplementedError("entropy is not implemented")
        if not self.bijector._is_injective:  # pylint: disable=protected-access
            raise NotImplementedError("entropy is not implemented when "
                                      "bijector is not injective.")
        distribution_kwargs, bijector_kwargs = self._kwargs_split_fn(kwargs)
        # Suppose Y = g(X) where g is a diffeomorphism and X is a continuous rv. It
        # can be shown that:
        #   H[Y] = H[X] + E_X[(log o abs o det o J o g)(X)].
        # If is_constant_jacobian then:
        #   E_X[(log o abs o det o J o g)(X)] = (log o abs o det o J o g)(c)
        # where c can by anything.
        entropy = self.distribution.entropy(**distribution_kwargs)
        if self._is_maybe_event_override:
            # H[X] = sum_i H[X_i] if X_i are mutually independent.
            # This means that a reduce_sum is a simple rescaling.
            entropy = entropy * tf.cast(
                tf.reduce_prod(self._override_event_shape),
                dtype=dtype_util.base_dtype(entropy.dtype))
        if self._is_maybe_batch_override:
            new_shape = tf.concat([
                prefer_static.ones_like(self._override_batch_shape),
                self.distribution.batch_shape_tensor()
            ], 0)
            entropy = tf.reshape(entropy, new_shape)
            multiples = tf.concat([
                self._override_batch_shape,
                prefer_static.ones_like(self.distribution.batch_shape_tensor())
            ], 0)
            entropy = tf.tile(entropy, multiples)
        dummy = prefer_static.zeros(shape=tf.concat(
            [self.batch_shape_tensor(),
             self.event_shape_tensor()], 0),
                                    dtype=self.dtype)
        event_ndims = (tensorshape_util.rank(self.event_shape)
                       if tensorshape_util.rank(self.event_shape) is not None
                       else tf.size(self.event_shape_tensor()))
        ildj = self.bijector.inverse_log_det_jacobian(dummy,
                                                      event_ndims=event_ndims,
                                                      **bijector_kwargs)

        entropy = entropy - tf.cast(ildj, entropy.dtype)
        tensorshape_util.set_shape(entropy, self.batch_shape)
        return entropy
コード例 #10
0
ファイル: chain.py プロジェクト: HackerShohag/SuggestBot-bn
    def _forward_log_det_jacobian(self, x, **kwargs):
        x = tf.convert_to_tensor(x, name="x")

        fldj = tf.cast(0., dtype=dtype_util.base_dtype(x.dtype))

        if not self.bijectors:
            return fldj

        event_ndims = self._maybe_get_static_event_ndims(
            self.forward_min_event_ndims)

        if _use_static_shape(x, event_ndims):
            event_shape = x.shape[tensorshape_util.rank(x.shape) -
                                  event_ndims:]
        else:
            event_shape = tf.shape(x)[tf.rank(x) - event_ndims:]

        # TODO(b/129973548): Document and simplify.
        for b in reversed(self.bijectors):
            fldj = fldj + b.forward_log_det_jacobian(
                x, event_ndims=event_ndims, **kwargs.get(b.name, {}))
            if _use_static_shape(x, event_ndims):
                event_shape = b.forward_event_shape(event_shape)
                event_ndims = self._maybe_get_static_event_ndims(
                    tensorshape_util.rank(event_shape))
            else:
                event_shape = b.forward_event_shape_tensor(event_shape)
                event_shape_ = distribution_util.maybe_get_static_value(
                    event_shape)
                event_ndims = tf.size(event_shape)
                event_ndims_ = self._maybe_get_static_event_ndims(event_ndims)

                if event_ndims_ is not None and event_shape_ is not None:
                    event_ndims = event_ndims_
                    event_shape = event_shape_

            x = b.forward(x, **kwargs.get(b.name, {}))

        return fldj
コード例 #11
0
ファイル: chain.py プロジェクト: HackerShohag/SuggestBot-bn
    def _inverse_log_det_jacobian(self, y, **kwargs):
        y = tf.convert_to_tensor(y, name="y")
        ildj = tf.cast(0., dtype=dtype_util.base_dtype(y.dtype))

        if not self.bijectors:
            return ildj

        event_ndims = self._maybe_get_static_event_ndims(
            self.inverse_min_event_ndims)

        if _use_static_shape(y, event_ndims):
            event_shape = y.shape[tensorshape_util.rank(y.shape) -
                                  event_ndims:]
        else:
            event_shape = tf.shape(y)[tf.rank(y) - event_ndims:]

        # TODO(b/129973548): Document and simplify.
        for b in self.bijectors:
            ildj = ildj + b.inverse_log_det_jacobian(
                y, event_ndims=event_ndims, **kwargs.get(b.name, {}))

            if _use_static_shape(y, event_ndims):
                event_shape = b.inverse_event_shape(event_shape)
                event_ndims = self._maybe_get_static_event_ndims(
                    tensorshape_util.rank(event_shape))
            else:
                event_shape = b.inverse_event_shape_tensor(event_shape)
                event_shape_ = distribution_util.maybe_get_static_value(
                    event_shape)
                event_ndims = tf.size(event_shape)
                event_ndims_ = self._maybe_get_static_event_ndims(event_ndims)

                if event_ndims_ is not None and event_shape_ is not None:
                    event_ndims = event_ndims_
                    event_shape = event_shape_

            y = b.inverse(y, **kwargs.get(b.name, {}))
        return ildj
コード例 #12
0
    def __init__(self,
                 mix_loc,
                 temperature,
                 distribution,
                 loc=None,
                 scale=None,
                 quadrature_size=8,
                 quadrature_fn=quadrature_scheme_softmaxnormal_quantiles,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="VectorDiffeomixture"):
        """Constructs the VectorDiffeomixture on `R^d`.

    The vector diffeomixture (VDM) approximates the compound distribution

    ```none
    p(x) = int p(x | z) p(z) dz,
    where z is in the K-simplex, and
    p(x | z) := p(x | loc=sum_k z[k] loc[k], scale=sum_k z[k] scale[k])
    ```

    Args:
      mix_loc: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`.
        In terms of samples, larger `mix_loc[..., k]` ==>
        `Z` is more likely to put more weight on its `kth` component.
      temperature: `float`-like `Tensor`. Broadcastable with `mix_loc`.
        In terms of samples, smaller `temperature` means one component is more
        likely to dominate.  I.e., smaller `temperature` makes the VDM look more
        like a standard mixture of `K` components.
      distribution: `tfp.distributions.Distribution`-like instance. Distribution
        from which `d` iid samples are used as input to the selected affine
        transformation. Must be a scalar-batch, scalar-event distribution.
        Typically `distribution.reparameterization_type = FULLY_REPARAMETERIZED`
        or it is a function of non-trainable parameters. WARNING: If you
        backprop through a VectorDiffeomixture sample and the `distribution`
        is not `FULLY_REPARAMETERIZED` yet is a function of trainable variables,
        then the gradient will be incorrect!
      loc: Length-`K` list of `float`-type `Tensor`s. The `k`-th element
        represents the `shift` used for the `k`-th affine transformation.  If
        the `k`-th item is `None`, `loc` is implicitly `0`.  When specified,
        must have shape `[B1, ..., Bb, d]` where `b >= 0` and `d` is the event
        size.
      scale: Length-`K` list of `LinearOperator`s. Each should be
        positive-definite and operate on a `d`-dimensional vector space. The
        `k`-th element represents the `scale` used for the `k`-th affine
        transformation. `LinearOperator`s must have shape `[B1, ..., Bb, d, d]`,
        `b >= 0`, i.e., characterizes `b`-batches of `d x d` matrices
      quadrature_size: Python `int` scalar representing number of
        quadrature points.  Larger `quadrature_size` means `q_N(x)` better
        approximates `p(x)`.
      quadrature_fn: Python callable taking `normal_loc`, `normal_scale`,
        `quadrature_size`, `validate_args` and returning `tuple(grid, probs)`
        representing the SoftmaxNormal grid and corresponding normalized weight.
        normalized) weight.
        Default value: `quadrature_scheme_softmaxnormal_quantiles`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      ValueError: if `not scale or len(scale) < 2`.
      ValueError: if `len(loc) != len(scale)`
      ValueError: if `quadrature_grid_and_probs is not None` and
        `len(quadrature_grid_and_probs[0]) != len(quadrature_grid_and_probs[1])`
      ValueError: if `validate_args` and any not scale.is_positive_definite.
      TypeError: if any scale.dtype != scale[0].dtype.
      TypeError: if any loc.dtype != scale[0].dtype.
      NotImplementedError: if `len(scale) != 2`.
      ValueError: if `not distribution.is_scalar_batch`.
      ValueError: if `not distribution.is_scalar_event`.
    """
        parameters = dict(locals())
        with tf.name_scope(name) as name:
            if not scale or len(scale) < 2:
                raise ValueError(
                    "Must specify list (or list-like object) of scale "
                    "LinearOperators, one for each component with "
                    "num_component >= 2.")

            if loc is None:
                loc = [None] * len(scale)

            if len(loc) != len(scale):
                raise ValueError("loc/scale must be same-length lists "
                                 "(or same-length list-like objects).")

            dtype = dtype_util.base_dtype(scale[0].dtype)

            loc = [
                tf.convert_to_tensor(loc_, dtype=dtype, name="loc{}".format(k))
                if loc_ is not None else None for k, loc_ in enumerate(loc)
            ]

            for k, scale_ in enumerate(scale):
                if validate_args and not scale_.is_positive_definite:
                    raise ValueError(
                        "scale[{}].is_positive_definite = {} != True".format(
                            k, scale_.is_positive_definite))
                if dtype_util.base_dtype(scale_.dtype) != dtype:
                    raise TypeError(
                        "dtype mismatch; scale[{}].base_dtype=\"{}\" != \"{}\""
                        .format(k, dtype_util.name(scale_.dtype),
                                dtype_util.name(dtype)))

            self._endpoint_affine = [
                affine_linear_operator_bijector.AffineLinearOperator(  # pylint: disable=g-complex-comprehension
                    shift=loc_,
                    scale=scale_,
                    validate_args=validate_args,
                    name="endpoint_affine_{}".format(k))
                for k, (loc_, scale_) in enumerate(zip(loc, scale))
            ]

            # TODO(jvdillon): Remove once we support k-mixtures.
            # We make this assertion here because otherwise `grid` would need to be a
            # vector not a scalar.
            if len(scale) != 2:
                raise NotImplementedError(
                    "Currently only bimixtures are supported; "
                    "len(scale)={} is not 2.".format(len(scale)))

            mix_loc = tf.convert_to_tensor(mix_loc,
                                           dtype=dtype,
                                           name="mix_loc")
            temperature = tf.convert_to_tensor(temperature,
                                               dtype=dtype,
                                               name="temperature")
            self._grid, probs = tuple(
                quadrature_fn(mix_loc / temperature, 1. / temperature,
                              quadrature_size, validate_args))

            # Note: by creating the logits as `log(prob)` we ensure that
            # `self.mixture_distribution.logits` is equivalent to
            # `math_ops.log(self.mixture_distribution.probs)`.
            self._mixture_distribution = categorical.Categorical(
                logits=tf.math.log(probs),
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats)

            asserts = distribution_util.maybe_check_scalar_distribution(
                distribution, dtype, validate_args)
            if asserts:
                self._grid = distribution_util.with_dependencies(
                    asserts, self._grid)
            self._distribution = distribution

            self._interpolated_affine = [
                affine_linear_operator_bijector.AffineLinearOperator(  # pylint: disable=g-complex-comprehension
                    shift=loc_,
                    scale=scale_,
                    validate_args=validate_args,
                    name="interpolated_affine_{}".format(k))
                for k, (loc_, scale_) in enumerate(
                    zip(interpolate_loc(self._grid, loc),
                        interpolate_scale(self._grid, scale)))
            ]

            [
                self._batch_shape_,
                self._batch_shape_tensor_,
                self._event_shape_,
                self._event_shape_tensor_,
            ] = determine_batch_event_shapes(self._grid, self._endpoint_affine)

            super(VectorDiffeomixture, self).__init__(
                dtype=dtype,
                # We hard-code `FULLY_REPARAMETERIZED` because when
                # `validate_args=True` we verify that indeed
                # `distribution.reparameterization_type == FULLY_REPARAMETERIZED`. A
                # distribution which is a function of only non-trainable parameters
                # also implies we can use `FULLY_REPARAMETERIZED`. However, we cannot
                # easily test for that possibility thus we use `validate_args=False`
                # as a "back-door" to allow users a way to use non
                # `FULLY_REPARAMETERIZED` distribution. In such cases IT IS THE USERS
                # RESPONSIBILITY to verify that the base distribution is a function of
                # non-trainable parameters.
                reparameterization_type=reparameterization.
                FULLY_REPARAMETERIZED,
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats,
                parameters=parameters,
                name=name)
コード例 #13
0
def quadrature_scheme_softmaxnormal_quantiles(normal_loc,
                                              normal_scale,
                                              quadrature_size,
                                              validate_args=False,
                                              name=None):
    """Use SoftmaxNormal quantiles to form quadrature on `K - 1` simplex.

  A `SoftmaxNormal` random variable `Y` may be generated via

  ```
  Y = SoftmaxCentered(X),
  X = Normal(normal_loc, normal_scale)
  ```

  Args:
    normal_loc: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`, B>=0.
      The location parameter of the Normal used to construct the SoftmaxNormal.
    normal_scale: `float`-like `Tensor`. Broadcastable with `normal_loc`.
      The scale parameter of the Normal used to construct the SoftmaxNormal.
    quadrature_size: Python `int` scalar representing the number of quadrature
      points.
    validate_args: Python `bool`, default `False`. When `True` distribution
      parameters are checked for validity despite possibly degrading runtime
      performance. When `False` invalid inputs may silently render incorrect
      outputs.
    name: Python `str` name prefixed to Ops created by this class.

  Returns:
    grid: Shape `[b1, ..., bB, K, quadrature_size]` `Tensor` representing the
      convex combination of affine parameters for `K` components.
      `grid[..., :, n]` is the `n`-th grid point, living in the `K - 1` simplex.
    probs:  Shape `[b1, ..., bB, K, quadrature_size]` `Tensor` representing the
      associated with each grid point.
  """
    with tf.name_scope(name or "softmax_normal_grid_and_probs"):
        normal_loc = tf.convert_to_tensor(normal_loc, name="normal_loc")
        dt = dtype_util.base_dtype(normal_loc.dtype)
        normal_scale = tf.convert_to_tensor(normal_scale,
                                            dtype=dt,
                                            name="normal_scale")

        normal_scale = maybe_check_quadrature_param(normal_scale,
                                                    "normal_scale",
                                                    validate_args)

        dist = normal.Normal(loc=normal_loc, scale=normal_scale)

        def _get_batch_ndims():
            """Helper to get rank(dist.batch_shape), statically if possible."""
            ndims = tensorshape_util.rank(dist.batch_shape)
            if ndims is None:
                ndims = tf.shape(dist.batch_shape_tensor())[0]
            return ndims

        batch_ndims = _get_batch_ndims()

        def _get_final_shape(qs):
            """Helper to build `TensorShape`."""
            bs = tensorshape_util.with_rank_at_least(dist.batch_shape, 1)
            num_components = tf.compat.dimension_value(bs[-1])
            if num_components is not None:
                num_components += 1
            tail = tf.TensorShape([num_components, qs])
            return bs[:-1].concatenate(tail)

        def _compute_quantiles():
            """Helper to build quantiles."""
            # Omit {0, 1} since they might lead to Inf/NaN.
            zero = tf.zeros([], dtype=dist.dtype)
            edges = tf.linspace(zero, 1., quadrature_size + 3)[1:-1]
            # Expand edges so its broadcast across batch dims.
            edges = tf.reshape(
                edges,
                shape=tf.concat(
                    [[-1], tf.ones([batch_ndims], dtype=tf.int32)], axis=0))
            quantiles = dist.quantile(edges)
            quantiles = softmax_centered_bijector.SoftmaxCentered().forward(
                quantiles)
            # Cyclically permute left by one.
            perm = tf.concat([tf.range(1, 1 + batch_ndims), [0]], axis=0)
            quantiles = tf.transpose(a=quantiles, perm=perm)
            tensorshape_util.set_shape(quantiles,
                                       _get_final_shape(quadrature_size + 1))
            return quantiles

        quantiles = _compute_quantiles()

        # Compute grid as quantile midpoints.
        grid = (quantiles[..., :-1] + quantiles[..., 1:]) / 2.
        # Set shape hints.
        tensorshape_util.set_shape(grid, _get_final_shape(quadrature_size))

        # By construction probs is constant, i.e., `1 / quadrature_size`. This is
        # important, because non-constant probs leads to non-reparameterizable
        # samples.
        probs = tf.fill(dims=[quadrature_size],
                        value=1. / tf.cast(quadrature_size, dist.dtype))

        return grid, probs
コード例 #14
0
ファイル: wishart.py プロジェクト: HackerShohag/SuggestBot-bn
    def _sample_n(self, n, seed):
        batch_shape = self.batch_shape_tensor()
        event_shape = self.event_shape_tensor()
        batch_ndims = tf.shape(batch_shape)[0]

        ndims = batch_ndims + 3  # sample_ndims=1, event_ndims=2
        shape = tf.concat([[n], batch_shape, event_shape], 0)
        stream = SeedStream(seed, salt="Wishart")

        # Complexity: O(nbk**2)
        x = tf.random.normal(shape=shape,
                             mean=0.,
                             stddev=1.,
                             dtype=self.dtype,
                             seed=stream())

        # Complexity: O(nbk)
        # This parametrization is equivalent to Chi2, i.e.,
        # ChiSquared(k) == Gamma(alpha=k/2, beta=1/2)
        expanded_df = self.df * tf.ones(
            self.scale_operator.batch_shape_tensor(),
            dtype=dtype_util.base_dtype(self.df.dtype))

        g = tf.random.gamma(shape=[n],
                            alpha=self._multi_gamma_sequence(
                                0.5 * expanded_df, self.dimension),
                            beta=0.5,
                            dtype=self.dtype,
                            seed=stream())

        # Complexity: O(nbk**2)
        x = tf.linalg.band_part(x, -1, 0)  # Tri-lower.

        # Complexity: O(nbk)
        x = tf.linalg.set_diag(x, tf.sqrt(g))

        # Make batch-op ready.
        # Complexity: O(nbk**2)
        perm = tf.concat([tf.range(1, ndims), [0]], 0)
        x = tf.transpose(a=x, perm=perm)
        shape = tf.concat(
            [batch_shape, [event_shape[0]], [event_shape[1] * n]], 0)
        x = tf.reshape(x, shape)

        # Complexity: O(nbM) where M is the complexity of the operator solving a
        # vector system. For LinearOperatorLowerTriangular, each matmul is O(k^3) so
        # this step has complexity O(nbk^3).
        x = self.scale_operator.matmul(x)

        # Undo make batch-op ready.
        # Complexity: O(nbk**2)
        shape = tf.concat([batch_shape, event_shape, [n]], 0)
        x = tf.reshape(x, shape)
        perm = tf.concat([[ndims - 1], tf.range(0, ndims - 1)], 0)
        x = tf.transpose(a=x, perm=perm)

        if not self.input_output_cholesky:
            # Complexity: O(nbk**3)
            x = tf.matmul(x, x, adjoint_b=True)

        return x
コード例 #15
0
 def _inverse_log_det_jacobian(self, y):
     return tf.constant(0., dtype=dtype_util.base_dtype(y.dtype))
コード例 #16
0
 def _forward_log_det_jacobian(self, x):
     return tf.constant(0., dtype=dtype_util.base_dtype(x.dtype))
コード例 #17
0
 def _inverse_log_det_jacobian(self, y):
     # is_constant_jacobian = True for this bijector, hence the
     # `log_det_jacobian` need only be specified for a single input, as this will
     # be tiled to match `event_ndims`.
     return tf.constant(0., dtype=dtype_util.base_dtype(y.dtype))
コード例 #18
0
    def __init__(self,
                 distributions,
                 dtype_override=None,
                 validate_args=False,
                 allow_nan_stats=False,
                 name='Blockwise'):
        """Construct the `Blockwise` distribution.

    Args:
      distributions: Python `list` of `tfp.distributions.Distribution`
        instances. All distribution instances must have the same `batch_shape`
        and all must have `event_ndims==1`, i.e., be vector-variate
        distributions.
      dtype_override: samples of `distributions` will be cast to this `dtype`.
        If unspecified, all `distributions` must have the same `dtype`.
        Default value: `None` (i.e., do not cast).
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or more
        of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
    """
        parameters = dict(locals())
        with tf.name_scope(name) as name:
            self._distributions = distributions
            if dtype_override is not None:
                distributions = tf.nest.map_structure(
                    lambda d: _Cast(d, dtype_override), distributions)
            if _is_iterable(distributions):
                self._distribution = (
                    joint_distribution_sequential.JointDistributionSequential(
                        list(distributions)))
            else:
                self._distribution = distributions

            # Need to cache these for JointDistributions as the batch shape of that
            # distribution can change after `_sample` calls.
            self._cached_batch_shape_tensor = self._distribution.batch_shape_tensor(
            )
            self._cached_batch_shape = self._distribution.batch_shape

            if dtype_override is not None:
                dtype = dtype_override
            else:
                dtype = set(
                    dtype_util.base_dtype(dtype)
                    for dtype in tf.nest.flatten(self._distribution.dtype)
                    if dtype is not None)
                if len(dtype) == 0:  # pylint: disable=g-explicit-length-test
                    dtype = tf.float32
                elif len(dtype) == 1:
                    dtype = dtype.pop()
                else:
                    raise TypeError(
                        'Distributions must have same dtype; found: {}.'.
                        format(self._distribution.dtype))

            reparameterization_type = set(
                tf.nest.flatten(self._distribution.reparameterization_type))
            reparameterization_type = (reparameterization_type.pop() if
                                       len(reparameterization_type) == 1 else
                                       reparameterization.NOT_REPARAMETERIZED)

            super(Blockwise, self).__init__(
                dtype=dtype,
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats,
                reparameterization_type=reparameterization_type,
                parameters=parameters,
                name=name)
コード例 #19
0
def convert_nonref_to_tensor(value, dtype=None, dtype_hint=None, name=None):
  """Converts the given `value` to a `Tensor` if input is nonreference type.

  This function converts Python objects of various types to `Tensor` objects
  only if the input has nonreference semantics. Reference semantics are
  characterized by `tensor_util.is_ref` and is any object which is a
  `tf.Variable` or instance of `tf.Module`. This function accepts any input
  which `tf.convert_to_tensor` would also.

  Note: This function diverges from default Numpy behavior for `float` and
    `string` types when `None` is present in a Python list or scalar. Rather
    than silently converting `None` values, an error will be thrown.

  Args:
    value: An object whose type has a registered `Tensor` conversion function.
    dtype: Optional element type for the returned tensor. If missing, the
      type is inferred from the type of `value`.
    dtype_hint: Optional element type for the returned tensor,
      used when dtype is None. In some cases, a caller may not have a
      dtype in mind when converting to a tensor, so dtype_hint
      can be used as a soft preference.  If the conversion to
      `dtype_hint` is not possible, this argument has no effect.
    name: Optional name to use if a new `Tensor` is created.

  Returns:
    tensor: A `Tensor` based on `value`.

  Raises:
    TypeError: If no conversion function is registered for `value` to `dtype`.
    RuntimeError: If a registered conversion function returns an invalid value.
    ValueError: If the `value` is a tensor not of given `dtype` in graph mode.


  #### Examples:

  ```python
  from tensorflow_probability.python.experimental.substrates.jax.internal import tensor_util

  x = tf.Variable(0.)
  y = tensor_util.convert_nonref_to_tensor(x)
  x is y
  # ==> True

  x = tf.constant(0.)
  y = tensor_util.convert_nonref_to_tensor(x)
  x is y
  # ==> True

  x = np.array(0.)
  y = tensor_util.convert_nonref_to_tensor(x)
  x is y
  # ==> False
  tf.is_tensor(y)
  # ==> True

  x = tfp.util.DeferredTensor(lambda x: x, 13.37)
  y = tensor_util.convert_nonref_to_tensor(x)
  x is y
  # ==> True
  tf.is_tensor(y)
  # ==> True
  tf.equal(y, 13.37)
  # ==> True
  ```

  """
  # We explicitly do not use a tf.name_scope to avoid graph clutter.
  if value is None:
    return None
  if is_ref(value):
    if dtype is None:
      return value
    dtype_base = dtype_util.base_dtype(dtype)
    value_dtype_base = dtype_util.base_dtype(value.dtype)
    if dtype_base != value_dtype_base:
      raise TypeError('Mutable type must be of dtype "{}" but is "{}".'.format(
          dtype_util.name(dtype_base), dtype_util.name(value_dtype_base)))
    return value
  return tf.convert_to_tensor(
      value, dtype=dtype, dtype_hint=dtype_hint, name=name)