Example #1
0
 def get_dataset(self, shuffle=True):
   """Returns dataset with slight restructuring of feature dictionary."""
   def preprocess_ex(ex):
     return {
         'pitch':
             ex['pitch'],
         'audio':
             ex['audio'],
         'instrument_source':
             ex['instrument']['source'],
         'instrument_family':
             ex['instrument']['family'],
         'instrument':
             ex['instrument']['label'],
         'f0_hz':
             ex['f0']['hz'],
         'f0_confidence':
             ex['f0']['confidence'],
         'loudness_db':
             ex['loudness']['db'],
     }
   dataset = super().get_dataset(shuffle)
   dataset = dataset.map(preprocess_ex, num_parallel_calls=_AUTOTUNE)
   if self.pitch_subset:
     dataset = dataset.filter(lambda ex: tf.reduce_any(
       ex['pitch'] == self.pitch_subset
     ))
   if self.instrument_subset:
     dataset = dataset.filter(lambda ex: tf.reduce_any(
       ex['instrument'] == self.instrument_subset
     ))
   return dataset
Example #2
0
def _secant2_inner(value_and_gradients_function,
                   initial_args,
                   val_0,
                   val_c,
                   f_lim,
                   sufficient_decrease_param,
                   curvature_param):
  """Helper function for secant square."""
  # Apply the `update` function on active branch members to squeeze their
  # bracketing interval.
  update_result = update(value_and_gradients_function,
                         initial_args.left,
                         initial_args.right,
                         val_c,
                         f_lim,
                         active=initial_args.active)

  # Update active and failed flags, update left/right on non-failed entries.
  active = initial_args.active & ~update_result.failed
  failed = initial_args.failed | update_result.failed
  val_left = val_where(active, update_result.left, initial_args.left)
  val_right = val_where(active, update_result.right, initial_args.right)

  # Check if new `c` points should be generated.
  updated_left = active & tf.equal(val_left.x, val_c.x)
  updated_right = active & tf.equal(val_right.x, val_c.x)
  is_new = updated_left | updated_right

  next_c = tf.where(
      updated_left, _secant(initial_args.left, val_left), val_c.x)
  next_c = tf.where(
      updated_right, _secant(initial_args.right, val_right), next_c)
  in_range = (val_left.x <= next_c) & (next_c <= val_right.x)

  # Figure out if an extra function evaluation is needed for new `c` points.
  needs_extra_eval = tf.reduce_any(in_range & is_new)
  num_evals = initial_args.num_evals + update_result.num_evals
  num_evals = num_evals + tf.cast(needs_extra_eval, num_evals.dtype)

  next_args = _Secant2Result(
      active=active & in_range,  # No longer active if `c` is out of range.
      converged=initial_args.converged,
      failed=failed,
      num_evals=num_evals,
      left=val_left,
      right=val_right)

  def _apply_inner_update():
    next_val_c = prefer_static.cond(
        needs_extra_eval,
        (lambda: value_and_gradients_function(next_c)),
        (lambda: val_c))
    return _secant2_inner_update(
        value_and_gradients_function, next_args, val_0, next_val_c, f_lim,
        sufficient_decrease_param, curvature_param)

  return prefer_static.cond(
      tf.reduce_any(next_args.active),
      _apply_inner_update,
      lambda: next_args)
def prune_completely_outside_window(boxlist, window, scope=None):
    """Prunes bounding boxes that fall completely outside of the given window.

  The function clip_to_window prunes bounding boxes that fall
  completely outside the window, but also clips any bounding boxes that
  partially overflow. This function does not clip partially overflowing boxes.

  Args:
    boxlist: a BoxList holding M_in boxes.
    window: a float tensor of shape [4] representing [ymin, xmin, ymax, xmax]
      of the window
    scope: name scope.

  Returns:
    pruned_boxlist: a new BoxList with all bounding boxes partially or fully in
      the window.
    valid_indices: a tensor with shape [M_out] indexing the valid bounding boxes
     in the input tensor.
  """
    with tf.name_scope(scope, 'PruneCompleteleyOutsideWindow'):
        y_min, x_min, y_max, x_max = tf.split(value=boxlist.get(),
                                              num_or_size_splits=4,
                                              axis=1)
        win_y_min, win_x_min, win_y_max, win_x_max = tf.unstack(window)
        coordinate_violations = tf.concat([
            tf.greater_equal(y_min, win_y_max),
            tf.greater_equal(x_min, win_x_max),
            tf.less_equal(y_max, win_y_min),
            tf.less_equal(x_max, win_x_min)
        ], 1)
        valid_indices = tf.reshape(
            tf.where(tf.logical_not(tf.reduce_any(coordinate_violations, 1))),
            [-1])
        return gather(boxlist, valid_indices), valid_indices
def prune_outside_window(boxlist, window, scope=None):
    """Prunes bounding boxes that fall outside a given window.

  This function prunes bounding boxes that even partially fall outside the given
  window. See also clip_to_window which only prunes bounding boxes that fall
  completely outside the window, and clips any bounding boxes that partially
  overflow.

  Args:
    boxlist: a BoxList holding M_in boxes.
    window: a float tensor of shape [4] representing [ymin, xmin, ymax, xmax]
      of the window
    scope: name scope.

  Returns:
    pruned_corners: a tensor with shape [M_out, 4] where M_out <= M_in
    valid_indices: a tensor with shape [M_out] indexing the valid bounding boxes
     in the input tensor.
  """
    with tf.name_scope(scope, 'PruneOutsideWindow'):
        y_min, x_min, y_max, x_max = tf.split(value=boxlist.get(),
                                              num_or_size_splits=4,
                                              axis=1)
        win_y_min, win_x_min, win_y_max, win_x_max = tf.unstack(window)
        coordinate_violations = tf.concat([
            tf.less(y_min, win_y_min),
            tf.less(x_min, win_x_min),
            tf.greater(y_max, win_y_max),
            tf.greater(x_max, win_x_max)
        ], 1)
        valid_indices = tf.reshape(
            tf.where(tf.logical_not(tf.reduce_any(coordinate_violations, 1))),
            [-1])
        return gather(boxlist, valid_indices), valid_indices
Example #5
0
def _owens_t_method4(h, a, m):
    """OwensT Method T4, which is a reordered evaluation of method T2."""
    dtype = dtype_util.common_dtype([h, a], tf.float32)
    h_squared = tf.math.square(h)
    nega_squared = -tf.math.square(a)
    num_iterations = 2 * m + 1.

    def series_evaluation(should_stop, index, term, coeff, series_sum):
        new_coeff = (1. - h_squared * coeff) / index
        new_term = nega_squared * term
        new_series_sum = tf.where(should_stop, series_sum,
                                  series_sum + new_coeff * new_term)
        should_stop = index >= num_iterations
        return should_stop, index + 2., new_term, new_coeff, new_series_sum

    broadcast_shape = prefer_static.broadcast_shape(prefer_static.shape(h),
                                                    prefer_static.shape(a))
    initial_term = a * tf.math.exp(-0.5 * h_squared *
                                   (1 - nega_squared)) / (2 * np.pi)
    initial_sum = initial_term

    (_, _, _, _,
     series_sum) = tf.while_loop(cond=lambda stop, *_: tf.reduce_any(~stop),
                                 body=series_evaluation,
                                 loop_vars=(tf.zeros(broadcast_shape,
                                                     dtype=tf.bool),
                                            tf.cast(3.,
                                                    dtype=dtype), initial_term,
                                            tf.ones(broadcast_shape,
                                                    dtype=dtype), initial_sum))
    return series_sum
Example #6
0
def has_not_u_turn_at_all_index(read_indexes, direction, momentum_state_memory,
                                momentum_right, state_right,
                                no_u_turns_within_tree, log_prob_rank):
    """Check u turn for early stopping."""
    def _get_left_state_and_check_u_turn(left_current_index, no_u_turns_last):
        """Check U turn on a single index."""
        momentum_left = [
            tf.gather(x, left_current_index, axis=0)
            for x in momentum_state_memory.momentum_swap
        ]
        state_left = [
            tf.gather(x, left_current_index, axis=0)
            for x in momentum_state_memory.state_swap
        ]
        # Note that in generalized u turn, state_diff is actually the cumulated sum
        # of the momentum.
        state_diff = [s1 - s2 for s1, s2 in zip(state_right, state_left)]
        if not GENERALIZED_UTURN:
            state_diff = [
                tf.where(d, m, -m) for d, m in zip(direction, state_diff)
            ]

        no_u_turns_current = has_not_u_turn(state_diff, momentum_left,
                                            momentum_right, log_prob_rank)
        return left_current_index + 1, no_u_turns_current & no_u_turns_last

    # Note that we dont need to set parallel_iterations arg in the while_loop
    # below as there is no random Ops in `_get_left_state_and_check_u_turn`.
    _, no_u_turns_within_tree = tf.while_loop(
        cond=lambda i, no_u_turn: ((i < tf.gather(read_indexes, 1)) &  # pylint: disable=g-long-lambda
                                   tf.reduce_any(no_u_turn)),
        body=_get_left_state_and_check_u_turn,
        loop_vars=(tf.gather(read_indexes, 0), no_u_turns_within_tree))
    return no_u_turns_within_tree
Example #7
0
def _line_search_inner_bisection(value_and_gradients_function, search_interval,
                                 active, f_lim):
    """Performs bisection and updates the interval."""
    midpoint = (search_interval.left.x + search_interval.right.x) / 2
    val_mid = value_and_gradients_function(midpoint)
    is_valid_mid = hzl.is_finite(val_mid)

    still_active = active & is_valid_mid
    new_failed = active & ~is_valid_mid
    next_inteval = search_interval._replace(
        failed=search_interval.failed | new_failed,
        func_evals=search_interval.func_evals + 1)

    def _apply_update():
        update_result = hzl.update(value_and_gradients_function,
                                   next_inteval.left,
                                   next_inteval.right,
                                   val_mid,
                                   f_lim,
                                   active=still_active)
        return HagerZhangLineSearchResult(
            converged=next_inteval.converged,
            failed=next_inteval.failed | update_result.failed,
            iterations=next_inteval.iterations + update_result.iteration,
            func_evals=next_inteval.func_evals + update_result.num_evals,
            left=update_result.left,
            right=update_result.right)

    return prefer_static.cond(tf.reduce_any(still_active), _apply_update,
                              lambda: next_inteval)
Example #8
0
    def _do_check_shrinkage():
      """Check if interval has shrinked enough."""
      old_width = curr_interval.right.x - curr_interval.left.x
      new_width = secant2_result.right.x - secant2_result.left.x
      sufficient_shrinkage = new_width < old_width * shrinkage_param
      func_is_flat = (
          _very_close(curr_interval.left.f, curr_interval.right.f) &
          _very_close(secant2_result.left.f, secant2_result.right.f))

      new_converged = (
          should_check_shrinkage & sufficient_shrinkage & func_is_flat)
      needs_inner_bisect = should_check_shrinkage & ~sufficient_shrinkage

      inner_bisect_args = secant2_result._replace(
          converged=secant2_result.converged | new_converged)

      def _apply_inner_bisect():
        return _line_search_inner_bisection(
            value_and_gradients_function, inner_bisect_args,
            needs_inner_bisect, f_lim)

      return prefer_static.cond(
          tf.reduce_any(needs_inner_bisect),
          _apply_inner_bisect,
          lambda: inner_bisect_args)
Example #9
0
def has_not_u_turn_at_all_index(read_indexes, direction, velocity_state_memory,
                                velocity_right, state_right,
                                no_u_turns_within_tree, log_prob_rank):
    """Check u turn for early stopping."""
    def _get_left_state_and_check_u_turn(left_current_index, no_u_turns_last):
        """Check U turn on a single index."""
        velocity_left = [
            tf.gather(x, left_current_index, axis=0)
            for x in velocity_state_memory.velocity_swap
        ]
        state_left = [
            tf.gather(x, left_current_index, axis=0)
            for x in velocity_state_memory.state_swap
        ]
        # Note that in generalized u turn, state_diff is actually the cumulated sum
        # of the momentum.
        state_diff = [s1 - s2 for s1, s2 in zip(state_right, state_left)]
        if not GENERALIZED_UTURN:
            state_diff = [
                tf.where(d, m, -m) for d, m in zip(direction, state_diff)
            ]

        no_u_turns_current = has_not_u_turn(state_diff, velocity_left,
                                            velocity_right, log_prob_rank)
        return left_current_index + 1, no_u_turns_current & no_u_turns_last

    _, no_u_turns_within_tree = tf.while_loop(
        cond=lambda i, no_u_turn: ((i < read_indexes[1]) &  # pylint: disable=g-long-lambda
                                   tf.reduce_any(no_u_turn)),
        body=_get_left_state_and_check_u_turn,
        loop_vars=(read_indexes[0], no_u_turns_within_tree))
    return no_u_turns_within_tree
Example #10
0
def _owens_t_method2(h, a, m):
    """OwensT Method T2 using Power series."""
    # Method T2, which is evaluation approximating the (1 + x^2)^-1 term in the
    # denominator of the OwensT integrand via power series, and integrating this
    # term by term to get a series expansion.
    dtype = dtype_util.common_dtype([h, a], tf.float32)
    numpy_dtype = dtype_util.as_numpy_dtype(dtype)
    h_squared = tf.math.square(h)
    nega_squared = -tf.math.square(a)
    num_iterations = 2 * m + 1.
    y = tf.math.reciprocal(h_squared)

    def series_evaluation(should_stop, index, summand, term, series_sum):
        new_summand = y * (term - index * summand)
        new_term = nega_squared * term
        new_series_sum = tf.where(should_stop, series_sum,
                                  series_sum + new_summand)
        should_stop = index >= num_iterations
        return should_stop, index + 2., new_summand, new_term, new_series_sum

    broadcast_shape = prefer_static.broadcast_shape(prefer_static.shape(h),
                                                    prefer_static.shape(a))
    initial_summand = -0.5 * tf.math.erf(a * h) / h
    initial_sum = initial_summand
    initial_term = a * tf.math.exp(-0.5 * tf.math.square(a * h)) / numpy_dtype(
        np.sqrt(2 * np.pi))

    (_, _, _, _, series_sum) = tf.while_loop(
        cond=lambda stop, *_: tf.reduce_any(~stop),
        body=series_evaluation,
        loop_vars=(tf.zeros(broadcast_shape,
                            dtype=tf.bool), tf.cast(1., dtype=dtype),
                   initial_summand, initial_term, initial_sum))
    return (series_sum * tf.math.exp(-0.5 * h_squared) /
            numpy_dtype(np.sqrt(2 * np.pi)))
Example #11
0
def reduce_any(input_tensor, axis=None, keepdims=False):
    """A version of tf.reduce_any that eagerly evaluates if possible."""
    v = get_static_value(input_tensor)
    if v is None:
        return tf.reduce_any(input_tensor, axis=axis, keepdims=keepdims)
    else:
        return v.any(axis=axis, keepdims=keepdims)
Example #12
0
    def test_valid_gradients(self):
        """Tests none of the gradients is nan."""

        # In this example, `x[0]` and `x[1]` are both less than or equal to
        # `x_data[0]`. `x[-2]` and `x[-1]` are both greater than or equal to
        # `x_data[-1]`. They are set up this way to test none of the tf.where
        # branches of the implementation have any nan. An unselected nan could still
        # propagate through gradient calculation with the end result being nan.
        x = [[-10.0, -1.0, 1.0, 3.0, 6.0, 7.0],
             [8.0, 15.0, 18.0, 25.0, 30.0, 35.0]]
        x_data = [[-1.0, 2.0, 6.0], [8.0, 18.0, 30.0]]

        def _value_helper_fn(y_data):
            """A helper function that returns sum of squared interplated values."""

            interpolated_values = tff.math.interpolation.linear.interpolate(
                x, x_data, y_data, dtype=tf.float64)
            return tf.reduce_sum(tf.math.square(interpolated_values))

        y_data = tf.convert_to_tensor([[10.0, -1.0, -5.0], [7.0, 9.0, 20.0]],
                                      dtype=tf.float64)
        if tf.executing_eagerly():
            with tf.GradientTape(watch_accessed_variables=False) as tape:
                tape.watch(y_data)
                value = _value_helper_fn(y_data=y_data)
                gradients = tape.gradient(value, y_data)
        else:
            value = _value_helper_fn(y_data=y_data)
            gradients = tf.gradients(value, y_data)[0]

        gradients = tf.convert_to_tensor(gradients)

        self.assertFalse(
            self.evaluate(tf.reduce_any(tf.math.is_nan(gradients))))
Example #13
0
def _update_inv_hessian(prev_state, next_state):
  """Update the BGFS state by computing the next inverse hessian estimate."""
  # Only update the inverse Hessian if not already failed or converged.
  should_update = ~next_state.converged & ~next_state.failed

  # Compute the normalization term (y^T . s), should not update if is singular.
  gradient_delta = next_state.objective_gradient - prev_state.objective_gradient
  position_delta = next_state.position - prev_state.position
  normalization_factor = tf.reduce_sum(gradient_delta * position_delta, axis=-1)
  should_update = should_update & ~tf.equal(normalization_factor, 0)

  def _do_update_inv_hessian():
    next_inv_hessian = _bfgs_inv_hessian_update(
        gradient_delta, position_delta, normalization_factor,
        prev_state.inverse_hessian_estimate)
    return bfgs_utils.update_fields(
        next_state,
        inverse_hessian_estimate=tf.where(
            should_update[..., tf.newaxis, tf.newaxis],
            next_inv_hessian,
            prev_state.inverse_hessian_estimate))

  return prefer_static.cond(
      tf.reduce_any(should_update),
      _do_update_inv_hessian,
      lambda: next_state)
Example #14
0
    def testGumbelGumbelKL(self):
        a_loc = np.arange(-2.0, 3.0, 1.0)
        a_scale = np.arange(0.5, 2.5, 0.5)
        b_loc = 2 * np.arange(-2.0, 3.0, 1.0)
        b_scale = np.arange(0.5, 2.5, 0.5)

        # This reshape is intended to expand the number of test cases.
        a_loc = a_loc.reshape((len(a_loc), 1, 1, 1))
        a_scale = a_scale.reshape((1, len(a_scale), 1, 1))
        b_loc = b_loc.reshape((1, 1, len(b_loc), 1))
        b_scale = b_scale.reshape((1, 1, 1, len(b_scale)))

        a = tfd.Gumbel(loc=a_loc, scale=a_scale, validate_args=True)
        b = tfd.Gumbel(loc=b_loc, scale=b_scale, validate_args=True)

        true_kl = (
            np.log(b_scale) - np.log(a_scale) + np.euler_gamma *
            (a_scale / b_scale - 1.) +
            np.expm1((b_loc - a_loc) / b_scale +
                     np.vectorize(np.math.lgamma)(a_scale / b_scale + 1.)) +
            (a_loc - b_loc) / b_scale)

        kl = tfd.kl_divergence(a, b)

        x = a.sample(int(1e5), seed=test_util.test_seed())
        kl_sample = tf.reduce_mean(input_tensor=a.log_prob(x) - b.log_prob(x),
                                   axis=0)

        # As noted in the Gumbel-Gumbel KL divergence implementation, there is an
        # error in the reference paper we use to implement our divergence. This
        # error is a missing summand, (a.loc - b.loc) / b.scale. To ensure that we
        # are adequately testing this difference in the below tests, we compute the
        # relative error between kl_sample_ and kl_ and check that it is "much less"
        # than this missing summand.
        summand = (a_loc - b_loc) / b_scale
        relative_error = (tf.abs(kl - kl_sample) /
                          tf.minimum(tf.abs(kl), tf.abs(kl_sample)))
        exists_missing_summand_test = tf.reduce_any(
            input_tensor=summand > 2 * relative_error)
        exists_missing_summand_test_ = self.evaluate(
            exists_missing_summand_test)
        self.assertTrue(
            exists_missing_summand_test_,
            msg=('No test case exists where (a.loc - b.loc) / b.scale '
                 'is much less than the relative error between kl as '
                 'computed in closed form, and kl as computed by '
                 'sampling. Failing to include such a test case makes '
                 'it difficult to detect regressions where this '
                 'summand (which is missing in our reference paper) '
                 'is omitted.'))

        kl_, kl_sample_ = self.evaluate([kl, kl_sample])
        self.assertAllClose(true_kl, kl_, atol=0.0, rtol=1e-12)
        self.assertAllClose(true_kl, kl_sample_, atol=0.0, rtol=1e-1)

        zero_kl = tfd.kl_divergence(a, a)
        true_zero_kl_, zero_kl_ = self.evaluate(
            [tf.zeros_like(zero_kl), zero_kl])
        self.assertAllEqual(true_zero_kl_, zero_kl_)
Example #15
0
def _hyp2f1_large_negative_c(a, b, c, z):
    """Compute 2F1(a, b, c, z) when c < 0 and |c| large."""

    # The recurrences here are based on Gauss' continguous recurrence relations as
    # based on [1]
    # References
    # [1] M. Abramowitz, I. Stegun. Handbook of Mathematical Functions with
    #     Formulas, Graphs and Mathematical Tables.
    with tf.name_scope('hyp2f1_large_negative_c'):
        dtype = dtype_util.common_dtype([a, b, c, z], tf.float32)
        a = tf.convert_to_tensor(a, dtype=dtype)
        b = tf.convert_to_tensor(b, dtype=dtype)
        c = tf.convert_to_tensor(c, dtype=dtype)
        z = tf.convert_to_tensor(z, dtype=dtype)

        # We assume that c < 0 and a, b > 0.

        d = c - a - b
        integer_d = tf.math.floor(d)
        e = c + 2 - integer_d

        # If |a| >> |e|, use the recurrence for large a.
        recurrence_for_large_a = ((tf.math.abs(a) > tf.math.abs(e)) &
                                  (tf.math.abs(e - a) > 2))
        second_result = tf.where(recurrence_for_large_a,
                                 _hyp2f1_large_a(a, b, e, z),
                                 _hyp2f1_small_parameters(a, b, e, z))

        first_result = tf.where(recurrence_for_large_a,
                                _hyp2f1_large_a(a, b, e + 1., z),
                                _hyp2f1_small_parameters(a, b, e + 1., z))

        broadcast_shape = functools.reduce(ps.broadcast_shape,
                                           [ps.shape(x) for x in [a, b, c]])

        # We use recurrence 15.2.27 in [1]:
        # w * 2F1(a, b, c, z) + x * 2F1(a, b, c + 1, z) = y 2F1(a, b, c - 1, z)
        # Where, w, x and y are coefficients for the recurrence (as specified
        # in [1]).
        def hypergeometric_recurrence(should_stop, index, term, result,
                                      previous_result):
            c = term
            new_result = ((c * (c - 1 - (2 * c - a - b - 1.) * z) * result +
                           (c - a) * (c - b) * z * previous_result) /
                          (c * (c - 1) * (1. - z)))
            should_stop = index >= 2 - integer_d
            new_term = tf.where(should_stop, term, term - 1.)
            new_result = tf.where(should_stop, result, new_result)

            return should_stop, index + 1, new_term, new_result, result

        (_, _, _,
         result, _) = tf.while_loop(cond=lambda stop, *_: tf.reduce_any(~stop),
                                    body=hypergeometric_recurrence,
                                    loop_vars=(tf.zeros(broadcast_shape,
                                                        dtype=tf.bool),
                                               tf.cast(0., dtype=dtype), e,
                                               second_result, first_result))
        return result
Example #16
0
def has_fully_masked_sequence(mask):
    # See https://github.com/tensorflow/tensorflow/issues/33148 for more details.
    # Cudnn kernel will error out if the input sequence contains any fully masked
    # data. We walk around this issue by rerouting the computation to standard
    # kernel, until the issue on cudnn side has been fixed.
    # For a fully masked sequence, it will contain all Falses. To make it easy to
    # check, we inverse the boolean, check if any of the sequence has all True.
    return tf.reduce_any(tf.reduce_all(tf.logical_not(mask), axis=1))
Example #17
0
 def call(self, inputs):
     boolean_mask = tf.reduce_any(tf.not_equal(inputs, self.mask_value),
                                  axis=-1,
                                  keepdims=True)
     outputs = inputs * tf.cast(boolean_mask, inputs.dtype)
     # Compute the mask and outputs simultaneously.
     outputs._keras_mask = tf.squeeze(boolean_mask, axis=-1)
     return outputs
 def is_last_day_of_season(t):
     t_ = dist_util.maybe_get_static_value(t)
     if t_ is not None:  # static case
         step_in_cycle = t_ % num_steps_per_cycle
         return any(step_in_cycle == changepoints)
     else:
         step_in_cycle = tf.math.floormod(t, num_steps_per_cycle)
         return tf.reduce_any(tf.equal(step_in_cycle, changepoints))
Example #19
0
 def body(unused_keep_going, geom_sum, num_geom, seed):
     u_seed, next_seed = samplers.split_seed(seed)
     u = samplers.uniform(full_shape, seed=u_seed, dtype=counts.dtype)
     geom = tf.math.ceil(tf.math.log(u) / log1minusprob)
     geom_sum += geom
     keep_going = (geom_sum <= counts)
     num_geom = tf.where(keep_going, num_geom + 1, num_geom)
     return tf.reduce_any(keep_going), geom_sum, num_geom, next_seed
Example #20
0
def _temme_expansion(v, x):
    """Compute modified bessel functions using Temme's method."""
    # The implementation of this is based on [1].
    # [1] N. Temme, On the Numerical Evaluation of the Modified Bessel Function
    #   of the Third Kind. Journal of Computational Physics 19, 1975.
    dtype = dtype_util.common_dtype([v, x], tf.float32)
    numpy_dtype = dtype_util.as_numpy_dtype(dtype)
    v_less_than_zero = v < 0.
    v = tf.math.abs(v)
    n = tf.math.round(v)
    # Use this to compute Kv(u, x) and Kv(u + 1., x)
    u = v - n
    x_abs = tf.math.abs(x)

    small_x = tf.where(x_abs <= 2., x_abs, numpy_dtype(0.1))
    large_x = tf.where(x_abs > 2., x_abs, numpy_dtype(1000.))
    temme_ku, temme_kup1 = _temme_series(u, small_x)
    cf_ku, cf_kup1 = _continued_fraction_kv(u, large_x)

    ku = tf.where(x_abs <= 2., temme_ku, cf_ku)
    kup1 = tf.where(x_abs <= 2., temme_kup1, cf_kup1)

    # Now use the forward recurrence for modified bessel functions
    # to compute Kv(v, x). That is,
    # K_{v + 1}(z) - (2v / z) K_v(z) - K_{v - 1}(z) = 0.
    # This is known to be forward numerically stable.

    def bessel_recurrence(index, kv, kvp1):
        next_kvp1 = 2 * (u + index) * kvp1 / x_abs + kv
        kv = tf.where(index > n, kv, kvp1)
        kvp1 = tf.where(index > n, kvp1, next_kvp1)
        return index + 1., kv, kvp1

    _, kv, kvp1 = tf.while_loop(cond=lambda i, *_: tf.reduce_any(i <= n),
                                body=bessel_recurrence,
                                loop_vars=(tf.cast(1., dtype=dtype), ku, kup1))

    # Finally, it is known that the Wronskian
    # det(I_v * K'_v - K_v * I'_v) = - 1. / x. We can
    # use this to evaluate I_v by taking advantage of identities of Bessel
    # derivatives.

    iv = tf.math.reciprocal(x_abs * (kv * bessel_iv_ratio(v + 1., x) + kvp1))
    z = u + tf.math.mod(n, 2.)

    iv = tf.where(v_less_than_zero,
                  iv + 2. / np.pi * tf.math.sin(np.pi * z) * kv, iv)
    iv = tf.where(
        tf.math.equal(x, 0.),
        tf.where(tf.math.equal(v, 0.), numpy_dtype(1.), numpy_dtype(0.)), iv)
    iv = tf.where(
        tf.math.equal(x, 0.) & v_less_than_zero,
        tf.where(tf.math.equal(z, tf.math.floor(z)), iv, numpy_dtype(np.inf)),
        iv)
    kv = tf.where(tf.math.equal(x, 0.), numpy_dtype(np.inf), kv)
    iv = tf.where(x < 0., numpy_dtype(np.nan), iv)
    kv = tf.where(x < 0., numpy_dtype(np.nan), kv)
    return iv, kv
def kl_divergence(distribution_a,
                  distribution_b,
                  allow_nan_stats=True,
                  name=None):
    """Get the KL-divergence KL(distribution_a || distribution_b).
    If there is no KL method registered specifically for `type(distribution_a)`
    and `type(distribution_b)`, then the class hierarchies of these types are
    searched.
    If one KL method is registered between any pairs of classes in these two
    parent hierarchies, it is used.
    If more than one such registered method exists, the method whose registered
    classes have the shortest sum MRO paths to the input types is used.
    If more than one such shortest path exists, the first method
    identified in the search is used (favoring a shorter MRO distance to
    `type(distribution_a)`).
    Args:
      distribution_a: The first distribution.
      distribution_b: The second distribution.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
    Returns:
      A Tensor with the batchwise KL-divergence between `distribution_a`
      and `distribution_b`.
    Raises:
      NotImplementedError: If no KL method is defined for distribution types
        of `distribution_a` and `distribution_b`.
    """
    kl_fn = _registered_kl(type(distribution_a), type(distribution_b))
    if kl_fn is None:
        raise NotImplementedError(
            "No KL(distribution_a || distribution_b) registered for distribution_a "
            "type {} and distribution_b type {}".format(
                type(distribution_a).__name__,
                type(distribution_b).__name__))

    name = name or "KullbackLeibler"
    with tf.name_scope(name):
        # pylint: disable=protected-access
        with distribution_a._name_and_control_scope(name + "_a"):
            with distribution_b._name_and_control_scope(name + "_b"):
                kl_t = kl_fn(distribution_a, distribution_b, name=name)
                if allow_nan_stats:
                    return kl_t

    # Check KL for NaNs
    kl_t = tf.identity(kl_t, name="kl")

    with tf.control_dependencies([
            tf.debugging.Assert(
                tf.logical_not(tf.reduce_any(tf.math.is_nan(kl_t))),
                [("KL calculation between {} and {} returned NaN values "
                  "(and was called with allow_nan_stats=False). Values:".
                  format(distribution_a.name, distribution_b.name)), kl_t])
    ]):
        return tf.identity(kl_t, name="checked_kl")
 def chain_not_done(
     seed,
     angle,
     angle_min,
     angle_max,
     current_state_parts,
     current_log_likelihood):
   del seed, angle, angle_min, angle_max, current_state_parts
   return tf.reduce_any(current_log_likelihood < threshold)
Example #23
0
    def __call__(self,
                 logits,
                 scaled_labels,
                 classes,
                 category_loss=True,
                 mse_loss=False):
        """Compute instance segmentation loss.

    Args:
      logits: A Tensor of shape [batch_size * num_points, height, width,
        num_classes]. The logits are not necessarily between 0 and 1.
      scaled_labels: A float16 Tensor of shape [batch_size, num_instances,
          mask_size, mask_size], where mask_size =
          mask_crop_size * gt_upsample_scale for fine mask, or mask_crop_size
          for coarse masks and shape priors.
      classes: A int tensor of shape [batch_size, num_instances].
      category_loss: use class specific mask prediction or not.
      mse_loss: use mean square error for mask loss or not

    Returns:
      mask_loss: an float tensor representing total mask classification loss.
      iou: a float tensor representing the IoU between target and prediction.
    """
        classes = tf.reshape(classes, [-1])
        _, _, height, width = scaled_labels.get_shape().as_list()
        scaled_labels = tf.reshape(scaled_labels, [-1, height, width])

        if not category_loss:
            logits = logits[:, :, :, 0]
        else:
            logits = tf.transpose(a=logits, perm=(0, 3, 1, 2))
            gather_idx = tf.stack(
                [tf.range(tf.size(input=classes)), classes - 1], axis=1)
            logits = tf.gather_nd(logits, gather_idx)

        # Ignore loss on empty mask targets.
        valid_labels = tf.reduce_any(input_tensor=tf.greater(scaled_labels, 0),
                                     axis=[1, 2])
        if mse_loss:
            # Logits are probabilities in the case of shape prior prediction.
            logits *= tf.reshape(tf.cast(valid_labels, logits.dtype),
                                 [-1, 1, 1])
            weighted_loss = tf.nn.l2_loss(scaled_labels - logits)
            probs = logits
        else:
            weighted_loss = tf.nn.sigmoid_cross_entropy_with_logits(
                labels=scaled_labels, logits=logits)
            probs = tf.sigmoid(logits)
            weighted_loss *= tf.reshape(
                tf.cast(valid_labels, weighted_loss.dtype), [-1, 1, 1])

        iou = tf.reduce_sum(
            input_tensor=tf.minimum(scaled_labels, probs)) / tf.reduce_sum(
                input_tensor=tf.maximum(scaled_labels, probs))
        mask_loss = tf.reduce_sum(input_tensor=weighted_loss) / tf.reduce_sum(
            input_tensor=scaled_labels)
        return tf.cast(mask_loss, tf.float32), tf.cast(iou, tf.float32)
Example #24
0
def _random_poisson_low_rate(sample_shape,
                             rate,
                             internal_dtype=tf.float64,
                             seed=None):
    """Samples from the Poisson distribution using Knuth's algorithm.

  We use an algorithm attributed to Knuth: Seminumerical Algorithms. Art of
  Computer Programming, Volume 2. This algorithm runs in O(rate) time, and
  requires O(rate) uniform variates. This algorithm is performant for rate ~<10.

  Given a Poisson process, the time between events is exponentially distributed.
  If we have a Poisson process with rate lambda, then, the time between events
  is distributed as Exp(lambda). If X ~ Uniform(0, 1), then Y ~ Exp(lambda)
  where Y = -log(X) / lambda. Thus, to simulate a Poisson draw, we can sample
  X_i ~ Exp(lambda), and we will haver N ~ Poisson(lambda), where N is the
  smallest number such that sum_i^N X_i > 1.

  Args:
    sample_shape: The output sample shape. Must broadcast with `rate`.
    rate: Floating point tensor, rate.
    internal_dtype: (optional) dtype to use for internal computations.
    seed: (optional) The random seed.

  Returns:
    Samples from the poisson distribution.
  """
    exp_neg_rate = tf.math.exp(-rate)

    def loop_body(should_continue, samples, prod, num_iters, seed):
        u_seed, next_seed = samplers.split_seed(seed)
        prod = prod * samplers.uniform(
            sample_shape, dtype=internal_dtype, seed=u_seed)
        accept = should_continue & (prod <= exp_neg_rate)
        samples = tf.where(accept, num_iters, samples)
        return [
            should_continue & (~accept), samples, prod, num_iters + 1,
            next_seed
        ]

    _, samples, _, _, _ = tf.while_loop(
        cond=lambda should_continue, *ignore: tf.reduce_any(should_continue),
        body=loop_body,
        loop_vars=[
            tf.ones(sample_shape, dtype=tf.bool),  # should_continue
            tf.zeros(sample_shape, dtype=tf.int32),  # samples
            tf.ones(sample_shape, dtype=internal_dtype),  # prod
            tf.zeros([], dtype=tf.int32),  # num_iters
            seed,  # seed
        ],
        # Using a Chernoff-like bound, we can show that for lambda < 10,
        # Pr[X >= lambda + n] <= exp(-n^2 / 2(lambda + n)) < exp(-90). Hence,
        # there is miniscule probability that, even after a union bound over
        # batch size, a poisson sample with rate < 10 would attain a value > 200.
        maximum_iterations=200,
    )
    return samples
Example #25
0
def _temme_series(v, z):
    """Computes Kve(v, z) and Kve(v + 1., z) via Power series expansion."""
    # This is based on:
    # [1] N. Temme, On the Numerical Evaluation of the Modified Bessel Function
    #   of the Third Kind. Journal of Computational Physics 19, 1975.
    # [2] Numerical Recipes in C. The Art of Scientific Computing,
    #   2nd Edition, 1992
    # We will assume |z| <= 2. and |v| < 0.5 for fast convergence.
    dtype = dtype_util.common_dtype([v, z], tf.float32)
    numpy_dtype = dtype_util.as_numpy_dtype(dtype)
    tol = tf.cast(np.finfo(numpy_dtype).eps, dtype=dtype)

    # The initial series term is defined by 6.7.39 in [2]. We compute
    # related coefficients and quantities.
    coeff1, coeff2, gamma1pv_inv, gamma1mv_inv = _evaluate_temme_coeffs(v)

    z_sq = tf.math.square(z)

    logzo2 = tf.math.log(z / 2.)
    mu = -v * logzo2
    sinc_v = tf.where(tf.math.equal(v, 0.), numpy_dtype(1.),
                      tf.math.sin(np.pi * v) / (np.pi * v))
    sinhc_mu = tf.where(tf.math.equal(mu, 0.), numpy_dtype(1.),
                        tf.math.sinh(mu) / mu)
    # These are defined in 6.7.17 in [2].
    initial_f = (coeff1 * tf.math.cosh(mu) +
                 coeff2 * -logzo2 * sinhc_mu) / sinc_v
    initial_p = 0.5 * tf.math.exp(mu) / gamma1pv_inv
    initial_q = 0.5 * tf.math.exp(-mu) / gamma1mv_inv
    max_iterations = 1000

    def body_fn(should_stop, index, f, p, q, coeff, kv_sum, kvp1_sum):
        f = tf.where(should_stop, f, (index * f + p + q) /
                     (tf.math.square(index) - tf.math.square(v)))
        p = tf.where(should_stop, p, p / (index - v))
        q = tf.where(should_stop, q, q / (index + v))
        h = p - index * f
        # c_k = (z ** 2 / 4) ** k / (k!)
        coeff = tf.where(should_stop, coeff, coeff * z_sq / (4 * index))
        kv_sum = tf.where(should_stop, kv_sum, kv_sum + coeff * f)
        kvp1_sum = tf.where(should_stop, kvp1_sum, kvp1_sum + coeff * h)
        index = index + 1
        should_stop = (tf.math.abs(coeff * f) <
                       tf.math.abs(kv_sum) * tol) | (index > max_iterations)
        return should_stop, index, f, p, q, coeff, kv_sum, kvp1_sum

    _, _, _, _, _, _, kv_sum, kvp1_sum = tf.while_loop(
        cond=lambda stop, *_: tf.reduce_any(~stop),
        body=body_fn,
        loop_vars=(tf.zeros_like(initial_f, dtype=tf.bool),
                   tf.cast(1., dtype), initial_f, initial_p, initial_q,
                   tf.ones_like(initial_p), initial_f, initial_p))

    log_kve = tf.math.log(kv_sum) + z
    log_kvep1 = tf.math.log(2. * kvp1_sum) + z - tf.math.log(z)
    return tf.math.exp(log_kve), tf.math.exp(log_kvep1)
 def _loop_cond(iter_, x_update_diff_norm_sq, x_update,
                hess_matmul_x_update):
   del x_update
   del hess_matmul_x_update
   sweep_complete = (iter_ > 0) & tf.equal(iter_ % dims, 0)
   small_delta = (
       x_update_diff_norm_sq < x_update_diff_norm_sq_convergence_threshold)
   converged = sweep_complete & small_delta
   allowed_more_iterations = iter_ < maximum_full_sweeps * dims
   return allowed_more_iterations & tf.reduce_any(~converged)
Example #27
0
def _suppression_loop_body(boxes, iou_threshold, output_size, idx):
    """Process boxes in the range [idx*NMS_TILE_SIZE, (idx+1)*NMS_TILE_SIZE).

  Args:
    boxes: a tensor with a shape of [batch_size, anchors, 4].
    iou_threshold: a float representing the threshold for deciding whether boxes
      overlap too much with respect to IOU.
    output_size: an int32 tensor of size [batch_size]. Representing the number
      of selected boxes for each batch.
    idx: an integer scalar representing induction variable.

  Returns:
    boxes: updated boxes.
    iou_threshold: pass down iou_threshold to the next iteration.
    output_size: the updated output_size.
    idx: the updated induction variable.
  """
    num_tiles = tf.shape(boxes)[1] // NMS_TILE_SIZE
    batch_size = tf.shape(boxes)[0]

    # Iterates over tiles that can possibly suppress the current tile.
    box_slice = tf.slice(boxes, [0, idx * NMS_TILE_SIZE, 0],
                         [batch_size, NMS_TILE_SIZE, 4])
    _, box_slice, _, _ = tf.while_loop(
        lambda _boxes, _box_slice, _threshold, inner_idx: inner_idx < idx,
        _cross_suppression, [boxes, box_slice, iou_threshold,
                             tf.constant(0)])

    # Iterates over the current tile to compute self-suppression.
    iou = box_utils.bbox_overlap(box_slice, box_slice)
    mask = tf.expand_dims(
        tf.reshape(tf.range(NMS_TILE_SIZE), [1, -1]) > tf.reshape(
            tf.range(NMS_TILE_SIZE), [-1, 1]), 0)
    iou *= tf.cast(tf.logical_and(mask, iou >= iou_threshold), iou.dtype)
    suppressed_iou, _, _ = tf.while_loop(
        lambda _iou, loop_condition, _iou_sum: loop_condition,
        _self_suppression,
        [iou, tf.constant(True),
         tf.reduce_sum(iou, [1, 2])])
    suppressed_box = tf.reduce_sum(suppressed_iou, 1) > 0
    box_slice *= tf.expand_dims(1.0 - tf.cast(suppressed_box, box_slice.dtype),
                                2)

    # Uses box_slice to update the input boxes.
    mask = tf.reshape(tf.cast(tf.equal(tf.range(num_tiles), idx), boxes.dtype),
                      [1, -1, 1, 1])
    boxes = tf.tile(tf.expand_dims(
        box_slice, [1]), [1, num_tiles, 1, 1]) * mask + tf.reshape(
            boxes, [batch_size, num_tiles, NMS_TILE_SIZE, 4]) * (1 - mask)
    boxes = tf.reshape(boxes, [batch_size, -1, 4])

    # Updates output_size.
    output_size += tf.reduce_sum(
        tf.cast(tf.reduce_any(box_slice > 0, [2]), tf.int32), [1])
    return boxes, iou_threshold, output_size, idx + 1
Example #28
0
    def _loop_body(curr_interval):
        """The loop body."""
        active = ~(curr_interval.converged | curr_interval.failed)
        # TODO(b/208441613): Skip updates for batch members that are not active?
        secant2_raw_result = hzl.secant2(value_and_gradients_function, val_0,
                                         curr_interval, f_lim,
                                         sufficient_decrease_param,
                                         curvature_param)
        secant2_result = HagerZhangLineSearchResult(
            ## TODO(b/208441613): `& ~curr_interval.failed` should not be needed.
            converged=secant2_raw_result.converged & ~curr_interval.failed,
            ## TODO(b/208441613): `| curr_interval.failed` should not be needed.
            failed=secant2_raw_result.failed | curr_interval.failed,
            iterations=curr_interval.iterations + tf.cast(active, tf.int32),
            func_evals=secant2_raw_result.num_evals,
            left=secant2_raw_result.left,
            right=secant2_raw_result.right)

        should_check_shrinkage = ~(secant2_result.converged
                                   | secant2_result.failed)

        def _do_check_shrinkage():
            """Check if interval has shrinked enough."""
            old_width = curr_interval.right.x - curr_interval.left.x
            new_width = secant2_result.right.x - secant2_result.left.x
            sufficient_shrinkage = new_width < old_width * shrinkage_param
            func_is_flat = (
                _very_close(curr_interval.left.f, curr_interval.right.f)
                & _very_close(secant2_result.left.f, secant2_result.right.f))

            new_converged = (should_check_shrinkage & sufficient_shrinkage
                             & func_is_flat)
            needs_inner_bisect = should_check_shrinkage & ~sufficient_shrinkage

            inner_bisect_args = secant2_result._replace(
                converged=secant2_result.converged | new_converged)

            def _apply_inner_bisect():
                return _line_search_inner_bisection(
                    value_and_gradients_function, inner_bisect_args,
                    needs_inner_bisect, f_lim)

            return prefer_static.cond(tf.reduce_any(needs_inner_bisect),
                                      _apply_inner_bisect,
                                      lambda: inner_bisect_args)

        next_args = prefer_static.cond(tf.reduce_any(should_check_shrinkage),
                                       _do_check_shrinkage,
                                       lambda: secant2_result)

        interval_shrunk = (~next_args.failed
                           & _very_close(next_args.left.x, next_args.right.x))
        return [
            next_args._replace(converged=next_args.converged | interval_shrunk)
        ]
def _self_suppression(iou, _, iou_sum):
  batch_size = tf.shape(iou)[0]
  can_suppress_others = tf.cast(
      tf.reshape(tf.reduce_max(iou, 1) <= 0.5, [batch_size, -1, 1]), iou.dtype)
  iou_suppressed = tf.reshape(
      tf.cast(tf.reduce_max(can_suppress_others * iou, 1) <= 0.5, iou.dtype),
      [batch_size, -1, 1]) * iou
  iou_sum_new = tf.reduce_sum(iou_suppressed, [1, 2])
  return [
      iou_suppressed,
      tf.reduce_any(iou_sum - iou_sum_new > 0.5), iou_sum_new
  ]
Example #30
0
def _hyp2f1_taylor_series(a, b, c, z):
    """Compute Hyp2F1(a, b, c, z) via the Taylor Series expansion."""
    with tf.name_scope('hyp2f1_taylor_series'):
        dtype = dtype_util.common_dtype([a, b, c, z], tf.float32)
        a = tf.convert_to_tensor(a, dtype=dtype)
        b = tf.convert_to_tensor(b, dtype=dtype)
        c = tf.convert_to_tensor(c, dtype=dtype)
        z = tf.convert_to_tensor(z, dtype=dtype)
        np_finfo = np.finfo(dtype_util.as_numpy_dtype(dtype))
        tolerance = tf.cast(np_finfo.resolution, dtype=dtype)

        broadcast_shape = functools.reduce(ps.broadcast_shape,
                                           [ps.shape(x) for x in [a, b, c, z]])

        def taylor_series(should_stop, index, term, taylor_sum, previous_term,
                          previous_taylor_sum, two_before_taylor_sum):
            new_term = term * (a + index) * (b + index) * z / ((c + index) *
                                                               (index + 1.))
            new_term = tf.where(should_stop, term, new_term)
            new_taylor_sum = tf.where(should_stop, taylor_sum,
                                      taylor_sum + new_term)

            # When a or be is near a negative integer n, it's possibly the term is
            # small because we are computing (a + n) * (b + n) in the numerator.
            # Checking that three consecutive terms are small compared their
            # corresponding sum will let us avoid this error.
            should_stop = (
                (tf.math.abs(new_term) < tolerance * tf.math.abs(taylor_sum)) &
                (tf.math.abs(term) <
                 tolerance * tf.math.abs(previous_taylor_sum)) &
                (tf.math.abs(previous_term) <
                 tolerance * tf.math.abs(two_before_taylor_sum)))
            return (tf.logical_or(should_stop,
                                  index > 2000.), index + 1., new_term,
                    new_taylor_sum, term, taylor_sum, previous_taylor_sum)

        (_, _, _, taylor_sum, _, _, _) = tf.while_loop(
            cond=lambda stop, *_: tf.reduce_any(~stop),
            body=taylor_series,
            loop_vars=(
                tf.zeros(broadcast_shape, dtype=tf.bool),
                tf.cast(0., dtype=dtype),
                # Only the previous term and taylor sum are used for computation.
                # The rest are used for checking convergence. We can safely set
                # these to zero.
                tf.ones(broadcast_shape, dtype=dtype),
                tf.ones(broadcast_shape, dtype=dtype),
                tf.zeros(broadcast_shape, dtype=dtype),
                tf.zeros(broadcast_shape, dtype=dtype),
                tf.zeros(broadcast_shape, dtype=dtype)))
        return taylor_sum