Ejemplo n.º 1
0
  def _parameter_control_dependencies(self, is_init):
    assertions = []

    logits = self._logits
    probs = self._probs
    param, name = (probs, 'probs') if logits is None else (logits, 'logits')

    # In init, we can always build shape and dtype checks because
    # we assume shape doesn't change for Variable backed args.
    if is_init:
      if not dtype_util.is_floating(param.dtype):
        raise TypeError('Argument `{}` must having floating type.'.format(name))

      msg = 'Argument `{}` must have rank at least 1.'.format(name)
      shape_static = tensorshape_util.dims(param.shape)
      if shape_static is not None:
        if len(shape_static) < 1:
          raise ValueError(msg)
      elif self.validate_args:
        param = tf.convert_to_tensor(param)
        assertions.append(
            assert_util.assert_rank_at_least(param, 1, message=msg))
        with tf.control_dependencies(assertions):
          param = tf.identity(param)

      msg1 = 'Argument `{}` must have final dimension >= 1.'.format(name)
      msg2 = 'Argument `{}` must have final dimension <= {}.'.format(
          name, dtype_util.max(tf.int32))
      event_size = shape_static[-1] if shape_static is not None else None
      if event_size is not None:
        if event_size < 1:
          raise ValueError(msg1)
        if event_size > dtype_util.max(tf.int32):
          raise ValueError(msg2)
      elif self.validate_args:
        param = tf.convert_to_tensor(param)
        assertions.append(assert_util.assert_greater_equal(
            tf.shape(param)[-1], 1, message=msg1))
        # NOTE: For now, we leave out a runtime assertion that
        # `tf.shape(param)[-1] <= tf.int32.max`.  An earlier `tf.shape` call
        # will fail before we get to this point.

    if not self.validate_args:
      assert not assertions  # Should never happen.
      return []

    if probs is not None:
      probs = param  # reuse tensor conversion from above
      if is_init != tensor_util.is_ref(probs):
        probs = tf.convert_to_tensor(probs)
        one = tf.ones([], dtype=probs.dtype)
        assertions.extend([
            assert_util.assert_non_negative(probs),
            assert_util.assert_less_equal(probs, one),
            assert_util.assert_near(
                tf.reduce_sum(probs, axis=-1), one,
                message='Argument `probs` must sum to 1.'),
        ])

    return assertions
Ejemplo n.º 2
0
  def test_assert_all_finite_input_finite(self):
    minval = tf.constant(dtype_util.min(self.dtype), dtype=self.dtype)
    maxval = tf.constant(dtype_util.max(self.dtype), dtype=self.dtype)

    # This tests if the minimum value for the dtype is detected as finite.
    self.assertAllFinite(minval)

    # This tests if the maximum value for the dtype is detected as finite.
    self.assertAllFinite(maxval)

    # This tests if a rank 3 `Tensor` with entries in the range
    # [0.4*minval, 0.4*maxval] is detected as finite.
    # The choice of range helps to avoid overflows or underflows
    # in tf.linspace calculations.
    num_elem = 1000
    shape = (10, 10, 10)
    a = tf.reshape(tf.linspace(0.4*minval, 0.4*maxval, num_elem), shape)
    self.assertAllFinite(a)
Ejemplo n.º 3
0
 def testMax(self, dtype, expected_maxval):
   self.assertEqual(dtype_util.max(dtype), expected_maxval)
Ejemplo n.º 4
0
    def _sample_n(self, n, seed=None):
        power = tf.convert_to_tensor(self.power)
        shape = ps.concat([[n], ps.shape(power)], axis=0)
        numpy_dtype = dtype_util.as_numpy_dtype(power.dtype)

        seed = samplers.sanitize_seed(seed, salt='zipf')

        # Because `_hat_integral` is montonically decreasing, the bounds for u will
        # switch.
        # Compute the hat_integral explicitly here since we can calculate the log of
        # the inputs statically in float64 with numpy.
        maxval_u = tf.math.exp(-(power - 1.) * numpy_dtype(np.log1p(0.5)) -
                               tf.math.log(power - 1.)) + 1.
        minval_u = tf.math.exp(
            -(power - 1.) *
            numpy_dtype(np.log1p(dtype_util.max(self.dtype) - 0.5)) -
            tf.math.log(power - 1.))

        def loop_body(should_continue, k, seed):
            """Resample the non-accepted points."""
            u_seed, next_seed = samplers.split_seed(seed)
            # Uniform variates must be sampled from the open-interval `(0, 1)` rather
            # than `[0, 1)`. To do so, we use
            # `np.finfo(dtype_util.as_numpy_dtype(self.dtype)).tiny`
            # because it is the smallest, positive, 'normal' number. A 'normal' number
            # is such that the mantissa has an implicit leading 1. Normal, positive
            # numbers x, y have the reasonable property that, `x + y >= max(x, y)`. In
            # this case, a subnormal number (i.e., np.nextafter) can cause us to
            # sample 0.
            u = samplers.uniform(
                shape,
                minval=np.finfo(dtype_util.as_numpy_dtype(power.dtype)).tiny,
                maxval=numpy_dtype(1.),
                dtype=power.dtype,
                seed=u_seed)
            # We use (1 - u) * maxval_u + u * minval_u rather than the other way
            # around, since we want to draw samples in (minval_u, maxval_u].
            u = maxval_u + (minval_u - maxval_u) * u
            # set_shape needed here because of b/139013403
            tensorshape_util.set_shape(u, should_continue.shape)

            # Sample the point X from the continuous density h(x) \propto x^(-power).
            x = self._hat_integral_inverse(u, power=power)

            # Rejection-inversion requires a `hat` function, h(x) such that
            # \int_{k - .5}^{k + .5} h(x) dx >= pmf(k + 1) for points k in the
            # support. A natural hat function for us is h(x) = x^(-power).
            #
            # After sampling X from h(x), suppose it lies in the interval
            # (K - .5, K + .5) for integer K. Then the corresponding K is accepted if
            # if lies to the left of x_K, where x_K is defined by:
            #   \int_{x_k}^{K + .5} h(x) dx = H(x_K) - H(K + .5) = pmf(K + 1),
            # where H(x) = \int_x^inf h(x) dx.

            # Solving for x_K, we find that x_K = H_inverse(H(K + .5) + pmf(K + 1)).
            # Or, the acceptance condition is X <= H_inverse(H(K + .5) + pmf(K + 1)).
            # Since X = H_inverse(U), this simplifies to U <= H(K + .5) + pmf(K + 1).

            # Update the non-accepted points.
            # Since X \in (K - .5, K + .5), the sample K is chosen as floor(X + 0.5).
            k = tf.where(should_continue, tf.floor(x + 0.5), k)
            accept = (u <= self._hat_integral(k + .5, power=power) +
                      tf.exp(self._log_prob(k + 1, power=power)))

            return [should_continue & (~accept), k, next_seed]

        should_continue, samples, _ = tf.while_loop(
            cond=lambda should_continue, *ignore: tf.reduce_any(should_continue
                                                                ),
            body=loop_body,
            loop_vars=[
                tf.ones(shape, dtype=tf.bool),  # should_continue
                tf.zeros(shape, dtype=power.dtype),  # k
                seed,  # seed
            ],
            maximum_iterations=self.sample_maximum_iterations,
        )
        samples = samples + 1.

        if self.validate_args and dtype_util.is_integer(self.dtype):
            samples = distribution_util.embed_check_integer_casting_closed(
                samples, target_dtype=self.dtype, assert_positive=True)

        samples = tf.cast(samples, self.dtype)

        if self.validate_args:
            npdt = dtype_util.as_numpy_dtype(self.dtype)
            v = npdt(
                dtype_util.min(npdt) if dtype_util.is_integer(npdt) else np.nan
            )
            samples = tf.where(should_continue, v, samples)

        return samples
Ejemplo n.º 5
0
    def _sample_n(self, n, seed=None):
        power = tf.convert_to_tensor(self.power)
        shape = tf.concat([[n], tf.shape(power)], axis=0)

        has_seed = seed is not None
        seed = SeedStream(seed, salt='zipf')

        minval_u = self._hat_integral(0.5, power=power) + 1.
        maxval_u = self._hat_integral(dtype_util.max(tf.int64) - 0.5,
                                      power=power)

        def loop_body(should_continue, k):
            """Resample the non-accepted points."""
            # The range of U is chosen so that the resulting sample K lies in
            # [0, tf.int64.max). The final sample, if accepted, is K + 1.
            u = tf.random.uniform(shape,
                                  minval=minval_u,
                                  maxval=maxval_u,
                                  dtype=power.dtype,
                                  seed=seed())

            # Sample the point X from the continuous density h(x) \propto x^(-power).
            x = self._hat_integral_inverse(u, power=power)

            # Rejection-inversion requires a `hat` function, h(x) such that
            # \int_{k - .5}^{k + .5} h(x) dx >= pmf(k + 1) for points k in the
            # support. A natural hat function for us is h(x) = x^(-power).
            #
            # After sampling X from h(x), suppose it lies in the interval
            # (K - .5, K + .5) for integer K. Then the corresponding K is accepted if
            # if lies to the left of x_K, where x_K is defined by:
            #   \int_{x_k}^{K + .5} h(x) dx = H(x_K) - H(K + .5) = pmf(K + 1),
            # where H(x) = \int_x^inf h(x) dx.

            # Solving for x_K, we find that x_K = H_inverse(H(K + .5) + pmf(K + 1)).
            # Or, the acceptance condition is X <= H_inverse(H(K + .5) + pmf(K + 1)).
            # Since X = H_inverse(U), this simplifies to U <= H(K + .5) + pmf(K + 1).

            # Update the non-accepted points.
            # Since X \in (K - .5, K + .5), the sample K is chosen as floor(X + 0.5).
            k = tf.where(should_continue, tf.floor(x + 0.5), k)
            accept = (u <= self._hat_integral(k + .5, power=power) +
                      tf.exp(self._log_prob(k + 1, power=power)))

            return [should_continue & (~accept), k]

        should_continue, samples = tf.while_loop(
            cond=lambda should_continue, *ignore: tf.reduce_any(should_continue
                                                                ),
            body=loop_body,
            loop_vars=[
                tf.ones(shape, dtype=tf.bool),  # should_continue
                tf.zeros(shape, dtype=power.dtype),  # k
            ],
            parallel_iterations=1 if has_seed else 10,
            maximum_iterations=self.sample_maximum_iterations,
        )
        samples = samples + 1.

        if self.validate_args and dtype_util.is_integer(self.dtype):
            samples = distribution_util.embed_check_integer_casting_closed(
                samples, target_dtype=self.dtype, assert_positive=True)

        samples = tf.cast(samples, self.dtype)

        if self.validate_args:
            npdt = dtype_util.as_numpy_dtype(self.dtype)
            v = npdt(
                dtype_util.min(npdt) if dtype_util.is_integer(npdt) else np.nan
            )
            samples = tf.where(should_continue, v, samples)

        return samples