Ejemplo n.º 1
0
 def testCorrectlyAssertsSmallestPossibleInteger(self):
     with self.assertRaisesOpError('Elements cannot be smaller than 0.'):
         x = tf1.placeholder_with_default(np.array([1, -1], dtype=np.int32),
                                          shape=None)
         x_checked = distribution_util.embed_check_integer_casting_closed(
             x, target_dtype=tf.uint16, assert_nonnegative=False)
         self.evaluate(x_checked)
Ejemplo n.º 2
0
 def testCorrectlyAssersIntegerForm(self):
   with self.assertRaisesOpError('Elements must be int16-equivalent.'):
     x = tf1.placeholder_with_default(
         np.array([1, 1.5], dtype=np.float16), shape=None)
     x_checked = distribution_util.embed_check_integer_casting_closed(
         x, target_dtype=tf.int16)
     self.evaluate(x_checked)
Ejemplo n.º 3
0
 def testCorrectlyAssertsLargestPossibleInteger(self):
   with self.assertRaisesOpError('Elements cannot exceed 32767.'):
     x = tf1.placeholder_with_default(
         np.array([1, 2**15], dtype=np.int32), shape=None)
     x_checked = distribution_util.embed_check_integer_casting_closed(
         x, target_dtype=tf.int16)
     self.evaluate(x_checked)
Ejemplo n.º 4
0
 def testCorrectlyAssertsPositive(self):
   with self.assertRaisesOpError('Elements must be positive'):
     x = tf1.placeholder_with_default(
         np.array([1, 0], dtype=np.float16), shape=None)
     x_checked = distribution_util.embed_check_integer_casting_closed(
         x, target_dtype=tf.int16, assert_positive=True)
     self.evaluate(x_checked)
Ejemplo n.º 5
0
  def _log_prob(self, event):
    if self.validate_args:
      event = distribution_util.embed_check_integer_casting_closed(
          event, target_dtype=tf.bool)

    log_probs0, log_probs1 = self._outcome_log_probs()
    event = tf.cast(event, log_probs0.dtype)
    return event * (log_probs1 - log_probs0) + log_probs0
 def _mean(self):
     probs = self.probs
     outcomes = self.outcomes
     if dtype_util.is_integer(outcomes.dtype):
         if self._validate_args:
             outcomes = dist_util.embed_check_integer_casting_closed(
                 outcomes, target_dtype=probs.dtype)
         outcomes = tf.cast(outcomes, dtype=probs.dtype)
     return tf.tensordot(outcomes, probs, axes=[[0], [-1]])
Ejemplo n.º 7
0
 def _log_prob(self, k):
     logits = self.logits_parameter()
     if self.validate_args:
         k = distribution_util.embed_check_integer_casting_closed(
             k, target_dtype=tf.int32)
     k, logits = _broadcast_cat_event_and_params(
         k, logits, base_dtype=dtype_util.base_dtype(self.dtype))
     return -tf.nn.sparse_softmax_cross_entropy_with_logits(labels=k,
                                                            logits=logits)
Ejemplo n.º 8
0
 def _log_prob(self, k):
     with tf.name_scope("Cat2log_prob"):
         logits = self.logits_parameter()
         if self.validate_args:
             k = distribution_util.embed_check_integer_casting_closed(
                 k, target_dtype=self.dtype)
         k, logits = _broadcast_cat_event_and_params(
             k, logits, base_dtype=dtype_util.base_dtype(self.dtype))
         logits_normalised = tf.math.log(tf.math.softmax(logits))
         return tf.gather(logits_normalised, k, batch_dims=1)
Ejemplo n.º 9
0
    def _log_prob(self, k):
        k = tf.convert_to_tensor(value=k, name="k")
        if self.validate_args:
            k = util.embed_check_integer_casting_closed(k,
                                                        target_dtype=tf.int32)
        k, logits = _broadcast_cat_event_and_params(
            k, self.logits, base_dtype=self.dtype.base_dtype)

        return -tf.nn.sparse_softmax_cross_entropy_with_logits(labels=k,
                                                               logits=logits)
Ejemplo n.º 10
0
  def _log_prob(self, k):
    k = tf.convert_to_tensor(k, name="k")
    if self.validate_args:
      k = util.embed_check_integer_casting_closed(
          k, target_dtype=tf.int32)
    k, logits = _broadcast_cat_event_and_params(
        k, self.logits, base_dtype=self.dtype.base_dtype)

    return -tf.nn.sparse_softmax_cross_entropy_with_logits(labels=k,
                                                           logits=logits)
Ejemplo n.º 11
0
 def _variance(self):
   probs = self._categorical.probs
   outcomes = tf.broadcast_to(
       self.outcomes, shape=dist_util.prefer_static_shape(probs))
   if dtype_util.is_integer(outcomes.dtype):
     if self._validate_args:
       outcomes = dist_util.embed_check_integer_casting_closed(
           outcomes, target_dtype=probs.dtype)
     outcomes = tf.cast(outcomes, dtype=probs.dtype)
   square_d = tf.math.squared_difference(outcomes,
                                         tf.expand_dims(self.mean(), axis=-1))
   return tf.reduce_sum(input_tensor=probs * square_d, axis=-1)
Ejemplo n.º 12
0
 def _variance(self):
     probs = self._categorical.probs_parameter()
     outcomes = tf.broadcast_to(self.outcomes, shape=ps.shape(probs))
     if dtype_util.is_integer(outcomes.dtype):
         if self._validate_args:
             outcomes = dist_util.embed_check_integer_casting_closed(
                 outcomes, target_dtype=probs.dtype)
         outcomes = tf.cast(outcomes, dtype=probs.dtype)
     square_d = tf.math.squared_difference(
         outcomes,
         self._mean(probs)[..., tf.newaxis])
     return tf.reduce_sum(probs * square_d, axis=-1)
Ejemplo n.º 13
0
 def _log_prob(self, x):
   # The log probability at positive integer points x is log(x^(-power) / Z)
   # where Z is the normalization constant. For x < 1 and non-integer points,
   # the log-probability is -inf.
   #
   # However, if interpolate_nondiscrete is True, we return the natural
   # continuous relaxation for x >= 1 which agrees with the log probability at
   # positive integer points.
   #
   # If interpolate_nondiscrete is False and validate_args is True, we check
   # that the sample point x is in the support. That is, x is equivalent to a
   # positive integer.
   x = tf.cast(x, self.power.dtype)
   if self.validate_args and not self.interpolate_nondiscrete:
     x = distribution_util.embed_check_integer_casting_closed(
         x, target_dtype=self.dtype, assert_positive=True)
   return self._log_unnormalized_prob(x) - self._log_normalization()
Ejemplo n.º 14
0
 def _log_prob(self, x):
   # The log probability at positive integer points x is log(x^(-power) / Z)
   # where Z is the normalization constant. For x < 1 and non-integer points,
   # the log-probability is -inf.
   #
   # However, if interpolate_nondiscrete is True, we return the natural
   # continuous relaxation for x >= 1 which agrees with the log probability at
   # positive integer points.
   #
   # If interpolate_nondiscrete is False and validate_args is True, we check
   # that the sample point x is in the support. That is, x is equivalent to a
   # positive integer.
   x = tf.cast(x, self.power.dtype)
   if self.validate_args and not self.interpolate_nondiscrete:
     x = distribution_util.embed_check_integer_casting_closed(
         x, target_dtype=self.dtype, assert_positive=True)
   return self._log_unnormalized_prob(x) - self._log_normalization()
Ejemplo n.º 15
0
    def _cdf(self, k):
        k = tf.convert_to_tensor(k, name="k")
        if self.validate_args:
            k = util.embed_check_integer_casting_closed(k,
                                                        target_dtype=tf.int32)

        k, probs = _broadcast_cat_event_and_params(
            k, self.probs, base_dtype=self.dtype.base_dtype)

        # batch-flatten everything in order to use `sequence_mask()`.
        batch_flattened_probs = tf.reshape(probs, (-1, self._event_size))
        batch_flattened_k = tf.reshape(k, [-1])

        to_sum_over = tf.where(
            tf.sequence_mask(batch_flattened_k, self._event_size),
            batch_flattened_probs, tf.zeros_like(batch_flattened_probs))
        batch_flattened_cdf = tf.reduce_sum(to_sum_over, axis=-1)
        # Reshape back to the shape of the argument.
        return tf.reshape(batch_flattened_cdf, tf.shape(k))
Ejemplo n.º 16
0
  def _cdf(self, k):
    k = tf.convert_to_tensor(k, name="k")
    if self.validate_args:
      k = util.embed_check_integer_casting_closed(
          k, target_dtype=tf.int32)

    k, probs = _broadcast_cat_event_and_params(
        k, self.probs, base_dtype=self.dtype.base_dtype)

    # batch-flatten everything in order to use `sequence_mask()`.
    batch_flattened_probs = tf.reshape(probs,
                                       (-1, self._event_size))
    batch_flattened_k = tf.reshape(k, [-1])

    to_sum_over = tf.where(
        tf.sequence_mask(batch_flattened_k, self._event_size),
        batch_flattened_probs,
        tf.zeros_like(batch_flattened_probs))
    batch_flattened_cdf = tf.reduce_sum(to_sum_over, axis=-1)
    # Reshape back to the shape of the argument.
    return tf.reshape(batch_flattened_cdf, tf.shape(k))
Ejemplo n.º 17
0
  def _log_prob(self, event):
    if self.validate_args:
      event = util.embed_check_integer_casting_closed(
          event, target_dtype=tf.bool)

    # TODO(jaana): The current sigmoid_cross_entropy_with_logits has
    # inconsistent behavior for logits = inf/-inf.
    event = tf.cast(event, self.logits.dtype)
    logits = self.logits
    # sigmoid_cross_entropy_with_logits doesn't broadcast shape,
    # so we do this here.

    def _broadcast(logits, event):
      return (tf.ones_like(event) * logits,
              tf.ones_like(logits) * event)

    if not (event.shape.is_fully_defined() and
            logits.shape.is_fully_defined() and
            event.shape == logits.shape):
      logits, event = _broadcast(logits, event)
    return -tf.nn.sigmoid_cross_entropy_with_logits(labels=event, logits=logits)
Ejemplo n.º 18
0
  def _sample_n(self, n, seed=None):
    power = tf.convert_to_tensor(self.power)
    shape = tf.concat([[n], tf.shape(power)], axis=0)

    has_seed = seed is not None
    seed = SeedStream(seed, salt='zipf')

    minval_u = self._hat_integral(0.5, power=power) + 1.
    maxval_u = self._hat_integral(tf.int64.max - 0.5, power=power)

    def loop_body(should_continue, k):
      """Resample the non-accepted points."""
      # The range of U is chosen so that the resulting sample K lies in
      # [0, tf.int64.max). The final sample, if accepted, is K + 1.
      u = tf.random.uniform(
          shape,
          minval=minval_u,
          maxval=maxval_u,
          dtype=power.dtype,
          seed=seed())

      # Sample the point X from the continuous density h(x) \propto x^(-power).
      x = self._hat_integral_inverse(u, power=power)

      # Rejection-inversion requires a `hat` function, h(x) such that
      # \int_{k - .5}^{k + .5} h(x) dx >= pmf(k + 1) for points k in the
      # support. A natural hat function for us is h(x) = x^(-power).
      #
      # After sampling X from h(x), suppose it lies in the interval
      # (K - .5, K + .5) for integer K. Then the corresponding K is accepted if
      # if lies to the left of x_K, where x_K is defined by:
      #   \int_{x_k}^{K + .5} h(x) dx = H(x_K) - H(K + .5) = pmf(K + 1),
      # where H(x) = \int_x^inf h(x) dx.

      # Solving for x_K, we find that x_K = H_inverse(H(K + .5) + pmf(K + 1)).
      # Or, the acceptance condition is X <= H_inverse(H(K + .5) + pmf(K + 1)).
      # Since X = H_inverse(U), this simplifies to U <= H(K + .5) + pmf(K + 1).

      # Update the non-accepted points.
      # Since X \in (K - .5, K + .5), the sample K is chosen as floor(X + 0.5).
      k = tf.where(should_continue, tf.floor(x + 0.5), k)
      accept = (u <= self._hat_integral(k + .5, power=power) + tf.exp(
          self._log_prob(k + 1, power=power)))

      return [should_continue & (~accept), k]

    should_continue, samples = tf.while_loop(
        cond=lambda should_continue, *ignore: tf.reduce_any(should_continue),
        body=loop_body,
        loop_vars=[
            tf.ones(shape, dtype=tf.bool),  # should_continue
            tf.zeros(shape, dtype=power.dtype),  # k
        ],
        parallel_iterations=1 if has_seed else 10,
        maximum_iterations=self.sample_maximum_iterations,
    )
    samples = samples + 1.

    if self.validate_args and dtype_util.is_integer(self.dtype):
      samples = distribution_util.embed_check_integer_casting_closed(
          samples, target_dtype=self.dtype, assert_positive=True)

    samples = tf.cast(samples, self.dtype)

    if self.validate_args:
      npdt = dtype_util.as_numpy_dtype(self.dtype)
      v = npdt(dtype_util.min(npdt) if dtype_util.is_integer(npdt) else np.nan)
      samples = tf.where(should_continue, v, samples)

    return samples
Ejemplo n.º 19
0
    def _sample_n(self, n, seed=None):
        power = tf.convert_to_tensor(self.power)
        shape = ps.concat([[n], ps.shape(power)], axis=0)
        numpy_dtype = dtype_util.as_numpy_dtype(power.dtype)

        seed = samplers.sanitize_seed(seed, salt='zipf')

        # Because `_hat_integral` is montonically decreasing, the bounds for u will
        # switch.
        # Compute the hat_integral explicitly here since we can calculate the log of
        # the inputs statically in float64 with numpy.
        maxval_u = tf.math.exp(-(power - 1.) * numpy_dtype(np.log1p(0.5)) -
                               tf.math.log(power - 1.)) + 1.
        minval_u = tf.math.exp(
            -(power - 1.) *
            numpy_dtype(np.log1p(dtype_util.max(self.dtype) - 0.5)) -
            tf.math.log(power - 1.))

        def loop_body(should_continue, k, seed):
            """Resample the non-accepted points."""
            u_seed, next_seed = samplers.split_seed(seed)
            # Uniform variates must be sampled from the open-interval `(0, 1)` rather
            # than `[0, 1)`. To do so, we use
            # `np.finfo(dtype_util.as_numpy_dtype(self.dtype)).tiny`
            # because it is the smallest, positive, 'normal' number. A 'normal' number
            # is such that the mantissa has an implicit leading 1. Normal, positive
            # numbers x, y have the reasonable property that, `x + y >= max(x, y)`. In
            # this case, a subnormal number (i.e., np.nextafter) can cause us to
            # sample 0.
            u = samplers.uniform(
                shape,
                minval=np.finfo(dtype_util.as_numpy_dtype(power.dtype)).tiny,
                maxval=numpy_dtype(1.),
                dtype=power.dtype,
                seed=u_seed)
            # We use (1 - u) * maxval_u + u * minval_u rather than the other way
            # around, since we want to draw samples in (minval_u, maxval_u].
            u = maxval_u + (minval_u - maxval_u) * u
            # set_shape needed here because of b/139013403
            tensorshape_util.set_shape(u, should_continue.shape)

            # Sample the point X from the continuous density h(x) \propto x^(-power).
            x = self._hat_integral_inverse(u, power=power)

            # Rejection-inversion requires a `hat` function, h(x) such that
            # \int_{k - .5}^{k + .5} h(x) dx >= pmf(k + 1) for points k in the
            # support. A natural hat function for us is h(x) = x^(-power).
            #
            # After sampling X from h(x), suppose it lies in the interval
            # (K - .5, K + .5) for integer K. Then the corresponding K is accepted if
            # if lies to the left of x_K, where x_K is defined by:
            #   \int_{x_k}^{K + .5} h(x) dx = H(x_K) - H(K + .5) = pmf(K + 1),
            # where H(x) = \int_x^inf h(x) dx.

            # Solving for x_K, we find that x_K = H_inverse(H(K + .5) + pmf(K + 1)).
            # Or, the acceptance condition is X <= H_inverse(H(K + .5) + pmf(K + 1)).
            # Since X = H_inverse(U), this simplifies to U <= H(K + .5) + pmf(K + 1).

            # Update the non-accepted points.
            # Since X \in (K - .5, K + .5), the sample K is chosen as floor(X + 0.5).
            k = tf.where(should_continue, tf.floor(x + 0.5), k)
            accept = (u <= self._hat_integral(k + .5, power=power) +
                      tf.exp(self._log_prob(k + 1, power=power)))

            return [should_continue & (~accept), k, next_seed]

        should_continue, samples, _ = tf.while_loop(
            cond=lambda should_continue, *ignore: tf.reduce_any(should_continue
                                                                ),
            body=loop_body,
            loop_vars=[
                tf.ones(shape, dtype=tf.bool),  # should_continue
                tf.zeros(shape, dtype=power.dtype),  # k
                seed,  # seed
            ],
            maximum_iterations=self.sample_maximum_iterations,
        )
        samples = samples + 1.

        if self.validate_args and dtype_util.is_integer(self.dtype):
            samples = distribution_util.embed_check_integer_casting_closed(
                samples, target_dtype=self.dtype, assert_positive=True)

        samples = tf.cast(samples, self.dtype)

        if self.validate_args:
            npdt = dtype_util.as_numpy_dtype(self.dtype)
            v = npdt(
                dtype_util.min(npdt) if dtype_util.is_integer(npdt) else np.nan
            )
            samples = tf.where(should_continue, v, samples)

        return samples
Ejemplo n.º 20
0
  def _sample_n(self, n, seed=None):
    shape = tf.concat([[n], self.batch_shape_tensor()], axis=0)

    has_seed = seed is not None
    seed = SeedStream(seed, salt="zipf")

    minval_u = self._hat_integral(0.5) + 1.
    maxval_u = self._hat_integral(tf.int64.max - 0.5)

    def loop_body(should_continue, k):
      """Resample the non-accepted points."""
      # The range of U is chosen so that the resulting sample K lies in
      # [0, tf.int64.max). The final sample, if accepted, is K + 1.
      u = tf.random_uniform(
          shape,
          minval=minval_u,
          maxval=maxval_u,
          dtype=self.power.dtype,
          seed=seed())

      # Sample the point X from the continuous density h(x) \propto x^(-power).
      x = self._hat_integral_inverse(u)

      # Rejection-inversion requires a `hat` function, h(x) such that
      # \int_{k - .5}^{k + .5} h(x) dx >= pmf(k + 1) for points k in the
      # support. A natural hat function for us is h(x) = x^(-power).
      #
      # After sampling X from h(x), suppose it lies in the interval
      # (K - .5, K + .5) for integer K. Then the corresponding K is accepted if
      # if lies to the left of x_K, where x_K is defined by:
      #   \int_{x_k}^{K + .5} h(x) dx = H(x_K) - H(K + .5) = pmf(K + 1),
      # where H(x) = \int_x^inf h(x) dx.

      # Solving for x_K, we find that x_K = H_inverse(H(K + .5) + pmf(K + 1)).
      # Or, the acceptance condition is X <= H_inverse(H(K + .5) + pmf(K + 1)).
      # Since X = H_inverse(U), this simplifies to U <= H(K + .5) + pmf(K + 1).

      # Update the non-accepted points.
      # Since X \in (K - .5, K + .5), the sample K is chosen as floor(X + 0.5).
      k = tf.where(should_continue, tf.floor(x + 0.5), k)
      accept = (u <= self._hat_integral(k + .5) + tf.exp(self._log_prob(k + 1)))

      return [should_continue & (~accept), k]

    should_continue, samples = tf.while_loop(
        cond=lambda should_continue, *ignore: tf.reduce_any(should_continue),
        body=loop_body,
        loop_vars=[
            tf.ones(shape, dtype=tf.bool),  # should_continue
            tf.zeros(shape, dtype=self.power.dtype),  # k
        ],
        parallel_iterations=1 if has_seed else 10,
        maximum_iterations=self.sample_maximum_iterations,
    )
    samples = samples + 1.

    if self.validate_args and self.dtype.is_integer:
      samples = distribution_util.embed_check_integer_casting_closed(
          samples, target_dtype=self.dtype, assert_positive=True)

    samples = tf.cast(samples, self.dtype)

    if self.validate_args:
      dt = self.dtype.as_numpy_dtype
      if self.dtype.is_integer:
        mask = tf.fill(shape, value=np.array(np.iinfo(dt).min, dtype=dt))
        samples = tf.where(should_continue, mask, samples)
      else:
        mask = tf.fill(shape, value=np.array(np.nan, dtype=dt))
        samples = tf.where(should_continue, mask, samples)

    return samples