예제 #1
0
    def _sample_n(self, n, seed=None):
        power = tf.convert_to_tensor(self.power)
        shape = ps.concat([[n], ps.shape(power)], axis=0)
        numpy_dtype = dtype_util.as_numpy_dtype(power.dtype)

        seed = samplers.sanitize_seed(seed, salt='zipf')

        # Because `_hat_integral` is montonically decreasing, the bounds for u will
        # switch.
        # Compute the hat_integral explicitly here since we can calculate the log of
        # the inputs statically in float64 with numpy.
        maxval_u = tf.math.exp(-(power - 1.) * numpy_dtype(np.log1p(0.5)) -
                               tf.math.log(power - 1.)) + 1.
        minval_u = tf.math.exp(
            -(power - 1.) *
            numpy_dtype(np.log1p(dtype_util.max(self.dtype) - 0.5)) -
            tf.math.log(power - 1.))

        def loop_body(should_continue, k, seed):
            """Resample the non-accepted points."""
            u_seed, next_seed = samplers.split_seed(seed)
            # Uniform variates must be sampled from the open-interval `(0, 1)` rather
            # than `[0, 1)`. To do so, we use
            # `np.finfo(dtype_util.as_numpy_dtype(self.dtype)).tiny`
            # because it is the smallest, positive, 'normal' number. A 'normal' number
            # is such that the mantissa has an implicit leading 1. Normal, positive
            # numbers x, y have the reasonable property that, `x + y >= max(x, y)`. In
            # this case, a subnormal number (i.e., np.nextafter) can cause us to
            # sample 0.
            u = samplers.uniform(
                shape,
                minval=np.finfo(dtype_util.as_numpy_dtype(power.dtype)).tiny,
                maxval=numpy_dtype(1.),
                dtype=power.dtype,
                seed=u_seed)
            # We use (1 - u) * maxval_u + u * minval_u rather than the other way
            # around, since we want to draw samples in (minval_u, maxval_u].
            u = maxval_u + (minval_u - maxval_u) * u
            # set_shape needed here because of b/139013403
            tensorshape_util.set_shape(u, should_continue.shape)

            # Sample the point X from the continuous density h(x) \propto x^(-power).
            x = self._hat_integral_inverse(u, power=power)

            # Rejection-inversion requires a `hat` function, h(x) such that
            # \int_{k - .5}^{k + .5} h(x) dx >= pmf(k + 1) for points k in the
            # support. A natural hat function for us is h(x) = x^(-power).
            #
            # After sampling X from h(x), suppose it lies in the interval
            # (K - .5, K + .5) for integer K. Then the corresponding K is accepted if
            # if lies to the left of x_K, where x_K is defined by:
            #   \int_{x_k}^{K + .5} h(x) dx = H(x_K) - H(K + .5) = pmf(K + 1),
            # where H(x) = \int_x^inf h(x) dx.

            # Solving for x_K, we find that x_K = H_inverse(H(K + .5) + pmf(K + 1)).
            # Or, the acceptance condition is X <= H_inverse(H(K + .5) + pmf(K + 1)).
            # Since X = H_inverse(U), this simplifies to U <= H(K + .5) + pmf(K + 1).

            # Update the non-accepted points.
            # Since X \in (K - .5, K + .5), the sample K is chosen as floor(X + 0.5).
            k = tf.where(should_continue, tf.floor(x + 0.5), k)
            accept = (u <= self._hat_integral(k + .5, power=power) +
                      tf.exp(self._log_prob(k + 1, power=power)))

            return [should_continue & (~accept), k, next_seed]

        should_continue, samples, _ = tf.while_loop(
            cond=lambda should_continue, *ignore: tf.reduce_any(should_continue
                                                                ),
            body=loop_body,
            loop_vars=[
                tf.ones(shape, dtype=tf.bool),  # should_continue
                tf.zeros(shape, dtype=power.dtype),  # k
                seed,  # seed
            ],
            maximum_iterations=self.sample_maximum_iterations,
        )
        samples = samples + 1.

        if self.validate_args and dtype_util.is_integer(self.dtype):
            samples = distribution_util.embed_check_integer_casting_closed(
                samples, target_dtype=self.dtype, assert_positive=True)

        samples = tf.cast(samples, self.dtype)

        if self.validate_args:
            npdt = dtype_util.as_numpy_dtype(self.dtype)
            v = npdt(
                dtype_util.min(npdt) if dtype_util.is_integer(npdt) else np.nan
            )
            samples = tf.where(should_continue, v, samples)

        return samples
예제 #2
0
 def _cast_dtype(dtype):
     if dtype_util.as_numpy_dtype(dtype) is np.int64:
         return tf.float64
     elif dtype_util.is_integer(dtype):
         return tf.float32
     return dtype
예제 #3
0
 def testIsInteger(self):
   self.assertFalse(dtype_util.is_integer(np.float64))
예제 #4
0
    def __init__(self, permutation, axis=-1, validate_args=False, name=None):
        """Creates the `Permute` bijector.

    Args:
      permutation: An `int`-like vector-shaped `Tensor` representing the
        permutation to apply to the `axis` dimension of the transformed
        `Tensor`.
      axis: Scalar `int` `Tensor` representing the dimension over which to
        `tf.gather`. `axis` must be relative to the end (reading left to right)
        thus must be negative.
        Default value: `-1` (i.e., right-most).
      validate_args: Python `bool` indicating whether arguments should be
        checked for correctness.
      name: Python `str`, name given to ops managed by this object.

    Raises:
      TypeError: if `not dtype_util.is_integer(permutation.dtype)`.
      ValueError: if `permutation` does not contain exactly one of each of
        `{0, 1, ..., d}`.
      NotImplementedError: if `axis` is not known prior to graph execution.
      NotImplementedError: if `axis` is not negative.
    """
        with tf.name_scope(name or "permute") as name:
            axis = tf.convert_to_tensor(axis, name="axis")
            if not dtype_util.is_integer(axis.dtype):
                raise TypeError("axis.dtype ({}) should be `int`-like.".format(
                    dtype_util.name(axis.dtype)))
            permutation = tf.convert_to_tensor(permutation, name="permutation")
            if not dtype_util.is_integer(permutation.dtype):
                raise TypeError(
                    "permutation.dtype ({}) should be `int`-like.".format(
                        dtype_util.name(permutation.dtype)))
            p = tf.get_static_value(permutation)
            if p is not None:
                if set(p) != set(np.arange(p.size)):
                    raise ValueError(
                        "Permutation over `d` must contain exactly one of "
                        "each of `{0, 1, ..., d}`.")
            elif validate_args:
                p, _ = tf.math.top_k(-permutation,
                                     k=tf.shape(permutation)[-1],
                                     sorted=True)
                permutation = distribution_util.with_dependencies([
                    assert_util.assert_equal(
                        -p,
                        tf.range(tf.size(p)),
                        message=(
                            "Permutation over `d` must contain exactly one of "
                            "each of `{0, 1, ..., d}`.")),
                ], permutation)
            axis_ = tf.get_static_value(axis)
            if axis_ is None:
                raise NotImplementedError(
                    "`axis` must be known prior to graph "
                    "execution.")
            elif axis_ >= 0:
                raise NotImplementedError(
                    "`axis` must be relative the rightmost "
                    "dimension, i.e., negative.")
            else:
                forward_min_event_ndims = int(np.abs(axis_))
            self._permutation = permutation
            self._axis = axis
            super(Permute, self).__init__(
                forward_min_event_ndims=forward_min_event_ndims,
                is_constant_jacobian=True,
                validate_args=validate_args,
                name=name)
예제 #5
0
def percentile(x,
               q,
               axis=None,
               interpolation=None,
               keepdims=False,
               validate_args=False,
               preserve_gradients=True,
               keep_dims=None,
               name=None):
  """Compute the `q`-th percentile(s) of `x`.

  Given a vector `x`, the `q`-th percentile of `x` is the value `q / 100` of the
  way from the minimum to the maximum in a sorted copy of `x`.

  The values and distances of the two nearest neighbors as well as the
  `interpolation` parameter will determine the percentile if the normalized
  ranking does not match the location of `q` exactly.

  This function is the same as the median if `q = 50`, the same as the minimum
  if `q = 0` and the same as the maximum if `q = 100`.

  Multiple percentiles can be computed at once by using `1-D` vector `q`.
  Dimension zero of the returned `Tensor` will index the different percentiles.

  Compare to `numpy.percentile`.

  Args:
    x:  Numeric `N-D` `Tensor` with `N > 0`.  If `axis` is not `None`,
      `x` must have statically known number of dimensions.
    q:  Scalar or vector `Tensor` with values in `[0, 100]`. The percentile(s).
    axis:  Optional `0-D` or `1-D` integer `Tensor` with constant values. The
      axis that index independent samples over which to return the desired
      percentile.  If `None` (the default), treat every dimension as a sample
      dimension, returning a scalar.
    interpolation : {'nearest', 'linear', 'lower', 'higher', 'midpoint'}.
      Default value: 'nearest'.  This specifies the interpolation method to
      use when the desired quantile lies between two data points `i < j`:
        * linear: i + (j - i) * fraction, where fraction is the fractional part
          of the index surrounded by i and j.
        * lower: `i`.
        * higher: `j`.
        * nearest: `i` or `j`, whichever is nearest.
        * midpoint: (i + j) / 2.
      `linear` and `midpoint` interpolation do not work with integer dtypes.
    keepdims:  Python `bool`. If `True`, the last dimension is kept with size 1
      If `False`, the last dimension is removed from the output shape.
    validate_args:  Whether to add runtime checks of argument validity. If
      False, and arguments are incorrect, correct behavior is not guaranteed.
    preserve_gradients:  Python `bool`.  If `True`, ensure that gradient w.r.t
      the percentile `q` is preserved in the case of linear interpolation.
      If `False`, the gradient will be (incorrectly) zero when `q` corresponds
      to a point in `x`.
    keep_dims: deprecated, use keepdims instead.
    name:  A Python string name to give this `Op`.  Default is 'percentile'

  Returns:
    A `(rank(q) + N - len(axis))` dimensional `Tensor` of same dtype as `x`, or,
      if `axis` is `None`, a `rank(q)` `Tensor`.  The first `rank(q)` dimensions
      index quantiles for different values of `q`.

  Raises:
    ValueError:  If argument 'interpolation' is not an allowed type.
    ValueError:  If interpolation type not compatible with `dtype`.

  #### Examples

  ```python
  # Get 30th percentile with default ('nearest') interpolation.
  x = [1., 2., 3., 4.]
  tfp.stats.percentile(x, q=30.)
  ==> 2.0

  # Get 30th percentile with 'linear' interpolation.
  x = [1., 2., 3., 4.]
  tfp.stats.percentile(x, q=30., interpolation='linear')
  ==> 1.9

  # Get 30th and 70th percentiles with 'lower' interpolation
  x = [1., 2., 3., 4.]
  tfp.stats.percentile(x, q=[30., 70.], interpolation='lower')
  ==> [1., 3.]

  # Get 100th percentile (maximum).  By default, this is computed over every dim
  x = [[1., 2.]
       [3., 4.]]
  tfp.stats.percentile(x, q=100.)
  ==> 4.

  # Treat the leading dim as indexing samples, and find the 100th quantile (max)
  # over all such samples.
  x = [[1., 2.]
       [3., 4.]]
  tfp.stats.percentile(x, q=100., axis=[0])
  ==> [3., 4.]
  ```

  """
  keepdims = keepdims if keep_dims is None else keep_dims
  del keep_dims
  name = name or 'percentile'
  allowed_interpolations = {'linear', 'lower', 'higher', 'nearest', 'midpoint'}

  if interpolation is None:
    interpolation = 'nearest'
  else:
    if interpolation not in allowed_interpolations:
      raise ValueError(
          'Argument `interpolation` must be in {}. Found {}.'.format(
              allowed_interpolations, interpolation))

  with tf.name_scope(name):
    x = tf.convert_to_tensor(x, name='x')

    if (interpolation in {'linear', 'midpoint'} and
        dtype_util.is_integer(x.dtype)):
      raise TypeError('{} interpolation not allowed with dtype {}'.format(
          interpolation, x.dtype))

    # Double is needed here and below, else we get the wrong index if the array
    # is huge along axis.
    q = tf.cast(q, tf.float64)
    _get_static_ndims(q, expect_ndims_no_more_than=1)

    if validate_args:
      q = distribution_util.with_dependencies([
          assert_util.assert_rank_in(q, [0, 1]),
          assert_util.assert_greater_equal(q, tf.cast(0., tf.float64)),
          assert_util.assert_less_equal(q, tf.cast(100., tf.float64))
      ], q)

    # Move `axis` dims of `x` to the rightmost, call it `y`.
    if axis is None:
      y = tf.reshape(x, [-1])
    else:
      x_ndims = _get_static_ndims(
          x, expect_static=True, expect_ndims_at_least=1)
      axis = _make_static_axis_non_negative_list(axis, x_ndims)
      y = _move_dims_to_flat_end(x, axis, x_ndims, right_end=True)

    frac_at_q_or_above = 1. - q / 100.

    # Sort everything, not just the top 'k' entries, which allows multiple calls
    # to sort only once (under the hood) and use CSE.
    sorted_y = _sort_tensor(y)

    d = tf.cast(tf.shape(y)[-1], tf.float64)

    def _get_indices(interp_type):
      """Get values of y at the indices implied by interp_type."""
      # Note `lower` <--> ceiling.  Confusing, huh?  Due to the fact that
      # _sort_tensor sorts highest to lowest, tf.ceil corresponds to the higher
      # index, but the lower value of y!
      if interp_type == 'lower':
        indices = tf.math.ceil((d - 1) * frac_at_q_or_above)
      elif interp_type == 'higher':
        indices = tf.floor((d - 1) * frac_at_q_or_above)
      elif interp_type == 'nearest':
        indices = tf.round((d - 1) * frac_at_q_or_above)
      # d - 1 will be distinct from d in int32, but not necessarily double.
      # So clip to avoid out of bounds errors.
      return tf.clip_by_value(
          tf.cast(indices, tf.int32), 0,
          tf.shape(y)[-1] - 1)

    if interpolation in ['nearest', 'lower', 'higher']:
      gathered_y = tf.gather(sorted_y, _get_indices(interpolation), axis=-1)
    elif interpolation == 'midpoint':
      gathered_y = 0.5 * (
          tf.gather(sorted_y, _get_indices('lower'), axis=-1) +
          tf.gather(sorted_y, _get_indices('higher'), axis=-1))
    elif interpolation == 'linear':
      # Copy-paste of docstring on interpolation:
      # linear: i + (j - i) * fraction, where fraction is the fractional part
      # of the index surrounded by i and j.
      larger_y_idx = _get_indices('lower')
      exact_idx = (d - 1) * frac_at_q_or_above
      if preserve_gradients:
        # If q corresponds to a point in x, we will initially have
        # larger_y_idx == smaller_y_idx.
        # This results in the gradient w.r.t. fraction being zero (recall `q`
        # enters only through `fraction`...and see that things cancel).
        # The fix is to ensure that smaller_y_idx and larger_y_idx are always
        # separated by exactly 1.
        smaller_y_idx = tf.maximum(larger_y_idx - 1, 0)
        larger_y_idx = tf.minimum(smaller_y_idx + 1, tf.shape(y)[-1] - 1)
        fraction = tf.cast(larger_y_idx, tf.float64) - exact_idx
      else:
        smaller_y_idx = _get_indices('higher')
        fraction = tf.math.ceil((d - 1) * frac_at_q_or_above) - exact_idx

      fraction = tf.cast(fraction, y.dtype)
      gathered_y = (
          tf.gather(sorted_y, larger_y_idx, axis=-1) * (1 - fraction) +
          tf.gather(sorted_y, smaller_y_idx, axis=-1) * fraction)

    # Propagate NaNs
    if x.dtype in (tf.bfloat16, tf.float16, tf.float32, tf.float64):
      # Apparently tf.is_nan doesn't like other dtypes
      nan_batch_members = tf.reduce_any(tf.math.is_nan(x), axis=axis)
      right_rank_matched_shape = tf.pad(
          tf.shape(nan_batch_members),
          paddings=[[0, tf.rank(q)]],
          constant_values=1)
      nan_batch_members = tf.reshape(
          nan_batch_members, shape=right_rank_matched_shape)
      nan = np.array(np.nan, dtype_util.as_numpy_dtype(gathered_y.dtype))
      gathered_y = tf.where(nan_batch_members, nan, gathered_y)

    # Expand dimensions if requested
    if keepdims:
      if axis is None:
        ones_vec = tf.ones(
            shape=[_get_best_effort_ndims(x) + _get_best_effort_ndims(q)],
            dtype=tf.int32)
        gathered_y *= tf.ones(ones_vec, dtype=x.dtype)
      else:
        gathered_y = _insert_back_keepdims(gathered_y, axis)

    # If q is a scalar, then result has the right shape.
    # If q is a vector, then result has trailing dim of shape q.shape, which
    # needs to be rotated to dim 0.
    return distribution_util.rotate_transpose(gathered_y, tf.rank(q))
예제 #6
0
    def _parameter_control_dependencies(self, is_init):
        assertions = []

        # Check num_steps is a scalar that's at least 1.
        if is_init != tensor_util.is_ref(self.num_steps):
            num_steps = tf.convert_to_tensor(self.num_steps)
            num_steps_ = tf.get_static_value(num_steps)
            if num_steps_ is not None:
                if np.ndim(num_steps_) != 0:
                    raise ValueError(
                        '`num_steps` must be a scalar but it has rank {}'.
                        format(np.ndim(num_steps_)))
                if num_steps_ < 1:
                    raise ValueError('`num_steps` must be at least 1.')
            elif self.validate_args:
                message = '`num_steps` must be a scalar'
                assertions.append(
                    assert_util.assert_rank_at_most(self.num_steps,
                                                    0,
                                                    message=message))
                assertions.append(
                    assert_util.assert_greater_equal(
                        num_steps,
                        1,
                        message='`num_steps` must be at least 1.'))

        # Check that the initial distribution has scalar events over the
        # integers.
        if is_init and not dtype_util.is_integer(
                self.initial_distribution.dtype):
            raise ValueError(
                '`initial_distribution.dtype` ({}) is not over integers'.
                format(dtype_util.name(self.initial_distribution.dtype)))

        if tensorshape_util.rank(
                self.initial_distribution.event_shape) is not None:
            if tensorshape_util.rank(
                    self.initial_distribution.event_shape) != 0:
                raise ValueError(
                    '`initial_distribution` must have scalar `event_dim`s')
        elif self.validate_args:
            assertions += [
                assert_util.assert_equal(
                    tf.size(self.initial_distribution.event_shape_tensor()),
                    0,
                    message=
                    '`initial_distribution` must have scalar `event_dim`s'),
            ]

        # Check that the transition distribution is over the integers.
        if (is_init and
                not dtype_util.is_integer(self.transition_distribution.dtype)):
            raise ValueError(
                '`transition_distribution.dtype` ({}) is not over integers'.
                format(dtype_util.name(self.transition_distribution.dtype)))

        # Check observations have non-scalar batches.
        # The graph version of this assertion is incorporated as
        # a control dependency of the transition/observation
        # compatibility test.
        if tensorshape_util.rank(
                self.observation_distribution.batch_shape) == 0:
            raise ValueError(
                "`observation_distribution` can't have scalar batches")

        # Check transitions have non-scalar batches.
        # The graph version of this assertion is incorporated as
        # a control dependency of the transition/observation
        # compatibility test.
        if tensorshape_util.rank(
                self.transition_distribution.batch_shape) == 0:
            raise ValueError(
                "`transition_distribution` can't have scalar batches")

        # Check compatibility of transition distribution and observation
        # distribution.
        tdbs = self.transition_distribution.batch_shape
        odbs = self.observation_distribution.batch_shape
        if (tensorshape_util.dims(tdbs) is not None
                and tf.compat.dimension_value(odbs[-1]) is not None):
            if (tf.compat.dimension_value(tdbs[-1]) !=
                    tf.compat.dimension_value(odbs[-1])):
                raise ValueError(
                    '`transition_distribution` and `observation_distribution` '
                    'must agree on last dimension of batch size')
        elif self.validate_args:
            tdbs = self.transition_distribution.batch_shape_tensor()
            odbs = self.observation_distribution.batch_shape_tensor()
            transition_precondition = assert_util.assert_greater(
                tf.size(tdbs),
                0,
                message=('`transition_distribution` can\'t have scalar '
                         'batches'))
            observation_precondition = assert_util.assert_greater(
                tf.size(odbs),
                0,
                message=('`observation_distribution` can\'t have scalar '
                         'batches'))
            with tf.control_dependencies(
                [transition_precondition, observation_precondition]):
                assertions += [
                    assert_util.assert_equal(
                        tdbs[-1],
                        odbs[-1],
                        message=('`transition_distribution` and '
                                 '`observation_distribution` '
                                 'must agree on last dimension of batch size'))
                ]

        return assertions
예제 #7
0
    def _sample_n(self, n, seed=None):
        shape = tf.concat([[n], self.batch_shape_tensor()], axis=0)

        has_seed = seed is not None
        seed = SeedStream(seed, salt="zipf")

        minval_u = self._hat_integral(0.5) + 1.
        maxval_u = self._hat_integral(tf.int64.max - 0.5)

        def loop_body(should_continue, k):
            """Resample the non-accepted points."""
            # The range of U is chosen so that the resulting sample K lies in
            # [0, tf.int64.max). The final sample, if accepted, is K + 1.
            u = tf.random.uniform(shape,
                                  minval=minval_u,
                                  maxval=maxval_u,
                                  dtype=self.power.dtype,
                                  seed=seed())

            # Sample the point X from the continuous density h(x) \propto x^(-power).
            x = self._hat_integral_inverse(u)

            # Rejection-inversion requires a `hat` function, h(x) such that
            # \int_{k - .5}^{k + .5} h(x) dx >= pmf(k + 1) for points k in the
            # support. A natural hat function for us is h(x) = x^(-power).
            #
            # After sampling X from h(x), suppose it lies in the interval
            # (K - .5, K + .5) for integer K. Then the corresponding K is accepted if
            # if lies to the left of x_K, where x_K is defined by:
            #   \int_{x_k}^{K + .5} h(x) dx = H(x_K) - H(K + .5) = pmf(K + 1),
            # where H(x) = \int_x^inf h(x) dx.

            # Solving for x_K, we find that x_K = H_inverse(H(K + .5) + pmf(K + 1)).
            # Or, the acceptance condition is X <= H_inverse(H(K + .5) + pmf(K + 1)).
            # Since X = H_inverse(U), this simplifies to U <= H(K + .5) + pmf(K + 1).

            # Update the non-accepted points.
            # Since X \in (K - .5, K + .5), the sample K is chosen as floor(X + 0.5).
            k = tf.where(should_continue, tf.floor(x + 0.5), k)
            accept = (u <= self._hat_integral(k + .5) +
                      tf.exp(self._log_prob(k + 1)))

            return [should_continue & (~accept), k]

        should_continue, samples = tf.while_loop(
            cond=lambda should_continue, *ignore: tf.reduce_any(should_continue
                                                                ),
            body=loop_body,
            loop_vars=[
                tf.ones(shape, dtype=tf.bool),  # should_continue
                tf.zeros(shape, dtype=self.power.dtype),  # k
            ],
            parallel_iterations=1 if has_seed else 10,
            maximum_iterations=self.sample_maximum_iterations,
        )
        samples = samples + 1.

        if self.validate_args and dtype_util.is_integer(self.dtype):
            samples = distribution_util.embed_check_integer_casting_closed(
                samples, target_dtype=self.dtype, assert_positive=True)

        samples = tf.cast(samples, self.dtype)

        if self.validate_args:
            npdt = dtype_util.as_numpy_dtype(self.dtype)
            v = npdt(
                dtype_util.min(npdt) if dtype_util.is_integer(npdt) else np.nan
            )
            samples = tf.where(should_continue, v, samples)

        return samples
  def _parameter_control_dependencies(self, is_init):
    assertions = []

    if is_init and not dtype_util.is_integer(self.mixture_distribution.dtype):
      raise ValueError(
          '`mixture_distribution.dtype` ({}) is not over integers'.format(
              dtype_util.name(self.mixture_distribution.dtype)))

    if tensorshape_util.rank(self.mixture_distribution.event_shape) is not None:
      if tensorshape_util.rank(self.mixture_distribution.event_shape) != 0:
        raise ValueError('`mixture_distribution` must have scalar `event_dim`s')
    elif self.validate_args:
      assertions += [
          assert_util.assert_equal(
              tf.size(self.mixture_distribution.event_shape_tensor()),
              0,
              message='`mixture_distribution` must have scalar `event_dim`s'),
      ]

    # pylint: disable=protected-access
    mixture_dist_param = (self.mixture_distribution._probs
                          if self.mixture_distribution._logits is None
                          else self.mixture_distribution._logits)
    km = tf.compat.dimension_value(
        tensorshape_util.with_rank_at_least(mixture_dist_param.shape, 1)[-1])
    kc = tf.compat.dimension_value(
        tensorshape_util.with_rank_at_least(
            self.components_distribution.batch_shape, 1)[-1])
    component_bst = None
    if km is not None and kc is not None:
      if km != kc:
        raise ValueError('`mixture_distribution` components ({}) does not '
                         'equal `components_distribution.batch_shape[-1]` '
                         '({})'.format(km, kc))
    elif self.validate_args:
      if km is None:
        mixture_dist_param = tf.convert_to_tensor(mixture_dist_param)
        km = tf.shape(mixture_dist_param)[-1]
      if kc is None:
        component_bst = self.components_distribution.batch_shape_tensor()
        kc = component_bst[-1]
      assertions += [
          assert_util.assert_equal(
              km,
              kc,
              message=('`mixture_distribution` components does not equal '
                       '`components_distribution.batch_shape[-1]`')),
      ]

    mdbs = self.mixture_distribution.batch_shape
    cdbs = tensorshape_util.with_rank_at_least(
        self.components_distribution.batch_shape, 1)[:-1]
    if (tensorshape_util.is_fully_defined(mdbs)
        and tensorshape_util.is_fully_defined(cdbs)):
      if tensorshape_util.rank(mdbs) != 0 and mdbs != cdbs:
        raise ValueError(
            '`mixture_distribution.batch_shape` (`{}`) is not '
            'compatible with `components_distribution.batch_shape` '
            '(`{}`)'.format(tensorshape_util.as_list(mdbs),
                            tensorshape_util.as_list(cdbs)))
    elif self.validate_args:
      if not tensorshape_util.is_fully_defined(mdbs):
        mixture_dist_param = tf.convert_to_tensor(mixture_dist_param)
        mdbs = tf.shape(mixture_dist_param)[:-1]
      if not tensorshape_util.is_fully_defined(cdbs):
        if component_bst is None:
          component_bst = self.components_distribution.batch_shape_tensor()
        cdbs = component_bst[:-1]
      assertions += [
          assert_util.assert_equal(
              distribution_utils.pick_vector(
                  tf.equal(tf.shape(mdbs)[0], 0), cdbs, mdbs),
              cdbs,
              message=(
                  '`mixture_distribution.batch_shape` is not '
                  'compatible with `components_distribution.batch_shape`'))
      ]

    return assertions
예제 #9
0
    def __init__(self,
                 mixture_distribution,
                 components_distribution,
                 reparameterize=False,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="MixtureSameFamily"):
        """Construct a `MixtureSameFamily` distribution.

    Args:
      mixture_distribution: `tfp.distributions.Categorical`-like instance.
        Manages the probability of selecting components. The number of
        categories must match the rightmost batch dimension of the
        `components_distribution`. Must have either scalar `batch_shape` or
        `batch_shape` matching `components_distribution.batch_shape[:-1]`.
      components_distribution: `tfp.distributions.Distribution`-like instance.
        Right-most batch dimension indexes components.
      reparameterize: Python `bool`, default `False`. Whether to reparameterize
        samples of the distribution using implicit reparameterization gradients
        [(Figurnov et al., 2018)][1]. The gradients for the mixture logits are
        equivalent to the ones described by [(Graves, 2016)][2]. The gradients
        for the components parameters are also computed using implicit
        reparameterization (as opposed to ancestral sampling), meaning that
        all components are updated every step.
        Only works when:
          (1) components_distribution is fully reparameterized;
          (2) components_distribution is either a scalar distribution or
          fully factorized (tfd.Independent applied to a scalar distribution);
          (3) batch shape has a known rank.
        Experimental, may be slow and produce infs/NaNs.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      ValueError: `if not dtype_util.is_integer(mixture_distribution.dtype)`.
      ValueError: if mixture_distribution does not have scalar `event_shape`.
      ValueError: if `mixture_distribution.batch_shape` and
        `components_distribution.batch_shape[:-1]` are both fully defined and
        the former is neither scalar nor equal to the latter.
      ValueError: if `mixture_distribution` categories does not equal
        `components_distribution` rightmost batch shape.

    #### References

    [1]: Michael Figurnov, Shakir Mohamed and Andriy Mnih. Implicit
         reparameterization gradients. In _Neural Information Processing
         Systems_, 2018. https://arxiv.org/abs/1805.08498

    [2]: Alex Graves. Stochastic Backpropagation through Mixture Density
         Distributions. _arXiv_, 2016. https://arxiv.org/abs/1607.05690
    """
        parameters = dict(locals())
        with tf.name_scope(name) as name:
            self._mixture_distribution = mixture_distribution
            self._components_distribution = components_distribution
            self._runtime_assertions = []

            s = components_distribution.event_shape_tensor()
            self._event_ndims = tf.compat.dimension_value(s.shape[0])
            if self._event_ndims is None:
                self._event_ndims = tf.size(input=s)
            self._event_size = tf.reduce_prod(input_tensor=s)

            if not dtype_util.is_integer(mixture_distribution.dtype):
                raise ValueError(
                    "`mixture_distribution.dtype` ({}) is not over integers".
                    format(dtype_util.name(mixture_distribution.dtype)))

            if (tensorshape_util.rank(mixture_distribution.event_shape)
                    is not None and tensorshape_util.rank(
                        mixture_distribution.event_shape) != 0):
                raise ValueError(
                    "`mixture_distribution` must have scalar `event_dim`s")
            elif validate_args:
                self._runtime_assertions += [
                    assert_util.assert_equal(
                        tf.size(
                            input=mixture_distribution.event_shape_tensor()),
                        0,
                        message=
                        "`mixture_distribution` must have scalar `event_dim`s"
                    ),
                ]

            mdbs = mixture_distribution.batch_shape
            cdbs = tensorshape_util.with_rank_at_least(
                components_distribution.batch_shape, 1)[:-1]
            if tensorshape_util.is_fully_defined(
                    mdbs) and tensorshape_util.is_fully_defined(cdbs):
                if tensorshape_util.rank(mdbs) != 0 and mdbs != cdbs:
                    raise ValueError(
                        "`mixture_distribution.batch_shape` (`{}`) is not "
                        "compatible with `components_distribution.batch_shape` "
                        "(`{}`)".format(tensorshape_util.as_list(mdbs),
                                        tensorshape_util.as_list(cdbs)))
            elif validate_args:
                mdbs = mixture_distribution.batch_shape_tensor()
                cdbs = components_distribution.batch_shape_tensor()[:-1]
                self._runtime_assertions += [
                    assert_util.assert_equal(
                        distribution_utils.pick_vector(
                            mixture_distribution.is_scalar_batch(), cdbs,
                            mdbs),
                        cdbs,
                        message=
                        ("`mixture_distribution.batch_shape` is not "
                         "compatible with `components_distribution.batch_shape`"
                         ))
                ]

            km = tf.compat.dimension_value(
                tensorshape_util.with_rank_at_least(
                    mixture_distribution.logits.shape, 1)[-1])
            kc = tf.compat.dimension_value(
                tensorshape_util.with_rank_at_least(
                    components_distribution.batch_shape, 1)[-1])
            if km is not None and kc is not None and km != kc:
                raise ValueError(
                    "`mixture_distribution components` ({}) does not "
                    "equal `components_distribution.batch_shape[-1]` "
                    "({})".format(km, kc))
            elif validate_args:
                km = tf.shape(input=mixture_distribution.logits)[-1]
                kc = components_distribution.batch_shape_tensor()[-1]
                self._runtime_assertions += [
                    assert_util.assert_equal(
                        km,
                        kc,
                        message=(
                            "`mixture_distribution components` does not equal "
                            "`components_distribution.batch_shape[-1:]`")),
                ]
            elif km is None:
                km = tf.shape(input=mixture_distribution.logits)[-1]

            self._num_components = km

            self._reparameterize = reparameterize
            if reparameterize:
                # Note: tfd.Independent passes through the reparameterization type hence
                # we do not need separate logic for Independent.
                if (self._components_distribution.reparameterization_type !=
                        reparameterization.FULLY_REPARAMETERIZED):
                    raise ValueError("Cannot reparameterize a mixture of "
                                     "non-reparameterized components.")
                reparameterization_type = reparameterization.FULLY_REPARAMETERIZED
            else:
                reparameterization_type = reparameterization.NOT_REPARAMETERIZED

            super(MixtureSameFamily, self).__init__(
                dtype=self._components_distribution.dtype,
                reparameterization_type=reparameterization_type,
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats,
                parameters=parameters,
                graph_parents=(
                    self._mixture_distribution._graph_parents  # pylint: disable=protected-access
                    + self._components_distribution._graph_parents),  # pylint: disable=protected-access
                name=name)
예제 #10
0
 def _is_equal_or_close(self, a, b):
     if dtype_util.is_integer(self.outcomes.dtype):
         return tf.equal(a, b)
     return tf.abs(a - b) < self._atol + self._rtol * tf.abs(b)
예제 #11
0
def _float_dtype_like(dtype):
  if dtype_util.as_numpy_dtype(dtype) == np.int64:
    return tf.float64
  if dtype_util.is_integer(dtype):
    return tf.float32
  return dtype
예제 #12
0
    def __init__(self,
                 perm=None,
                 rightmost_transposed_ndims=None,
                 validate_args=False,
                 name='transpose'):
        """Instantiates the `Transpose` bijector.

    Args:
      perm: Positive `int32` vector-shaped `Tensor` representing permutation of
        rightmost dims (for forward transformation).  Note that the `0`th index
        represents the first of the rightmost dims and the largest value must be
        `rightmost_transposed_ndims - 1` and corresponds to `tf.rank(x) - 1`.
        Only one of `perm` and `rightmost_transposed_ndims` can (and must) be
        specified. The number of elements in a permutation must have a value
        that can be determined statically.
        Default value:
        `tf.range(start=rightmost_transposed_ndims, limit=-1, delta=-1)`.
      rightmost_transposed_ndims: Positive `int32` scalar-shaped `Tensor`
        representing the number of rightmost dimensions to permute.
        Only one of `perm` and `rightmost_transposed_ndims` can (and must) be
        specified. If `rightmost_transposed_ndims` is specified, the rightmost
        dims are reversed. This argument must have a value that can be
        determined statically.
        Default value: `tf.size(perm)`.
      validate_args: Python `bool` indicating whether arguments should be
        checked for correctness.
      name: Python `str` name given to ops managed by this object.

    Raises:
      ValueError: if both or neither `perm` and `rightmost_transposed_ndims` are
        specified.
      NotImplementedError: if `rightmost_transposed_ndims` is not known prior to
        graph execution.
    """
        parameters = dict(locals())
        with tf.name_scope(name) as name:
            # We need to determine `forward_min_event_ndims` statically, which
            # requires that we know `rightmost_transposed_ndims` statically.
            # So the corresponding assertions go here rather than in
            # `_parameter_control_dependencies`
            if (rightmost_transposed_ndims is None) == (perm is None):
                raise ValueError('Must specify exactly one of '
                                 '`rightmost_transposed_ndims` and `perm`.')
            if rightmost_transposed_ndims is not None:
                rightmost_transposed_ndims = tensor_util.convert_nonref_to_tensor(
                    rightmost_transposed_ndims, dtype_hint=np.int32)
                if not dtype_util.is_integer(rightmost_transposed_ndims.dtype):
                    raise TypeError(
                        '`rightmost_transposed_ndims` must be integer type.')
                rightmost_transposed_ndims_ = tf.get_static_value(
                    rightmost_transposed_ndims)
                if rightmost_transposed_ndims_ is None:
                    raise NotImplementedError(
                        '`rightmost_transposed_ndims` must be '
                        'known prior to graph execution.')
                msg = '`rightmost_transposed_ndims` must be non-negative.'
                if rightmost_transposed_ndims_ < 0:
                    raise ValueError(
                        msg[:-1] +
                        ', saw: {}.'.format(rightmost_transposed_ndims_))

                perm_start = (distribution_util.prefer_static_value(
                    rightmost_transposed_ndims) - 1)
                perm = tf.range(start=perm_start,
                                limit=-1,
                                delta=-1,
                                name='perm')
            else:  # perm is not None:
                perm = tensor_util.convert_nonref_to_tensor(
                    perm, dtype_hint=np.int32, name='perm')
                rightmost_transposed_ndims = tf.size(
                    perm, name='rightmost_transposed_ndims')
                rightmost_transposed_ndims_ = tf.get_static_value(
                    rightmost_transposed_ndims)

            # TODO(b/110828604): If bijector base class ever supports dynamic
            # `min_event_ndims`, then this class already works dynamically and the
            # following five lines can be removed.
            if rightmost_transposed_ndims_ is None:
                raise NotImplementedError(
                    '`rightmost_transposed_ndims` must be '
                    'known prior to graph execution.')
            else:
                rightmost_transposed_ndims_ = int(rightmost_transposed_ndims_)

            self._perm = perm
            self._rightmost_transposed_ndims = rightmost_transposed_ndims
            self._initial_rightmost_transposed_ndims = rightmost_transposed_ndims_
            super(Transpose, self).__init__(
                forward_min_event_ndims=rightmost_transposed_ndims_,
                is_constant_jacobian=True,
                validate_args=validate_args,
                parameters=parameters,
                name=name)
예제 #13
0
def _float_dtype_like(dtype):
    if dtype is tf.int64:
        return tf.float64
    if dtype_util.is_integer(dtype):
        return tf.float32
    return dtype
예제 #14
0
def _potential_scale_reduction_single_state(state, independent_chain_ndims,
                                            split_chains, validate_args):
    """potential_scale_reduction for one single state `Tensor`."""
    # casting integers to floats for floating-point division
    # check to see if the `state` is a numpy object for the numpy test suite
    if dtype_util.as_numpy_dtype(state.dtype) is np.int64:
        state = tf.cast(state, tf.float64)
    elif dtype_util.is_integer(state.dtype):
        state = tf.cast(state, tf.float32)
    with tf.name_scope('potential_scale_reduction_single_state'):
        # We assume exactly one leading dimension indexes e.g. correlated samples
        # from each Markov chain.
        state = tf.convert_to_tensor(state, name='state')

        n_samples_ = tf.compat.dimension_value(state.shape[0])
        if n_samples_ is not None:  # If available statically.
            if split_chains and n_samples_ < 4:
                raise ValueError(
                    'Must provide at least 4 samples when splitting chains. '
                    'Found {}'.format(n_samples_))
            if not split_chains and n_samples_ < 2:
                raise ValueError(
                    'Must provide at least 2 samples.  Found {}'.format(
                        n_samples_))
        elif validate_args:
            if split_chains:
                assertions = [
                    assert_util.assert_greater(
                        ps.shape(state)[0],
                        4,
                        message=
                        'Must provide at least 4 samples when splitting chains.'
                    )
                ]
                with tf.control_dependencies(assertions):
                    state = tf.identity(state)
            else:
                assertions = [
                    assert_util.assert_greater(
                        ps.shape(state)[0],
                        2,
                        message='Must provide at least 2 samples.')
                ]
                with tf.control_dependencies(assertions):
                    state = tf.identity(state)

        # Define so it's not a magic number.
        # Warning!  `if split_chains` logic assumes this is 1!
        sample_ndims = 1

        if split_chains:
            # Split the sample dimension in half, doubling the number of
            # independent chains.

            # For odd number of samples, keep all but the last sample.
            state_shape = ps.shape(state)
            n_samples = state_shape[0]
            state = state[:n_samples - n_samples % 2]

            # Suppose state = [0, 1, 2, 3, 4, 5]
            # Step 1: reshape into [[0, 1, 2], [3, 4, 5]]
            # E.g. reshape states of shape [a, b] into [2, a//2, b].
            state = tf.reshape(
                state, ps.concat([[2, n_samples // 2], state_shape[1:]],
                                 axis=0))
            # Step 2: Put the size `2` dimension in the right place to be treated as a
            # chain, changing [[0, 1, 2], [3, 4, 5]] into [[0, 3], [1, 4], [2, 5]],
            # reshaping [2, a//2, b] into [a//2, 2, b].
            state = tf.transpose(
                a=state,
                perm=ps.concat([[1, 0], tf.range(2, tf.rank(state))], axis=0))

            # We're treating the new dim as indexing 2 chains, so increment.
            independent_chain_ndims += 1

        sample_axis = tf.range(0, sample_ndims)
        chain_axis = tf.range(sample_ndims,
                              sample_ndims + independent_chain_ndims)
        sample_and_chain_axis = tf.range(
            0, sample_ndims + independent_chain_ndims)

        n = _axis_size(state, sample_axis)
        m = _axis_size(state, chain_axis)

        # In the language of Brooks and Gelman (1998),
        # B / n is the between chain variance, the variance of the chain means.
        # W is the within sequence variance, the mean of the chain variances.
        b_div_n = _reduce_variance(tf.reduce_mean(state,
                                                  axis=sample_axis,
                                                  keepdims=True),
                                   sample_and_chain_axis,
                                   biased=False)
        w = tf.reduce_mean(_reduce_variance(state,
                                            sample_axis,
                                            keepdims=True,
                                            biased=False),
                           axis=sample_and_chain_axis)

        # sigma^2_+ is an estimate of the true variance, which would be unbiased if
        # each chain was drawn from the target.  c.f. "law of total variance."
        sigma_2_plus = ((n - 1) / n) * w + b_div_n
        return ((m + 1.) / m) * sigma_2_plus / w - (n - 1.) / (m * n)