Ejemplo n.º 1
0
def assert_mvn_target_conservation(event_size, batch_size, **kwargs):
    initialization = tfd.MultivariateNormalFullCovariance(
        loc=tf.zeros(event_size),
        covariance_matrix=tf.eye(event_size)).sample(batch_size, seed=4)
    samples, leapfrogs = run_nuts_chain(event_size,
                                        batch_size,
                                        num_steps=1,
                                        initial_state=initialization,
                                        **kwargs)
    answer = samples[0][-1]
    check_cdf_agrees = (
        st.assert_multivariate_true_cdf_equal_on_projections_two_sample(
            answer, initialization, num_projections=100, false_fail_rate=1e-6))
    check_sample_shape = tf1.assert_equal(
        tf.shape(input=answer)[0], batch_size)
    unique, _ = tf.unique(leapfrogs[0])
    check_leapfrogs_vary = tf1.assert_greater_equal(
        tf.shape(input=unique)[0], 3)
    avg_leapfrogs = tf.math.reduce_mean(input_tensor=leapfrogs[0])
    check_leapfrogs = tf1.assert_greater_equal(
        avg_leapfrogs, tf.constant(4, dtype=avg_leapfrogs.dtype))
    movement = tf.linalg.norm(tensor=answer - initialization, axis=-1)
    # This movement distance (0.3) was copied from the univariate case.
    check_movement = tf1.assert_greater_equal(
        tf.reduce_mean(input_tensor=movement), 0.3)
    check_enough_power = tf1.assert_less(
        st.min_discrepancy_of_true_cdfs_detectable_by_dkwm_two_sample(
            batch_size, batch_size, false_fail_rate=1e-8,
            false_pass_rate=1e-6), 0.055)
    return (check_cdf_agrees, check_sample_shape, check_leapfrogs_vary,
            check_leapfrogs, check_movement, check_enough_power)
Ejemplo n.º 2
0
def assert_univariate_target_conservation(test, mk_target, step_size,
                                          stackless):
    # Sample count limited partly by memory reliably available on Forge.  The test
    # remains reasonable even if the nuts recursion limit is severely curtailed
    # (e.g., 3 or 4 levels), so use that to recover some memory footprint and bump
    # the sample count.
    num_samples = int(5e4)
    num_steps = 1
    target_d = mk_target()
    strm = tfp.util.SeedStream(salt='univariate_nuts_test', seed=1)
    # We wrap the initial values in `tf.identity` to avoid broken gradients
    # resulting from a bijector cache hit, since bijectors of the same
    # type/parameterization now share a cache.
    # TODO(b/72831017): Fix broken gradients caused by bijector caching.
    initialization = tf.identity(target_d.sample([num_samples], seed=strm()))

    def target(*args):
        # TODO(axch): Just use target_d.log_prob directly, and accept target_d
        # itself as an argument instead of a maker function.  Blocked by
        # b/128932888.  It would then also be nice not to eta-expand
        # target_d.log_prob; that was blocked by b/122414321, but maybe tfp's port
        # of value_and_gradients_function fixed that bug.
        return mk_target().log_prob(*args)

    operator = tfp.experimental.mcmc.NoUTurnSampler(target,
                                                    step_size=step_size,
                                                    max_tree_depth=3,
                                                    use_auto_batching=True,
                                                    stackless=stackless,
                                                    unrolled_leapfrog_steps=2,
                                                    seed=strm())
    result, extra = tfp.mcmc.sample_chain(num_results=num_steps,
                                          num_burnin_steps=0,
                                          current_state=initialization,
                                          kernel=operator)
    # Note: sample_chain puts the chain history on top, not the (independent)
    # chains.
    test.assertAllEqual([num_steps, num_samples], result.shape)
    answer = result[0]
    check_cdf_agrees = st.assert_true_cdf_equal_by_dkwm(answer,
                                                        target_d.cdf,
                                                        false_fail_rate=1e-6)
    check_enough_power = tf1.assert_less(
        st.min_discrepancy_of_true_cdfs_detectable_by_dkwm(
            num_samples, false_fail_rate=1e-6, false_pass_rate=1e-6), 0.025)
    test.assertAllEqual([num_samples], extra.leapfrogs_taken[0].shape)
    unique, _ = tf.unique(extra.leapfrogs_taken[0])
    check_leapfrogs_vary = tf1.assert_greater_equal(
        tf.shape(input=unique)[0], 3)
    avg_leapfrogs = tf.math.reduce_mean(input_tensor=extra.leapfrogs_taken[0])
    check_leapfrogs = tf1.assert_greater_equal(
        avg_leapfrogs, tf.constant(4, dtype=avg_leapfrogs.dtype))
    movement = tf.abs(answer - initialization)
    test.assertAllEqual([num_samples], movement.shape)
    # This movement distance (1 * step_size) was selected by reducing until 100
    # runs with independent seeds all passed.
    check_movement = tf1.assert_greater_equal(
        tf.reduce_mean(input_tensor=movement), 1 * step_size)
    return (check_cdf_agrees, check_enough_power, check_leapfrogs_vary,
            check_leapfrogs, check_movement)
Ejemplo n.º 3
0
def RandomCropImages(images, input_shape,
                     target_shape):
  """Crop a part of given shape from a random location in a list of images.

  Args:
    images: List of tensors of shape [batch_size, h, w, c].
    input_shape: Shape [h, w, c] of the input images.
    target_shape: Shape [h, w] of the cropped output.

  Raises:
    ValueError: In case the either the input_shape or the target_shape have a
      wrong length.

  Returns:
    crops: List of cropped tensors of shape [batch_size] + target_shape.
  """
  if len(input_shape) != 3:
    raise ValueError(
        'The input shape has to be of the form (height, width, channels) '
        'but has len {}'.format(len(input_shape)))
  if len(target_shape) != 2:
    raise ValueError('The target shape has to be of the form (height, width) '
                     'but has len {}'.format(len(target_shape)))
  max_y = int(input_shape[0]) - int(target_shape[0])
  max_x = int(input_shape[1]) - int(target_shape[1])
  with tf.control_dependencies(
      [tf.assert_greater_equal(max_x, 0),
       tf.assert_greater_equal(max_y, 0)]):
    offset_y = tf.random_uniform((), maxval=max_y + 1, dtype=tf.int32)
    offset_x = tf.random_uniform((), maxval=max_x + 1, dtype=tf.int32)
    return [
        tf.image.crop_to_bounding_box(img, offset_y, offset_x,
                                      int(target_shape[0]),
                                      int(target_shape[1])) for img in images
    ]
Ejemplo n.º 4
0
def area_loss(logits, ranges, length, max_area_width, allow_empty=False):
    """Computes the loss regarding areas.

  Args:
    logits: the predictions of each area [batch_size, query_length, num_areas].
    ranges: the groundtruth [batch_size, query_length, 2].
    length: the length of the original tensor.
    max_area_width: the maximum area width.
    allow_empty: whether to allow empty refs.
  Returns:
    the loss.
  """
    num_areas = common_layers.shape_list(logits)[-1]
    ranges = tf.reshape(ranges, [-1, 2])
    indices = area_range_to_index(area_range=ranges,
                                  length=length,
                                  max_area_width=max_area_width)
    if allow_empty:
        indices = tf.where(tf.greater(ranges[:, 1], ranges[:, 0]), indices + 1,
                           tf.zeros_like(indices))
    logits = tf.reshape(logits, [-1, num_areas])
    losses = tf.losses.sparse_softmax_cross_entropy(
        labels=indices, logits=logits, reduction=tf.losses.Reduction.NONE)
    with tf.control_dependencies(
        [tf.assert_greater_equal(ranges[:, 1], ranges[:, 0])]):
        if not allow_empty:
            mask = tf.greater(ranges[:, 1], ranges[:, 0])
            losses = losses * tf.cast(mask, tf.float32)
    return tf.reduce_mean(losses)
  def pre_attention(self, segment_number, query_antecedent,
                    memory_antecedent, bias):
    """Called prior to self-attention, to incorporate memory items.

    Args:
      segment_number: an integer Tensor with shape [batch]
      query_antecedent: a Tensor with shape [batch, length_q, channels]
      memory_antecedent: must be None. Attention normally allows this to be a
        Tensor with shape [batch, length_m, channels], but we currently only
        support memory for decoder-side self-attention.
      bias: bias Tensor (see attention_bias())
    Returns:
      (data, new_query_antecedent, new_memory_antecedent, new_bias)
    """
    with tf.variable_scope(self.name + "/pre_attention", reuse=tf.AUTO_REUSE):
      assert memory_antecedent is None, "We only support language modeling"
      with tf.control_dependencies([
          tf.assert_greater_equal(self.batch_size, tf.size(segment_number))]):
        difference = self.batch_size - tf.size(segment_number)
        segment_number = tf.pad(segment_number, [[0, difference]])
        reset_op = self.reset(tf.reshape(tf.where(
            tf.less(segment_number, self.segment_number)), [-1]))
      memory_results = {}
      with tf.control_dependencies([reset_op]):
        with tf.control_dependencies([
            self.update_segment_number(segment_number)]):
          x = tf.pad(query_antecedent, [
              [0, difference], [0, 0], [0, 0]])
          access_logits, retrieved_mem = self.read(x)
      memory_results["x"] = x
      memory_results["access_logits"] = access_logits
      memory_results["retrieved_mem"] = retrieved_mem
      return memory_results, query_antecedent, memory_antecedent, bias
Ejemplo n.º 6
0
def remidify(pitches):
  """Transforms [0, 88) to MIDI pitches [21, 108]."""
  assertions = [
      tf.assert_greater_equal(pitches, 0),
      tf.assert_less_equal(pitches, 87)
  ]
  with tf.control_dependencies(assertions):
    return pitches + 21
Ejemplo n.º 7
0
def demidify(pitches):
  """Transforms MIDI pitches [21,108] to [0, 88)."""
  assertions = [
      tf.assert_greater_equal(pitches, 21),
      tf.assert_less_equal(pitches, 108)
  ]
  with tf.control_dependencies(assertions):
    return pitches - 21
Ejemplo n.º 8
0
def assert_greater_equal(*args, **kwargs):
    """
  Wrapper for tf.assert_greater_equal.
  Overrides tf.device so that the assert always goes on CPU.
  The unwrapped version raises an exception if used with tf.device("/GPU:x").
  """
    with tf.device("/CPU:0"):
        return tf.assert_greater_equal(*args, **kwargs)
 def _scan_fn(*_):
     exchange = exchange_proposed_fn(num_replica, seed)
     flat_replicas = tf.reshape(exchange, [-1])
     with tf.control_dependencies([
             tf1.assert_equal(
                 tf.size(input=flat_replicas),
                 tf.size(input=tf.unique(flat_replicas)[0])),
             tf1.assert_greater_equal(flat_replicas, 0),
             tf1.assert_less(flat_replicas, num_replica),
     ]):
         return tf.shape(input=exchange)[0]
Ejemplo n.º 10
0
def _batch_stitch(features, mean_length=4.0, stddev=2.0):
    """Stitches a batch of single-step data to a batch of multi-step data."""
    batch_size = common_layers.shape_list(features['task'])[0]
    num_sequences = tf.maximum(
        tf.to_int32(tf.to_float(batch_size) / mean_length), 1)
    lengths = tf.random.truncated_normal(shape=[num_sequences],
                                         mean=mean_length,
                                         stddev=stddev)
    max_length = tf.reduce_max(lengths) * (tf.to_float(batch_size) /
                                           tf.reduce_sum(lengths))
    max_length = tf.to_int32(tf.ceil(max_length))
    total_items = max_length * num_sequences
    num_paddings = total_items - batch_size
    indices = tf.random.shuffle(tf.range(total_items))
    for key in features:
        shape_list = common_layers.shape_list(features[key])
        assert len(shape_list) >= 1
        with tf.control_dependencies([
                tf.assert_greater_equal(num_paddings,
                                        0,
                                        name='num_paddings_positive')
        ]):
            paddings = [[0, num_paddings]] + [[0, 0]] * (len(shape_list) - 1)
        features[key] = tf.pad(features[key],
                               paddings,
                               constant_values=-1 if key == 'obj_type' else 0)
        features[key] = tf.gather(features[key], indices)
        shape = [num_sequences, max_length]
        if len(shape_list) >= 2:
            shape += shape_list[1:]
        features[key] = tf.reshape(features[key], shape)
    # Remove all-padding seqs
    step_mask = tf.reduce_any(tf.greater(features['task'], 1), axis=-1)
    mask = tf.reduce_any(step_mask, axis=-1)
    step_mask = tf.boolean_mask(step_mask, mask)
    for key in features:
        features[key] = tf.boolean_mask(features[key], mask=mask)
    num_sequences = tf.shape(features['task'])[0]
    # Sort steps within each seq
    _, step_indices = tf.math.top_k(tf.to_int32(step_mask), k=max_length)
    step_indices = step_indices + tf.expand_dims(
        tf.range(num_sequences) * max_length, 1)
    step_indices = tf.reshape(step_indices, [-1])
    for key in features:
        shape_list = common_layers.shape_list(features[key])
        features[key] = tf.gather(
            tf.reshape(features[key], [-1] + shape_list[2:]), step_indices)
        features[key] = tf.reshape(features[key], shape_list)
    features = _stitch(features)
    return features
Ejemplo n.º 11
0
def sparse_softmax_cross_entropy(labels,
                                 logits,
                                 num_classes,
                                 weights=1.0,
                                 label_smoothing=0.1):
    """Softmax cross entropy with example weights, label smoothing."""
    assert_valid_label = [
        tf.assert_greater_equal(labels, tf.cast(0, dtype=tf.int64)),
        tf.assert_less(labels, tf.cast(num_classes, dtype=tf.int64))
    ]
    with tf.control_dependencies(assert_valid_label):
        labels = tf.reshape(labels, [-1])
        dense_labels = tf.one_hot(labels, num_classes)
        loss = tf.losses.softmax_cross_entropy(onehot_labels=dense_labels,
                                               logits=logits,
                                               weights=weights,
                                               label_smoothing=label_smoothing)
    return loss
Ejemplo n.º 12
0
def psnr(labels, predictions):
  """Computes average peak signal-to-noise ratio of `predictions`.

  Here PSNR is defined with respect to the maximum value of 1. All image tensors
  must be within the range [0, 1].

  Args:
    labels: Tensor of shape [B, H, W, N].
    predictions: Tensor of shape [B, H, W, N].

  Returns:
    Tuple of (psnr, update_op) as returned by tf.metrics.
  """
  predictions.shape.assert_is_compatible_with(labels.shape)
  with tf.control_dependencies([tf.assert_greater_equal(labels, 0.0),
                                tf.assert_less_equal(labels, 1.0)]):
    psnrs = tf.image.psnr(labels, predictions, max_val=1.0)
    psnrs = tf.boolean_mask(psnrs, tf.logical_not(tf.is_inf(psnrs)))
    return tf.metrics.mean(psnrs, name='psnr')
Ejemplo n.º 13
0
    def pop(self, mask, name=None):
        """Pops each indicated batch member, returns the new top of the stack.

    Does not mutate `self`.

    Args:
      mask: Boolean `Tensor` of shape `[batch_size]`. The stack frames at `True`
        indices of `mask` are regressed; the others are unchanged.
      name: Optional name for this op.

    Returns:
      new_stack: A new stack whose frames have been regressed where indicated
          by `mask`.
      read: The batch of values at the newly-current stack frame.
    """
        with tf.name_scope(name or 'Stack.pop'):
            mask = tf.convert_to_tensor(value=mask, name='mask')
            new_stack_index = self.stack_index - tf.cast(
                mask, self.stack_index.dtype)
            if self._safety_checks():
                with tf.control_dependencies([
                        tf1.assert_greater_equal(
                            new_stack_index,
                            tf.constant(0, new_stack_index.dtype))
                ]):
                    new_stack_index = tf.identity(new_stack_index)
            new_stack_index.set_shape(self.stack_index.shape)
            # self.stack:      [max_stack_depth * batch_size, ...]
            # self.stack_index:                  [batch_size]
            # returned:                          [batch_size, ...]
            batch_size = (tf.compat.dimension_value(self.stack_index.shape[0])
                          or tf.shape(input=self.stack_index,
                                      out_type=self.stack_index.dtype)[0])
            # Note that stack depth and batch are in a single dimension, stack major.
            gather_indices = (
                new_stack_index * batch_size +
                tf.range(batch_size, dtype=new_stack_index.dtype))
            read_value = tf.gather(self.stack, gather_indices)
            read_value.set_shape(
                self.stack_index.shape.concatenate(self.stack.shape[1:]))
            return type(self)(self.stack, new_stack_index), read_value
Ejemplo n.º 14
0
def percentile(x,
               q,
               axis=None,
               interpolation=None,
               keep_dims=False,
               validate_args=False,
               preserve_gradients=True,
               name=None):
    """Compute the `q`-th percentile(s) of `x`.

  Given a vector `x`, the `q`-th percentile of `x` is the value `q / 100` of the
  way from the minimum to the maximum in a sorted copy of `x`.

  The values and distances of the two nearest neighbors as well as the
  `interpolation` parameter will determine the percentile if the normalized
  ranking does not match the location of `q` exactly.

  This function is the same as the median if `q = 50`, the same as the minimum
  if `q = 0` and the same as the maximum if `q = 100`.

  Multiple percentiles can be computed at once by using `1-D` vector `q`.
  Dimension zero of the returned `Tensor` will index the different percentiles.

  Compare to `numpy.percentile`.

  Args:
    x:  Numeric `N-D` `Tensor` with `N > 0`.  If `axis` is not `None`,
      `x` must have statically known number of dimensions.
    q:  Scalar or vector `Tensor` with values in `[0, 100]`. The percentile(s).
    axis:  Optional `0-D` or `1-D` integer `Tensor` with constant values. The
      axis that index independent samples over which to return the desired
      percentile.  If `None` (the default), treat every dimension as a sample
      dimension, returning a scalar.
    interpolation : {'nearest', 'linear', 'lower', 'higher', 'midpoint'}.
      Default value: 'nearest'.  This specifies the interpolation method to
      use when the desired quantile lies between two data points `i < j`:
        * linear: i + (j - i) * fraction, where fraction is the fractional part
          of the index surrounded by i and j.
        * lower: `i`.
        * higher: `j`.
        * nearest: `i` or `j`, whichever is nearest.
        * midpoint: (i + j) / 2.
      `linear` and `midpoint` interpolation do not work with integer dtypes.
    keep_dims:  Python `bool`. If `True`, the last dimension is kept with size 1
      If `False`, the last dimension is removed from the output shape.
    validate_args:  Whether to add runtime checks of argument validity. If
      False, and arguments are incorrect, correct behavior is not guaranteed.
    preserve_gradients:  Python `bool`.  If `True`, ensure that gradient w.r.t
      the percentile `q` is preserved in the case of linear interpolation.
      If `False`, the gradient will be (incorrectly) zero when `q` corresponds
      to a point in `x`.
    name:  A Python string name to give this `Op`.  Default is 'percentile'

  Returns:
    A `(rank(q) + N - len(axis))` dimensional `Tensor` of same dtype as `x`, or,
      if `axis` is `None`, a `rank(q)` `Tensor`.  The first `rank(q)` dimensions
      index quantiles for different values of `q`.

  Raises:
    ValueError:  If argument 'interpolation' is not an allowed type.
    ValueError:  If interpolation type not compatible with `dtype`.

  #### Examples

  ```python
  # Get 30th percentile with default ('nearest') interpolation.
  x = [1., 2., 3., 4.]
  tfp.stats.percentile(x, q=30.)
  ==> 2.0

  # Get 30th percentile with 'linear' interpolation.
  x = [1., 2., 3., 4.]
  tfp.stats.percentile(x, q=30., interpolation='linear')
  ==> 1.9

  # Get 30th and 70th percentiles with 'lower' interpolation
  x = [1., 2., 3., 4.]
  tfp.stats.percentile(x, q=[30., 70.], interpolation='lower')
  ==> [1., 3.]

  # Get 100th percentile (maximum).  By default, this is computed over every dim
  x = [[1., 2.]
       [3., 4.]]
  tfp.stats.percentile(x, q=100.)
  ==> 4.

  # Treat the leading dim as indexing samples, and find the 100th quantile (max)
  # over all such samples.
  x = [[1., 2.]
       [3., 4.]]
  tfp.stats.percentile(x, q=100., axis=[0])
  ==> [3., 4.]
  ```

  """
    name = name or 'percentile'
    allowed_interpolations = {
        'linear', 'lower', 'higher', 'nearest', 'midpoint'
    }

    if interpolation is None:
        interpolation = 'nearest'
    else:
        if interpolation not in allowed_interpolations:
            raise ValueError(
                'Argument `interpolation` must be in %s.  Found %s' %
                (allowed_interpolations, interpolation))

    with tf1.name_scope(name, values=[x, q]):
        x = tf.convert_to_tensor(value=x, name='x')

        if interpolation in {'linear', 'midpoint'} and x.dtype.is_integer:
            raise TypeError(
                '{} interpolation not allowed with dtype {}'.format(
                    interpolation, x.dtype))

        # Double is needed here and below, else we get the wrong index if the array
        # is huge along axis.
        q = tf.cast(q, tf.float64)
        _get_static_ndims(q, expect_ndims_no_more_than=1)

        if validate_args:
            q = distribution_util.with_dependencies([
                tf1.assert_rank_in(q, [0, 1]),
                tf1.assert_greater_equal(q, tf.cast(0., tf.float64)),
                tf1.assert_less_equal(q, tf.cast(100., tf.float64))
            ], q)

        # Move `axis` dims of `x` to the rightmost, call it `y`.
        if axis is None:
            y = tf.reshape(x, [-1])
        else:
            x_ndims = _get_static_ndims(x,
                                        expect_static=True,
                                        expect_ndims_at_least=1)
            axis = _make_static_axis_non_negative_list(axis, x_ndims)
            y = _move_dims_to_flat_end(x, axis, x_ndims, right_end=True)

        frac_at_q_or_above = 1. - q / 100.

        # Sort everything, not just the top 'k' entries, which allows multiple calls
        # to sort only once (under the hood) and use CSE.
        sorted_y = _sort_tensor(y)

        d = tf.cast(tf.shape(input=y)[-1], tf.float64)

        def _get_indices(interp_type):
            """Get values of y at the indices implied by interp_type."""
            # Note `lower` <--> ceiling.  Confusing, huh?  Due to the fact that
            # _sort_tensor sorts highest to lowest, tf.ceil corresponds to the higher
            # index, but the lower value of y!
            if interp_type == 'lower':
                indices = tf.math.ceil((d - 1) * frac_at_q_or_above)
            elif interp_type == 'higher':
                indices = tf.floor((d - 1) * frac_at_q_or_above)
            elif interp_type == 'nearest':
                indices = tf.round((d - 1) * frac_at_q_or_above)
            # d - 1 will be distinct from d in int32, but not necessarily double.
            # So clip to avoid out of bounds errors.
            return tf.clip_by_value(tf.cast(indices, tf.int32), 0,
                                    tf.shape(input=y)[-1] - 1)

        if interpolation in ['nearest', 'lower', 'higher']:
            gathered_y = tf.gather(sorted_y,
                                   _get_indices(interpolation),
                                   axis=-1)
        elif interpolation == 'midpoint':
            gathered_y = 0.5 * (
                tf.gather(sorted_y, _get_indices('lower'), axis=-1) +
                tf.gather(sorted_y, _get_indices('higher'), axis=-1))
        elif interpolation == 'linear':
            # Copy-paste of docstring on interpolation:
            # linear: i + (j - i) * fraction, where fraction is the fractional part
            # of the index surrounded by i and j.
            larger_y_idx = _get_indices('lower')
            exact_idx = (d - 1) * frac_at_q_or_above
            if preserve_gradients:
                # If q corresponds to a point in x, we will initially have
                # larger_y_idx == smaller_y_idx.
                # This results in the gradient w.r.t. fraction being zero (recall `q`
                # enters only through `fraction`...and see that things cancel).
                # The fix is to ensure that smaller_y_idx and larger_y_idx are always
                # separated by exactly 1.
                smaller_y_idx = tf.maximum(larger_y_idx - 1, 0)
                larger_y_idx = tf.minimum(smaller_y_idx + 1,
                                          tf.shape(input=y)[-1] - 1)
                fraction = tf.cast(larger_y_idx, tf.float64) - exact_idx
            else:
                smaller_y_idx = _get_indices('higher')
                fraction = tf.math.ceil(
                    (d - 1) * frac_at_q_or_above) - exact_idx

            fraction = tf.cast(fraction, y.dtype)
            gathered_y = (
                tf.gather(sorted_y, larger_y_idx, axis=-1) * (1 - fraction) +
                tf.gather(sorted_y, smaller_y_idx, axis=-1) * fraction)

        # Propagate NaNs
        if x.dtype in (tf.bfloat16, tf.float16, tf.float32, tf.float64):
            # Apparently tf.is_nan doesn't like other dtypes
            nan_batch_members = tf.reduce_any(input_tensor=tf.math.is_nan(x),
                                              axis=axis)
            right_rank_matched_shape = tf.pad(
                tensor=tf.shape(input=nan_batch_members),
                paddings=[[0, tf.rank(input=q)]],
                constant_values=1)
            nan_batch_members = tf.reshape(nan_batch_members,
                                           shape=right_rank_matched_shape)
            nan = np.array(np.nan, gathered_y.dtype.as_numpy_dtype)
            gathered_y = tf.where(nan_batch_members, nan, gathered_y)

        # Expand dimensions if requested
        if keep_dims:
            if axis is None:
                ones_vec = tf.ones(shape=[
                    _get_best_effort_ndims(x) + _get_best_effort_ndims(q)
                ],
                                   dtype=tf.int32)
                gathered_y *= tf.ones(ones_vec, dtype=x.dtype)
            else:
                gathered_y = _insert_back_keep_dims(gathered_y, axis)

        # If q is a scalar, then result has the right shape.
        # If q is a vector, then result has trailing dim of shape q.shape, which
        # needs to be rotated to dim 0.
        return distribution_util.rotate_transpose(gathered_y, tf.rank(q))
Ejemplo n.º 15
0
def expected_calibration_error(y_true, y_pred, nbins=20):
    """Calculates Expected Calibration Error (ECE).

  ECE is a scalar summary statistic of calibration error. It is the
  sample-weighted average of the difference between the predicted and true
  probabilities of a positive detection across uniformly-spaced model
  confidences [0, 1]. See referenced paper for a thorough explanation.

  Reference:
    Guo, et. al, "On Calibration of Modern Neural Networks"
    Page 2, Expected Calibration Error (ECE).
    https://arxiv.org/pdf/1706.04599.pdf

  This function creates three local variables, `bin_counts`, `bin_true_sum`, and
  `bin_preds_sum` that are used to compute ECE.  For estimation of the metric
  over a stream of data, the function creates an `update_op` operation that
  updates these variables and returns the ECE.

  Args:
    y_true: 1-D tf.int64 Tensor of binarized ground truth, corresponding to each
      prediction in y_pred.
    y_pred: 1-D tf.float32 tensor of model confidence scores in range
      [0.0, 1.0].
    nbins: int specifying the number of uniformly-spaced bins into which y_pred
      will be bucketed.

  Returns:
    value_op: A value metric op that returns ece.
    update_op: An operation that increments the `bin_counts`, `bin_true_sum`,
      and `bin_preds_sum` variables appropriately and whose value matches `ece`.

  Raises:
    InvalidArgumentError: if y_pred is not in [0.0, 1.0].
  """
    bin_counts = metrics_impl.metric_variable([nbins],
                                              tf.float32,
                                              name='bin_counts')
    bin_true_sum = metrics_impl.metric_variable([nbins],
                                                tf.float32,
                                                name='true_sum')
    bin_preds_sum = metrics_impl.metric_variable([nbins],
                                                 tf.float32,
                                                 name='preds_sum')

    with tf.control_dependencies([
            tf.assert_greater_equal(y_pred, 0.0),
            tf.assert_less_equal(y_pred, 1.0),
    ]):
        bin_ids = tf.histogram_fixed_width_bins(y_pred, [0.0, 1.0],
                                                nbins=nbins)

    with tf.control_dependencies([bin_ids]):
        update_bin_counts_op = tf.assign_add(
            bin_counts,
            tf.cast(tf.bincount(bin_ids, minlength=nbins), dtype=tf.float32))
        update_bin_true_sum_op = tf.assign_add(
            bin_true_sum,
            tf.cast(tf.bincount(bin_ids, weights=y_true, minlength=nbins),
                    dtype=tf.float32))
        update_bin_preds_sum_op = tf.assign_add(
            bin_preds_sum,
            tf.cast(tf.bincount(bin_ids, weights=y_pred, minlength=nbins),
                    dtype=tf.float32))

    ece_update_op = _ece_from_bins(update_bin_counts_op,
                                   update_bin_true_sum_op,
                                   update_bin_preds_sum_op,
                                   name='update_op')
    ece = _ece_from_bins(bin_counts, bin_true_sum, bin_preds_sum, name='value')
    return ece, ece_update_op