Пример #1
0
    def _forward_log_det_jacobian(self, x, **kwargs):
        x = tf.convert_to_tensor(x, name="x")

        fldj = tf.cast(0., dtype=dtype_util.base_dtype(x.dtype))

        if not self.bijectors:
            return fldj

        event_ndims = self._maybe_get_static_event_ndims(
            self.forward_min_event_ndims)

        if _use_static_shape(x, event_ndims):
            event_shape = x.shape[tensorshape_util.rank(x.shape) -
                                  event_ndims:]
        else:
            event_shape = tf.shape(x)[tf.rank(x) - event_ndims:]

        # TODO(b/129973548): Document and simplify.
        for b in reversed(self.bijectors):
            fldj = fldj + b.forward_log_det_jacobian(
                x, event_ndims=event_ndims, **kwargs.get(b.name, {}))
            if _use_static_shape(x, event_ndims):
                event_shape = b.forward_event_shape(event_shape)
                event_ndims = self._maybe_get_static_event_ndims(
                    tensorshape_util.rank(event_shape))
            else:
                event_shape = b.forward_event_shape_tensor(event_shape)
                event_shape_ = distribution_util.maybe_get_static_value(
                    event_shape)
                event_ndims = tf.size(event_shape)
                event_ndims_ = self._maybe_get_static_event_ndims(event_ndims)

                if event_ndims_ is not None and event_shape_ is not None:
                    event_ndims = event_ndims_
                    event_shape = event_shape_

            x = b.forward(x, **kwargs.get(b.name, {}))

        return fldj
def _bessel_ive(v, z, cache=None):
  """Computes I_v(z)*exp(-abs(z)) using a recurrence relation, where z > 0."""
  # TODO(b/67497980): Switch to a more numerically faithful implementation.
  z = tf.convert_to_tensor(z)

  wrap = lambda result: tf.debugging.check_numerics(result, 'besseli{}'.format(v
                                                                              ))

  if float(v) >= 2:
    raise ValueError(
        'Evaluating bessel_i by recurrence becomes imprecise for large v')

  cache = cache or {}
  safe_z = tf.where(z > 0, z, tf.ones_like(z))
  if v in cache:
    return wrap(cache[v])
  if v == 0:
    cache[v] = tf.math.bessel_i0e(z)
  elif v == 1:
    cache[v] = tf.math.bessel_i1e(z)
  elif v == 0.5:
    # sinh(x)*exp(-abs(x)), sinh(x) = (e^x - e^{-x}) / 2
    sinhe = lambda x: (tf.exp(x - tf.abs(x)) - tf.exp(-x - tf.abs(x))) / 2
    cache[v] = (
        np.sqrt(2 / np.pi) * sinhe(z) *
        tf.where(z > 0, tf.math.rsqrt(safe_z), tf.ones_like(safe_z)))
  elif v == -0.5:
    # cosh(x)*exp(-abs(x)), cosh(x) = (e^x + e^{-x}) / 2
    coshe = lambda x: (tf.exp(x - tf.abs(x)) + tf.exp(-x - tf.abs(x))) / 2
    cache[v] = (
        np.sqrt(2 / np.pi) * coshe(z) *
        tf.where(z > 0, tf.math.rsqrt(safe_z), tf.ones_like(safe_z)))
  if v <= 1:
    return wrap(cache[v])
  # Recurrence relation:
  cache[v] = (_bessel_ive(v - 2, z, cache) -
              (2 * (v - 1)) * _bessel_ive(v - 1, z, cache) / z)
  return wrap(cache[v])
Пример #3
0
    def _inverse_log_det_jacobian(self, y, **kwargs):
        y = tf.convert_to_tensor(y, name="y")
        ildj = tf.cast(0., dtype=dtype_util.base_dtype(y.dtype))

        if not self.bijectors:
            return ildj

        event_ndims = self._maybe_get_static_event_ndims(
            self.inverse_min_event_ndims)

        if _use_static_shape(y, event_ndims):
            event_shape = y.shape[tensorshape_util.rank(y.shape) -
                                  event_ndims:]
        else:
            event_shape = tf.shape(y)[tf.rank(y) - event_ndims:]

        # TODO(b/129973548): Document and simplify.
        for b in self.bijectors:
            ildj = ildj + b.inverse_log_det_jacobian(
                y, event_ndims=event_ndims, **kwargs.get(b.name, {}))

            if _use_static_shape(y, event_ndims):
                event_shape = b.inverse_event_shape(event_shape)
                event_ndims = self._maybe_get_static_event_ndims(
                    tensorshape_util.rank(event_shape))
            else:
                event_shape = b.inverse_event_shape_tensor(event_shape)
                event_shape_ = distribution_util.maybe_get_static_value(
                    event_shape)
                event_ndims = tf.size(event_shape)
                event_ndims_ = self._maybe_get_static_event_ndims(event_ndims)

                if event_ndims_ is not None and event_shape_ is not None:
                    event_ndims = event_ndims_
                    event_shape = event_shape_

            y = b.inverse(y, **kwargs.get(b.name, {}))
        return ildj
Пример #4
0
def _sparse_block_diag(sp_a):
    """Returns a block diagonal rank 2 SparseTensor from a batch of SparseTensors.

  Args:
    sp_a: A rank 3 `SparseTensor` representing a batch of matrices.

  Returns:
    sp_block_diag_a: matrix-shaped, `float` `SparseTensor` with the same dtype
    as `sparse_or_matrix`, of shape [B * M, B * N] where `sp_a` has shape
    [B, M, N]. Each [M, N] batch of `sp_a` is lined up along the diagonal.
  """
    # Construct the matrix [[M, N], [1, 0], [0, 1]] which would map the index
    # (b, i, j) to (Mb + i, Nb + j). This effectively creates a block-diagonal
    # matrix of dense shape [B * M, B * N].
    # Note that this transformation doesn't increase the number of non-zero
    # entries in the SparseTensor.
    sp_a_shape = tf.convert_to_tensor(_get_shape(sp_a, tf.int64))
    ind_mat = tf.concat([[sp_a_shape[-2:]], tf.eye(2, dtype=tf.int64)], axis=0)
    indices = tf.matmul(sp_a.indices, ind_mat)
    dense_shape = sp_a_shape[0] * sp_a_shape[1:]
    return tf.SparseTensor(indices=indices,
                           values=sp_a.values,
                           dense_shape=dense_shape)
Пример #5
0
    def _entropy(self):
        samples = tf.convert_to_tensor(self.samples)
        num_samples = self._compute_num_samples(samples)
        entropy_shape = self._batch_shape_tensor(samples)

        # Flatten samples for each batch.
        if self._event_ndims == 0:
            samples = tf.reshape(samples, [-1, num_samples])
        else:
            event_size = tf.reduce_prod(self.event_shape_tensor())
            samples = tf.reshape(samples, [-1, num_samples, event_size])

        # Use map_fn to compute entropy for each batch separately.
        def _get_entropy(samples):
            # TODO(b/123985779): Switch to tf.unique_with_counts_v2 when exposed
            count = gen_array_ops.unique_with_counts_v2(samples,
                                                        axis=[0]).count
            prob = tf.cast(count / num_samples, dtype=self.dtype)
            entropy = tf.reduce_sum(-prob * tf.math.log(prob))
            return entropy

        entropy = tf.map_fn(_get_entropy, samples, dtype=self.dtype)
        return tf.reshape(entropy, entropy_shape)
Пример #6
0
  def _sample_n(self, n, seed=None):
    temperature = tf.convert_to_tensor(self.temperature)
    logits = self._logits_parameter_no_checks()

    # Uniform variates must be sampled from the open-interval `(0, 1)` rather
    # than `[0, 1)`. To do so, we use
    # `np.finfo(dtype_util.as_numpy_dtype(self.dtype)).tiny` because it is the
    # smallest, positive, 'normal' number. A 'normal' number is such that the
    # mantissa has an implicit leading 1. Normal, positive numbers x, y have the
    # reasonable property that, `x + y >= max(x, y)`. In this case, a subnormal
    # number (i.e., np.nextafter) can cause us to sample 0.
    uniform_shape = tf.concat(
        [[n],
         self._batch_shape_tensor(temperature=temperature, logits=logits),
         self._event_shape_tensor(logits=logits)], 0)
    uniform = tf.random.uniform(
        shape=uniform_shape,
        minval=np.finfo(dtype_util.as_numpy_dtype(self.dtype)).tiny,
        maxval=1.,
        dtype=self.dtype,
        seed=seed)
    gumbel = -tf.math.log(-tf.math.log(uniform))
    noisy_logits = (gumbel + logits) / temperature[..., tf.newaxis]
    return tf.math.log_softmax(noisy_logits)
Пример #7
0
def _slice_params_to_dict(dist, params_event_ndims, slices):
  """Computes the override dictionary of sliced parameters.

  Args:
    dist: The tfd.Distribution being batch-sliced.
    params_event_ndims: Per-event parameter ranks, a `str->int` `dict`.
    slices: Slices as received by __getitem__.

  Returns:
    overrides: `str->Tensor` `dict` of batch-sliced parameter overrides.
  """
  override_dict = {}
  for param_name, param_event_ndims in six.iteritems(params_event_ndims):
    # Verify that either None or a legit value is in the parameters dict.
    if param_name not in dist.parameters:
      raise ValueError('Distribution {} is missing advertised '
                       'parameter {}'.format(dist, param_name))
    param = dist.parameters[param_name]
    if param is None:
      # some distributions have multiple possible parameterizations; this
      # param was not provided
      continue
    dtype = None
    if hasattr(dist, param_name):
      attr = getattr(dist, param_name)
      dtype = getattr(attr, 'dtype', None)
    if dtype is None:
      dtype = dist.dtype
      warnings.warn('Unable to find property getter for parameter Tensor {} '
                    'on {}, falling back to Distribution.dtype {}'.format(
                        param_name, dist, dtype))
    param = tf.convert_to_tensor(value=param, dtype=dtype)
    override_dict[param_name] = _slice_single_param(param, param_event_ndims,
                                                    slices,
                                                    dist.batch_shape_tensor())
  return override_dict
Пример #8
0
 def _validate_dimension(self, x):
     x = tf.convert_to_tensor(x, name='x')
     if tensorshape_util.is_fully_defined(x.shape[-2:]):
         if (tensorshape_util.dims(x.shape)[-2] == tensorshape_util.dims(
                 x.shape)[-1] == self.dimension):
             pass
         else:
             raise ValueError(
                 'Input dimension mismatch: expected [..., {}, {}], got {}'.
                 format(self.dimension, self.dimension,
                        tensorshape_util.dims(x.shape)))
     elif self.validate_args:
         msg = 'Input dimension mismatch: expected [..., {}, {}], got {}'.format(
             self.dimension, self.dimension, tf.shape(x))
         with tf.control_dependencies([
                 assert_util.assert_equal(tf.shape(x)[-2],
                                          self.dimension,
                                          message=msg),
                 assert_util.assert_equal(tf.shape(x)[-1],
                                          self.dimension,
                                          message=msg)
         ]):
             x = tf.identity(x)
     return x
def quadrature_scheme_softmaxnormal_gauss_hermite(normal_loc,
                                                  normal_scale,
                                                  quadrature_size,
                                                  validate_args=False,
                                                  name=None):
    """Use Gauss-Hermite quadrature to form quadrature on `K - 1` simplex.

  A `SoftmaxNormal` random variable `Y` may be generated via

  ```
  Y = SoftmaxCentered(X),
  X = Normal(normal_loc, normal_scale)
  ```

  Note: for a given `quadrature_size`, this method is generally less accurate
  than `quadrature_scheme_softmaxnormal_quantiles`.

  Args:
    normal_loc: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`, B>=0.
      The location parameter of the Normal used to construct the SoftmaxNormal.
    normal_scale: `float`-like `Tensor`. Broadcastable with `normal_loc`.
      The scale parameter of the Normal used to construct the SoftmaxNormal.
    quadrature_size: Python `int` scalar representing the number of quadrature
      points.
    validate_args: Python `bool`, default `False`. When `True` distribution
      parameters are checked for validity despite possibly degrading runtime
      performance. When `False` invalid inputs may silently render incorrect
      outputs.
    name: Python `str` name prefixed to Ops created by this class.

  Returns:
    grid: Shape `[b1, ..., bB, K, quadrature_size]` `Tensor` representing the
      convex combination of affine parameters for `K` components.
      `grid[..., :, n]` is the `n`-th grid point, living in the `K - 1` simplex.
    probs:  Shape `[b1, ..., bB, K, quadrature_size]` `Tensor` representing the
      associated with each grid point.
  """
    with tf.name_scope(name
                       or "quadrature_scheme_softmaxnormal_gauss_hermite"):
        normal_loc = tf.convert_to_tensor(normal_loc, name="normal_loc")
        npdt = dtype_util.as_numpy_dtype(normal_loc.dtype)
        normal_scale = tf.convert_to_tensor(normal_scale,
                                            dtype=npdt,
                                            name="normal_scale")

        normal_scale = maybe_check_quadrature_param(normal_scale,
                                                    "normal_scale",
                                                    validate_args)

        grid, probs = np.polynomial.hermite.hermgauss(deg=quadrature_size)
        grid = grid.astype(npdt)
        probs = probs.astype(npdt)
        probs /= np.linalg.norm(probs, ord=1, keepdims=True)
        probs = tf.convert_to_tensor(probs, name="probs", dtype=npdt)

        grid = softmax(-distribution_util.pad(
            (normal_loc[..., tf.newaxis] +
             np.sqrt(2.) * normal_scale[..., tf.newaxis] * grid),
            axis=-2,
            front=True),
                       axis=-2)  # shape: [B, components, deg]

        return grid, probs
    def __init__(self,
                 mix_loc,
                 temperature,
                 distribution,
                 loc=None,
                 scale=None,
                 quadrature_size=8,
                 quadrature_fn=quadrature_scheme_softmaxnormal_quantiles,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="VectorDiffeomixture"):
        """Constructs the VectorDiffeomixture on `R^d`.

    The vector diffeomixture (VDM) approximates the compound distribution

    ```none
    p(x) = int p(x | z) p(z) dz,
    where z is in the K-simplex, and
    p(x | z) := p(x | loc=sum_k z[k] loc[k], scale=sum_k z[k] scale[k])
    ```

    Args:
      mix_loc: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`.
        In terms of samples, larger `mix_loc[..., k]` ==>
        `Z` is more likely to put more weight on its `kth` component.
      temperature: `float`-like `Tensor`. Broadcastable with `mix_loc`.
        In terms of samples, smaller `temperature` means one component is more
        likely to dominate.  I.e., smaller `temperature` makes the VDM look more
        like a standard mixture of `K` components.
      distribution: `tfp.distributions.Distribution`-like instance. Distribution
        from which `d` iid samples are used as input to the selected affine
        transformation. Must be a scalar-batch, scalar-event distribution.
        Typically `distribution.reparameterization_type = FULLY_REPARAMETERIZED`
        or it is a function of non-trainable parameters. WARNING: If you
        backprop through a VectorDiffeomixture sample and the `distribution`
        is not `FULLY_REPARAMETERIZED` yet is a function of trainable variables,
        then the gradient will be incorrect!
      loc: Length-`K` list of `float`-type `Tensor`s. The `k`-th element
        represents the `shift` used for the `k`-th affine transformation.  If
        the `k`-th item is `None`, `loc` is implicitly `0`.  When specified,
        must have shape `[B1, ..., Bb, d]` where `b >= 0` and `d` is the event
        size.
      scale: Length-`K` list of `LinearOperator`s. Each should be
        positive-definite and operate on a `d`-dimensional vector space. The
        `k`-th element represents the `scale` used for the `k`-th affine
        transformation. `LinearOperator`s must have shape `[B1, ..., Bb, d, d]`,
        `b >= 0`, i.e., characterizes `b`-batches of `d x d` matrices
      quadrature_size: Python `int` scalar representing number of
        quadrature points.  Larger `quadrature_size` means `q_N(x)` better
        approximates `p(x)`.
      quadrature_fn: Python callable taking `normal_loc`, `normal_scale`,
        `quadrature_size`, `validate_args` and returning `tuple(grid, probs)`
        representing the SoftmaxNormal grid and corresponding normalized weight.
        normalized) weight.
        Default value: `quadrature_scheme_softmaxnormal_quantiles`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      ValueError: if `not scale or len(scale) < 2`.
      ValueError: if `len(loc) != len(scale)`
      ValueError: if `quadrature_grid_and_probs is not None` and
        `len(quadrature_grid_and_probs[0]) != len(quadrature_grid_and_probs[1])`
      ValueError: if `validate_args` and any not scale.is_positive_definite.
      TypeError: if any scale.dtype != scale[0].dtype.
      TypeError: if any loc.dtype != scale[0].dtype.
      NotImplementedError: if `len(scale) != 2`.
      ValueError: if `not distribution.is_scalar_batch`.
      ValueError: if `not distribution.is_scalar_event`.
    """
        parameters = dict(locals())
        with tf.name_scope(name) as name:
            if not scale or len(scale) < 2:
                raise ValueError(
                    "Must specify list (or list-like object) of scale "
                    "LinearOperators, one for each component with "
                    "num_component >= 2.")

            if loc is None:
                loc = [None] * len(scale)

            if len(loc) != len(scale):
                raise ValueError("loc/scale must be same-length lists "
                                 "(or same-length list-like objects).")

            dtype = dtype_util.base_dtype(scale[0].dtype)

            loc = [
                tf.convert_to_tensor(loc_, dtype=dtype, name="loc{}".format(k))
                if loc_ is not None else None for k, loc_ in enumerate(loc)
            ]

            for k, scale_ in enumerate(scale):
                if validate_args and not scale_.is_positive_definite:
                    raise ValueError(
                        "scale[{}].is_positive_definite = {} != True".format(
                            k, scale_.is_positive_definite))
                if dtype_util.base_dtype(scale_.dtype) != dtype:
                    raise TypeError(
                        "dtype mismatch; scale[{}].base_dtype=\"{}\" != \"{}\""
                        .format(k, dtype_util.name(scale_.dtype),
                                dtype_util.name(dtype)))

            self._endpoint_affine = [
                affine_linear_operator_bijector.AffineLinearOperator(  # pylint: disable=g-complex-comprehension
                    shift=loc_,
                    scale=scale_,
                    validate_args=validate_args,
                    name="endpoint_affine_{}".format(k))
                for k, (loc_, scale_) in enumerate(zip(loc, scale))
            ]

            # TODO(jvdillon): Remove once we support k-mixtures.
            # We make this assertion here because otherwise `grid` would need to be a
            # vector not a scalar.
            if len(scale) != 2:
                raise NotImplementedError(
                    "Currently only bimixtures are supported; "
                    "len(scale)={} is not 2.".format(len(scale)))

            mix_loc = tf.convert_to_tensor(mix_loc,
                                           dtype=dtype,
                                           name="mix_loc")
            temperature = tf.convert_to_tensor(temperature,
                                               dtype=dtype,
                                               name="temperature")
            self._grid, probs = tuple(
                quadrature_fn(mix_loc / temperature, 1. / temperature,
                              quadrature_size, validate_args))

            # Note: by creating the logits as `log(prob)` we ensure that
            # `self.mixture_distribution.logits` is equivalent to
            # `math_ops.log(self.mixture_distribution.probs)`.
            self._mixture_distribution = categorical.Categorical(
                logits=tf.math.log(probs),
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats)

            asserts = distribution_util.maybe_check_scalar_distribution(
                distribution, dtype, validate_args)
            if asserts:
                self._grid = distribution_util.with_dependencies(
                    asserts, self._grid)
            self._distribution = distribution

            self._interpolated_affine = [
                affine_linear_operator_bijector.AffineLinearOperator(  # pylint: disable=g-complex-comprehension
                    shift=loc_,
                    scale=scale_,
                    validate_args=validate_args,
                    name="interpolated_affine_{}".format(k))
                for k, (loc_, scale_) in enumerate(
                    zip(interpolate_loc(self._grid, loc),
                        interpolate_scale(self._grid, scale)))
            ]

            [
                self._batch_shape_,
                self._batch_shape_tensor_,
                self._event_shape_,
                self._event_shape_tensor_,
            ] = determine_batch_event_shapes(self._grid, self._endpoint_affine)

            super(VectorDiffeomixture, self).__init__(
                dtype=dtype,
                # We hard-code `FULLY_REPARAMETERIZED` because when
                # `validate_args=True` we verify that indeed
                # `distribution.reparameterization_type == FULLY_REPARAMETERIZED`. A
                # distribution which is a function of only non-trainable parameters
                # also implies we can use `FULLY_REPARAMETERIZED`. However, we cannot
                # easily test for that possibility thus we use `validate_args=False`
                # as a "back-door" to allow users a way to use non
                # `FULLY_REPARAMETERIZED` distribution. In such cases IT IS THE USERS
                # RESPONSIBILITY to verify that the base distribution is a function of
                # non-trainable parameters.
                reparameterization_type=reparameterization.
                FULLY_REPARAMETERIZED,
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats,
                parameters=parameters,
                name=name)
def quadrature_scheme_softmaxnormal_quantiles(normal_loc,
                                              normal_scale,
                                              quadrature_size,
                                              validate_args=False,
                                              name=None):
    """Use SoftmaxNormal quantiles to form quadrature on `K - 1` simplex.

  A `SoftmaxNormal` random variable `Y` may be generated via

  ```
  Y = SoftmaxCentered(X),
  X = Normal(normal_loc, normal_scale)
  ```

  Args:
    normal_loc: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`, B>=0.
      The location parameter of the Normal used to construct the SoftmaxNormal.
    normal_scale: `float`-like `Tensor`. Broadcastable with `normal_loc`.
      The scale parameter of the Normal used to construct the SoftmaxNormal.
    quadrature_size: Python `int` scalar representing the number of quadrature
      points.
    validate_args: Python `bool`, default `False`. When `True` distribution
      parameters are checked for validity despite possibly degrading runtime
      performance. When `False` invalid inputs may silently render incorrect
      outputs.
    name: Python `str` name prefixed to Ops created by this class.

  Returns:
    grid: Shape `[b1, ..., bB, K, quadrature_size]` `Tensor` representing the
      convex combination of affine parameters for `K` components.
      `grid[..., :, n]` is the `n`-th grid point, living in the `K - 1` simplex.
    probs:  Shape `[b1, ..., bB, K, quadrature_size]` `Tensor` representing the
      associated with each grid point.
  """
    with tf.name_scope(name or "softmax_normal_grid_and_probs"):
        normal_loc = tf.convert_to_tensor(normal_loc, name="normal_loc")
        dt = dtype_util.base_dtype(normal_loc.dtype)
        normal_scale = tf.convert_to_tensor(normal_scale,
                                            dtype=dt,
                                            name="normal_scale")

        normal_scale = maybe_check_quadrature_param(normal_scale,
                                                    "normal_scale",
                                                    validate_args)

        dist = normal.Normal(loc=normal_loc, scale=normal_scale)

        def _get_batch_ndims():
            """Helper to get rank(dist.batch_shape), statically if possible."""
            ndims = tensorshape_util.rank(dist.batch_shape)
            if ndims is None:
                ndims = tf.shape(dist.batch_shape_tensor())[0]
            return ndims

        batch_ndims = _get_batch_ndims()

        def _get_final_shape(qs):
            """Helper to build `TensorShape`."""
            bs = tensorshape_util.with_rank_at_least(dist.batch_shape, 1)
            num_components = tf.compat.dimension_value(bs[-1])
            if num_components is not None:
                num_components += 1
            tail = tf.TensorShape([num_components, qs])
            return bs[:-1].concatenate(tail)

        def _compute_quantiles():
            """Helper to build quantiles."""
            # Omit {0, 1} since they might lead to Inf/NaN.
            zero = tf.zeros([], dtype=dist.dtype)
            edges = tf.linspace(zero, 1., quadrature_size + 3)[1:-1]
            # Expand edges so its broadcast across batch dims.
            edges = tf.reshape(
                edges,
                shape=tf.concat(
                    [[-1], tf.ones([batch_ndims], dtype=tf.int32)], axis=0))
            quantiles = dist.quantile(edges)
            quantiles = softmax_centered_bijector.SoftmaxCentered().forward(
                quantiles)
            # Cyclically permute left by one.
            perm = tf.concat([tf.range(1, 1 + batch_ndims), [0]], axis=0)
            quantiles = tf.transpose(a=quantiles, perm=perm)
            tensorshape_util.set_shape(quantiles,
                                       _get_final_shape(quadrature_size + 1))
            return quantiles

        quantiles = _compute_quantiles()

        # Compute grid as quantile midpoints.
        grid = (quantiles[..., :-1] + quantiles[..., 1:]) / 2.
        # Set shape hints.
        tensorshape_util.set_shape(grid, _get_final_shape(quadrature_size))

        # By construction probs is constant, i.e., `1 / quadrature_size`. This is
        # important, because non-constant probs leads to non-reparameterizable
        # samples.
        probs = tf.fill(dims=[quadrature_size],
                        value=1. / tf.cast(quadrature_size, dist.dtype))

        return grid, probs
Пример #12
0
 def _log_survival_function(self, value):
     rate = tf.convert_to_tensor(self._rate)
     return self._log_prob(value, rate=rate) - tf.math.log(rate)
Пример #13
0
 def _logits_parameter_no_checks(self):
     if self._logits is None:
         probs = tf.convert_to_tensor(self._probs)
         return tf.math.log(probs) - tf.math.log1p(-probs)
     return tf.identity(self._logits)
Пример #14
0
    def _sample_n(self, num_samples, seed=None, name=None):
        """Returns a Tensor of samples from an LKJ distribution.

    Args:
      num_samples: Python `int`. The number of samples to draw.
      seed: Python integer seed for RNG
      name: Python `str` name prefixed to Ops created by this function.

    Returns:
      samples: A Tensor of correlation matrices with shape `[n, B, D, D]`,
        where `B` is the shape of the `concentration` parameter, and `D`
        is the `dimension`.

    Raises:
      ValueError: If `dimension` is negative.
    """
        if self.dimension < 0:
            raise ValueError(
                'Cannot sample negative-dimension correlation matrices.')
        # Notation below: B is the batch shape, i.e., tf.shape(concentration)
        seed = SeedStream(seed, 'sample_lkj')
        with tf.name_scope('sample_lkj' or name):
            concentration = tf.convert_to_tensor(self.concentration)
            if not dtype_util.is_floating(concentration.dtype):
                raise TypeError(
                    'The concentration argument should have floating type, not '
                    '{}'.format(dtype_util.name(concentration.dtype)))

            concentration = _replicate(num_samples, concentration)
            concentration_shape = tf.shape(concentration)
            if self.dimension <= 1:
                # For any dimension <= 1, there is only one possible correlation matrix.
                shape = tf.concat(
                    [concentration_shape, [self.dimension, self.dimension]],
                    axis=0)
                return tf.ones(shape=shape, dtype=concentration.dtype)
            beta_conc = concentration + (self.dimension - 2.) / 2.
            beta_dist = beta.Beta(concentration1=beta_conc,
                                  concentration0=beta_conc)

            # Note that the sampler below deviates from [1], by doing the sampling in
            # cholesky space. This does not change the fundamental logic of the
            # sampler, but does speed up the sampling.

            # This is the correlation coefficient between the first two dimensions.
            # This is also `r` in reference [1].
            corr12 = 2. * beta_dist.sample(seed=seed()) - 1.

            # Below we construct the Cholesky of the initial 2x2 correlation matrix,
            # which is of the form:
            # [[1, 0], [r, sqrt(1 - r**2)]], where r is the correlation between the
            # first two dimensions.
            # This is the top-left corner of the cholesky of the final sample.
            first_row = tf.concat([
                tf.ones_like(corr12)[..., tf.newaxis],
                tf.zeros_like(corr12)[..., tf.newaxis]
            ],
                                  axis=-1)
            second_row = tf.concat([
                corr12[..., tf.newaxis],
                tf.sqrt(1 - corr12**2)[..., tf.newaxis]
            ],
                                   axis=-1)

            chol_result = tf.concat([
                first_row[..., tf.newaxis, :], second_row[..., tf.newaxis, :]
            ],
                                    axis=-2)

            for n in range(2, self.dimension):
                # Loop invariant: on entry, result has shape B + [n, n]
                beta_conc = beta_conc - 0.5
                # norm is y in reference [1].
                norm = beta.Beta(concentration1=n / 2.,
                                 concentration0=beta_conc).sample(seed=seed())
                # distance shape: B + [1] for broadcast
                distance = tf.sqrt(norm)[..., tf.newaxis]
                # direction is u in reference [1].
                # direction shape: B + [n]
                direction = _uniform_unit_norm(n, concentration_shape,
                                               concentration.dtype, seed)
                # raw_correlation is w in reference [1].
                raw_correlation = distance * direction  # shape: B + [n]

                # This is the next row in the cholesky of the result,
                # which differs from the construction in reference [1].
                # In the reference, the new row `z` = chol_result @ raw_correlation^T
                # = C @ raw_correlation^T (where as short hand we use C = chol_result).
                # We prove that the below equation is the right row to add to the
                # cholesky, by showing equality with reference [1].
                # Let S be the sample constructed so far, and let `z` be as in
                # reference [1]. Then at this iteration, the new sample S' will be
                # [[S z^T]
                #  [z 1]]
                # In our case we have the cholesky decomposition factor C, so
                # we want our new row x (same size as z) to satisfy:
                #  [[S z^T]  [[C 0]    [[C^T  x^T]         [[CC^T  Cx^T]
                #   [z 1]] =  [x k]]    [0     k]]  =       [xC^t   xx^T + k**2]]
                # Since C @ raw_correlation^T = z = C @ x^T, and C is invertible,
                # we have that x = raw_correlation. Also 1 = xx^T + k**2, so k
                # = sqrt(1 - xx^T) = sqrt(1 - |raw_correlation|**2) = sqrt(1 -
                # distance**2).
                new_row = tf.concat(
                    [raw_correlation,
                     tf.sqrt(1. - norm[..., tf.newaxis])],
                    axis=-1)

                # Finally add this new row, by growing the cholesky of the result.
                chol_result = tf.concat([
                    chol_result,
                    tf.zeros_like(chol_result[..., 0][..., tf.newaxis])
                ],
                                        axis=-1)

                chol_result = tf.concat(
                    [chol_result, new_row[..., tf.newaxis, :]], axis=-2)

            if self.input_output_cholesky:
                return chol_result

            result = tf.matmul(chol_result, chol_result, transpose_b=True)
            # The diagonal for a correlation matrix should always be ones. Due to
            # numerical instability the matmul might not achieve that, so manually set
            # these to ones.
            result = tf.linalg.set_diag(
                result, tf.ones(shape=tf.shape(result)[:-1],
                                dtype=result.dtype))
            # This sampling algorithm can produce near-PSD matrices on which standard
            # algorithms such as `tf.cholesky` or `tf.linalg.self_adjoint_eigvals`
            # fail. Specifically, as documented in b/116828694, around 2% of trials
            # of 900,000 5x5 matrices (distributed according to 9 different
            # concentration parameter values) contained at least one matrix on which
            # the Cholesky decomposition failed.
            return result
Пример #15
0
 def _param_shapes(sample_shape):
     return dict(
         zip(('concentration', 'rate'),
             ([tf.convert_to_tensor(sample_shape, dtype=tf.int32)] * 2)))
Пример #16
0
 def _stddev(self):
     samples = tf.convert_to_tensor(self._samples)
     axis = self._samples_axis
     r = samples - tf.expand_dims(self._mean(samples), axis=axis)
     var = tf.reduce_mean(tf.square(r), axis=axis)
     return tf.sqrt(var)
Пример #17
0
 def _mean(self, samples=None):
     if samples is None:
         samples = tf.convert_to_tensor(self._samples)
     return tf.reduce_mean(samples, axis=self._samples_axis)
Пример #18
0
 def _event_shape_tensor(self, samples=None):
     if samples is None:
         samples = tf.convert_to_tensor(self.samples)
     return tf.shape(samples)[self._samples_axis + 1:]
Пример #19
0
 def _batch_shape_tensor(self, samples=None):
     if samples is None:
         samples = tf.convert_to_tensor(self.samples)
     return tf.shape(samples)[:self._samples_axis]
Пример #20
0
 def _compute_num_samples(self, samples):
     samples_shape = distribution_util.prefer_static_shape(samples)
     return tf.convert_to_tensor(samples_shape[self._samples_axis],
                                 dtype_hint=tf.int32,
                                 name='num_samples')
Пример #21
0
 def _param_shapes(sample_shape):
     return dict(
         zip(('df', 'loc', 'scale'),
             ([tf.convert_to_tensor(sample_shape, dtype=tf.int32)] * 3)))
Пример #22
0
 def _param_shapes(sample_shape):
     return dict(
         zip(('low', 'high'),
             ([tf.convert_to_tensor(sample_shape, dtype=tf.int32)] * 2)))
Пример #23
0
 def _mode(self):
     loc = tf.convert_to_tensor(self.loc)
     return tf.broadcast_to(loc, self._batch_shape_tensor(loc=loc))
Пример #24
0
 def _convert_to_tensor(x, name, dtype=None):
     return None if x is None else tf.convert_to_tensor(
         x, name=name, dtype=dtype)
Пример #25
0
 def _entropy(self):
     concentration = tf.convert_to_tensor(self.concentration)
     return (concentration - tf.math.log(self.rate) +
             tf.math.lgamma(concentration) +
             ((1. - concentration) * tf.math.digamma(concentration)))
Пример #26
0
 def _variance(self):
     p = self._probs_parameter_no_checks()
     k = tf.convert_to_tensor(self.total_count)
     return k[..., tf.newaxis] * p * (1. - p)
    def __init__(self,
                 distribution,
                 low=None,
                 high=None,
                 validate_args=False,
                 name="QuantizedDistribution"):
        """Construct a Quantized Distribution representing `Y = ceiling(X)`.

    Some properties are inherited from the distribution defining `X`. Example:
    `allow_nan_stats` is determined for this `QuantizedDistribution` by reading
    the `distribution`.

    Args:
      distribution:  The base distribution class to transform. Typically an
        instance of `Distribution`.
      low: `Tensor` with same `dtype` as this distribution and shape
        able to be added to samples. Should be a whole number. Default `None`.
        If provided, base distribution's `prob` should be defined at
        `low`.
      high: `Tensor` with same `dtype` as this distribution and shape
        able to be added to samples. Should be a whole number. Default `None`.
        If provided, base distribution's `prob` should be defined at
        `high - 1`.
        `high` must be strictly greater than `low`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      TypeError: If `dist_cls` is not a subclass of
          `Distribution` or continuous.
      NotImplementedError:  If the base distribution does not implement `cdf`.
    """
        parameters = dict(locals())
        with tf.name_scope(name) as name:
            self._dist = distribution

            if low is not None:
                low = tf.convert_to_tensor(low,
                                           name="low",
                                           dtype=distribution.dtype)
            if high is not None:
                high = tf.convert_to_tensor(high,
                                            name="high",
                                            dtype=distribution.dtype)
            dtype_util.assert_same_float_dtype(
                tensors=[self.distribution, low, high])

            checks = []
            if validate_args and low is not None and high is not None:
                message = "low must be strictly less than high."
                checks.append(
                    assert_util.assert_less(low, high, message=message))
            self._validate_args = validate_args  # self._check_integer uses this.
            with tf.control_dependencies(checks if validate_args else []):
                if low is not None:
                    self._low = self._check_integer(low)
                else:
                    self._low = None
                if high is not None:
                    self._high = self._check_integer(high)
                else:
                    self._high = None

        super(QuantizedDistribution, self).__init__(
            dtype=self._dist.dtype,
            reparameterization_type=reparameterization.NOT_REPARAMETERIZED,
            validate_args=validate_args,
            allow_nan_stats=self._dist.allow_nan_stats,
            parameters=parameters,
            name=name)
Пример #28
0
 def _prob(self, x):
     loc = tf.convert_to_tensor(self.loc)
     # Enforces dtype of probability to be float, when self.dtype is not.
     prob_dtype = self.dtype if self.dtype.is_floating else tf.float32
     return tf.cast(tf.abs(x - loc) <= self._slack(loc), dtype=prob_dtype)
Пример #29
0
    def _parameter_control_dependencies(self, is_init):
        assertions = []

        logits = self._logits
        probs = self._probs
        param, name = (probs, 'probs') if logits is None else (logits,
                                                               'logits')

        # In init, we can always build shape and dtype checks because
        # we assume shape doesn't change for Variable backed args.
        if is_init:
            if not dtype_util.is_floating(param.dtype):
                raise TypeError(
                    'Argument `{}` must having floating type.'.format(name))

            msg = 'Argument `{}` must have rank at least 1.'.format(name)
            shape_static = tensorshape_util.dims(param.shape)
            if shape_static is not None:
                if len(shape_static) < 1:
                    raise ValueError(msg)
            elif self.validate_args:
                param = tf.convert_to_tensor(param)
                assertions.append(
                    assert_util.assert_rank_at_least(param, 1, message=msg))
                with tf.control_dependencies(assertions):
                    param = tf.identity(param)

            msg1 = 'Argument `{}` must have final dimension >= 1.'.format(name)
            msg2 = 'Argument `{}` must have final dimension <= {}.'.format(
                name, tf.int32.max)
            event_size = shape_static[-1] if shape_static is not None else None
            if event_size is not None:
                if event_size < 1:
                    raise ValueError(msg1)
                if event_size > tf.int32.max:
                    raise ValueError(msg2)
            elif self.validate_args:
                param = tf.convert_to_tensor(param)
                assertions.append(
                    assert_util.assert_greater_equal(tf.shape(param)[-1],
                                                     1,
                                                     message=msg1))
                # NOTE: For now, we leave out a runtime assertion that
                # `tf.shape(param)[-1] <= tf.int32.max`.  An earlier `tf.shape` call
                # will fail before we get to this point.

        if not self.validate_args:
            assert not assertions  # Should never happen.
            return []

        if probs is not None:
            probs = param  # reuse tensor conversion from above
            if is_init != tensor_util.is_ref(probs):
                probs = tf.convert_to_tensor(probs)
                one = tf.ones([], dtype=probs.dtype)
                assertions.extend([
                    assert_util.assert_non_negative(probs),
                    assert_util.assert_less_equal(probs, one),
                    assert_util.assert_near(
                        tf.reduce_sum(probs, axis=-1),
                        one,
                        message='Argument `probs` must sum to 1.'),
                ])

        return assertions
Пример #30
0
 def _param_shapes(sample_shape):
     return {"rate": tf.convert_to_tensor(sample_shape, dtype=tf.int32)}