def assert_mvn_target_conservation(event_size, batch_size, **kwargs): initialization = tfd.MultivariateNormalFullCovariance( loc=tf.zeros(event_size), covariance_matrix=tf.eye(event_size)).sample(batch_size, seed=4) samples, leapfrogs = run_nuts_chain(event_size, batch_size, num_steps=1, initial_state=initialization, **kwargs) answer = samples[0][-1] check_cdf_agrees = ( st.assert_multivariate_true_cdf_equal_on_projections_two_sample( answer, initialization, num_projections=100, false_fail_rate=1e-6)) check_sample_shape = tf1.assert_equal( tf.shape(input=answer)[0], batch_size) unique, _ = tf.unique(leapfrogs[0]) check_leapfrogs_vary = tf1.assert_greater_equal( tf.shape(input=unique)[0], 3) avg_leapfrogs = tf.math.reduce_mean(input_tensor=leapfrogs[0]) check_leapfrogs = tf1.assert_greater_equal( avg_leapfrogs, tf.constant(4, dtype=avg_leapfrogs.dtype)) movement = tf.linalg.norm(tensor=answer - initialization, axis=-1) # This movement distance (0.3) was copied from the univariate case. check_movement = tf1.assert_greater_equal( tf.reduce_mean(input_tensor=movement), 0.3) check_enough_power = tf1.assert_less( st.min_discrepancy_of_true_cdfs_detectable_by_dkwm_two_sample( batch_size, batch_size, false_fail_rate=1e-8, false_pass_rate=1e-6), 0.055) return (check_cdf_agrees, check_sample_shape, check_leapfrogs_vary, check_leapfrogs, check_movement, check_enough_power)
def assert_univariate_target_conservation(test, mk_target, step_size, stackless): # Sample count limited partly by memory reliably available on Forge. The test # remains reasonable even if the nuts recursion limit is severely curtailed # (e.g., 3 or 4 levels), so use that to recover some memory footprint and bump # the sample count. num_samples = int(5e4) num_steps = 1 target_d = mk_target() strm = tfp.util.SeedStream(salt='univariate_nuts_test', seed=1) # We wrap the initial values in `tf.identity` to avoid broken gradients # resulting from a bijector cache hit, since bijectors of the same # type/parameterization now share a cache. # TODO(b/72831017): Fix broken gradients caused by bijector caching. initialization = tf.identity(target_d.sample([num_samples], seed=strm())) def target(*args): # TODO(axch): Just use target_d.log_prob directly, and accept target_d # itself as an argument instead of a maker function. Blocked by # b/128932888. It would then also be nice not to eta-expand # target_d.log_prob; that was blocked by b/122414321, but maybe tfp's port # of value_and_gradients_function fixed that bug. return mk_target().log_prob(*args) operator = tfp.experimental.mcmc.NoUTurnSampler(target, step_size=step_size, max_tree_depth=3, use_auto_batching=True, stackless=stackless, unrolled_leapfrog_steps=2, seed=strm()) result, extra = tfp.mcmc.sample_chain(num_results=num_steps, num_burnin_steps=0, current_state=initialization, kernel=operator) # Note: sample_chain puts the chain history on top, not the (independent) # chains. test.assertAllEqual([num_steps, num_samples], result.shape) answer = result[0] check_cdf_agrees = st.assert_true_cdf_equal_by_dkwm(answer, target_d.cdf, false_fail_rate=1e-6) check_enough_power = tf1.assert_less( st.min_discrepancy_of_true_cdfs_detectable_by_dkwm( num_samples, false_fail_rate=1e-6, false_pass_rate=1e-6), 0.025) test.assertAllEqual([num_samples], extra.leapfrogs_taken[0].shape) unique, _ = tf.unique(extra.leapfrogs_taken[0]) check_leapfrogs_vary = tf1.assert_greater_equal( tf.shape(input=unique)[0], 3) avg_leapfrogs = tf.math.reduce_mean(input_tensor=extra.leapfrogs_taken[0]) check_leapfrogs = tf1.assert_greater_equal( avg_leapfrogs, tf.constant(4, dtype=avg_leapfrogs.dtype)) movement = tf.abs(answer - initialization) test.assertAllEqual([num_samples], movement.shape) # This movement distance (1 * step_size) was selected by reducing until 100 # runs with independent seeds all passed. check_movement = tf1.assert_greater_equal( tf.reduce_mean(input_tensor=movement), 1 * step_size) return (check_cdf_agrees, check_enough_power, check_leapfrogs_vary, check_leapfrogs, check_movement)
def RandomCropImages(images, input_shape, target_shape): """Crop a part of given shape from a random location in a list of images. Args: images: List of tensors of shape [batch_size, h, w, c]. input_shape: Shape [h, w, c] of the input images. target_shape: Shape [h, w] of the cropped output. Raises: ValueError: In case the either the input_shape or the target_shape have a wrong length. Returns: crops: List of cropped tensors of shape [batch_size] + target_shape. """ if len(input_shape) != 3: raise ValueError( 'The input shape has to be of the form (height, width, channels) ' 'but has len {}'.format(len(input_shape))) if len(target_shape) != 2: raise ValueError('The target shape has to be of the form (height, width) ' 'but has len {}'.format(len(target_shape))) max_y = int(input_shape[0]) - int(target_shape[0]) max_x = int(input_shape[1]) - int(target_shape[1]) with tf.control_dependencies( [tf.assert_greater_equal(max_x, 0), tf.assert_greater_equal(max_y, 0)]): offset_y = tf.random_uniform((), maxval=max_y + 1, dtype=tf.int32) offset_x = tf.random_uniform((), maxval=max_x + 1, dtype=tf.int32) return [ tf.image.crop_to_bounding_box(img, offset_y, offset_x, int(target_shape[0]), int(target_shape[1])) for img in images ]
def area_loss(logits, ranges, length, max_area_width, allow_empty=False): """Computes the loss regarding areas. Args: logits: the predictions of each area [batch_size, query_length, num_areas]. ranges: the groundtruth [batch_size, query_length, 2]. length: the length of the original tensor. max_area_width: the maximum area width. allow_empty: whether to allow empty refs. Returns: the loss. """ num_areas = common_layers.shape_list(logits)[-1] ranges = tf.reshape(ranges, [-1, 2]) indices = area_range_to_index(area_range=ranges, length=length, max_area_width=max_area_width) if allow_empty: indices = tf.where(tf.greater(ranges[:, 1], ranges[:, 0]), indices + 1, tf.zeros_like(indices)) logits = tf.reshape(logits, [-1, num_areas]) losses = tf.losses.sparse_softmax_cross_entropy( labels=indices, logits=logits, reduction=tf.losses.Reduction.NONE) with tf.control_dependencies( [tf.assert_greater_equal(ranges[:, 1], ranges[:, 0])]): if not allow_empty: mask = tf.greater(ranges[:, 1], ranges[:, 0]) losses = losses * tf.cast(mask, tf.float32) return tf.reduce_mean(losses)
def pre_attention(self, segment_number, query_antecedent, memory_antecedent, bias): """Called prior to self-attention, to incorporate memory items. Args: segment_number: an integer Tensor with shape [batch] query_antecedent: a Tensor with shape [batch, length_q, channels] memory_antecedent: must be None. Attention normally allows this to be a Tensor with shape [batch, length_m, channels], but we currently only support memory for decoder-side self-attention. bias: bias Tensor (see attention_bias()) Returns: (data, new_query_antecedent, new_memory_antecedent, new_bias) """ with tf.variable_scope(self.name + "/pre_attention", reuse=tf.AUTO_REUSE): assert memory_antecedent is None, "We only support language modeling" with tf.control_dependencies([ tf.assert_greater_equal(self.batch_size, tf.size(segment_number))]): difference = self.batch_size - tf.size(segment_number) segment_number = tf.pad(segment_number, [[0, difference]]) reset_op = self.reset(tf.reshape(tf.where( tf.less(segment_number, self.segment_number)), [-1])) memory_results = {} with tf.control_dependencies([reset_op]): with tf.control_dependencies([ self.update_segment_number(segment_number)]): x = tf.pad(query_antecedent, [ [0, difference], [0, 0], [0, 0]]) access_logits, retrieved_mem = self.read(x) memory_results["x"] = x memory_results["access_logits"] = access_logits memory_results["retrieved_mem"] = retrieved_mem return memory_results, query_antecedent, memory_antecedent, bias
def remidify(pitches): """Transforms [0, 88) to MIDI pitches [21, 108].""" assertions = [ tf.assert_greater_equal(pitches, 0), tf.assert_less_equal(pitches, 87) ] with tf.control_dependencies(assertions): return pitches + 21
def demidify(pitches): """Transforms MIDI pitches [21,108] to [0, 88).""" assertions = [ tf.assert_greater_equal(pitches, 21), tf.assert_less_equal(pitches, 108) ] with tf.control_dependencies(assertions): return pitches - 21
def assert_greater_equal(*args, **kwargs): """ Wrapper for tf.assert_greater_equal. Overrides tf.device so that the assert always goes on CPU. The unwrapped version raises an exception if used with tf.device("/GPU:x"). """ with tf.device("/CPU:0"): return tf.assert_greater_equal(*args, **kwargs)
def _scan_fn(*_): exchange = exchange_proposed_fn(num_replica, seed) flat_replicas = tf.reshape(exchange, [-1]) with tf.control_dependencies([ tf1.assert_equal( tf.size(input=flat_replicas), tf.size(input=tf.unique(flat_replicas)[0])), tf1.assert_greater_equal(flat_replicas, 0), tf1.assert_less(flat_replicas, num_replica), ]): return tf.shape(input=exchange)[0]
def _batch_stitch(features, mean_length=4.0, stddev=2.0): """Stitches a batch of single-step data to a batch of multi-step data.""" batch_size = common_layers.shape_list(features['task'])[0] num_sequences = tf.maximum( tf.to_int32(tf.to_float(batch_size) / mean_length), 1) lengths = tf.random.truncated_normal(shape=[num_sequences], mean=mean_length, stddev=stddev) max_length = tf.reduce_max(lengths) * (tf.to_float(batch_size) / tf.reduce_sum(lengths)) max_length = tf.to_int32(tf.ceil(max_length)) total_items = max_length * num_sequences num_paddings = total_items - batch_size indices = tf.random.shuffle(tf.range(total_items)) for key in features: shape_list = common_layers.shape_list(features[key]) assert len(shape_list) >= 1 with tf.control_dependencies([ tf.assert_greater_equal(num_paddings, 0, name='num_paddings_positive') ]): paddings = [[0, num_paddings]] + [[0, 0]] * (len(shape_list) - 1) features[key] = tf.pad(features[key], paddings, constant_values=-1 if key == 'obj_type' else 0) features[key] = tf.gather(features[key], indices) shape = [num_sequences, max_length] if len(shape_list) >= 2: shape += shape_list[1:] features[key] = tf.reshape(features[key], shape) # Remove all-padding seqs step_mask = tf.reduce_any(tf.greater(features['task'], 1), axis=-1) mask = tf.reduce_any(step_mask, axis=-1) step_mask = tf.boolean_mask(step_mask, mask) for key in features: features[key] = tf.boolean_mask(features[key], mask=mask) num_sequences = tf.shape(features['task'])[0] # Sort steps within each seq _, step_indices = tf.math.top_k(tf.to_int32(step_mask), k=max_length) step_indices = step_indices + tf.expand_dims( tf.range(num_sequences) * max_length, 1) step_indices = tf.reshape(step_indices, [-1]) for key in features: shape_list = common_layers.shape_list(features[key]) features[key] = tf.gather( tf.reshape(features[key], [-1] + shape_list[2:]), step_indices) features[key] = tf.reshape(features[key], shape_list) features = _stitch(features) return features
def sparse_softmax_cross_entropy(labels, logits, num_classes, weights=1.0, label_smoothing=0.1): """Softmax cross entropy with example weights, label smoothing.""" assert_valid_label = [ tf.assert_greater_equal(labels, tf.cast(0, dtype=tf.int64)), tf.assert_less(labels, tf.cast(num_classes, dtype=tf.int64)) ] with tf.control_dependencies(assert_valid_label): labels = tf.reshape(labels, [-1]) dense_labels = tf.one_hot(labels, num_classes) loss = tf.losses.softmax_cross_entropy(onehot_labels=dense_labels, logits=logits, weights=weights, label_smoothing=label_smoothing) return loss
def psnr(labels, predictions): """Computes average peak signal-to-noise ratio of `predictions`. Here PSNR is defined with respect to the maximum value of 1. All image tensors must be within the range [0, 1]. Args: labels: Tensor of shape [B, H, W, N]. predictions: Tensor of shape [B, H, W, N]. Returns: Tuple of (psnr, update_op) as returned by tf.metrics. """ predictions.shape.assert_is_compatible_with(labels.shape) with tf.control_dependencies([tf.assert_greater_equal(labels, 0.0), tf.assert_less_equal(labels, 1.0)]): psnrs = tf.image.psnr(labels, predictions, max_val=1.0) psnrs = tf.boolean_mask(psnrs, tf.logical_not(tf.is_inf(psnrs))) return tf.metrics.mean(psnrs, name='psnr')
def pop(self, mask, name=None): """Pops each indicated batch member, returns the new top of the stack. Does not mutate `self`. Args: mask: Boolean `Tensor` of shape `[batch_size]`. The stack frames at `True` indices of `mask` are regressed; the others are unchanged. name: Optional name for this op. Returns: new_stack: A new stack whose frames have been regressed where indicated by `mask`. read: The batch of values at the newly-current stack frame. """ with tf.name_scope(name or 'Stack.pop'): mask = tf.convert_to_tensor(value=mask, name='mask') new_stack_index = self.stack_index - tf.cast( mask, self.stack_index.dtype) if self._safety_checks(): with tf.control_dependencies([ tf1.assert_greater_equal( new_stack_index, tf.constant(0, new_stack_index.dtype)) ]): new_stack_index = tf.identity(new_stack_index) new_stack_index.set_shape(self.stack_index.shape) # self.stack: [max_stack_depth * batch_size, ...] # self.stack_index: [batch_size] # returned: [batch_size, ...] batch_size = (tf.compat.dimension_value(self.stack_index.shape[0]) or tf.shape(input=self.stack_index, out_type=self.stack_index.dtype)[0]) # Note that stack depth and batch are in a single dimension, stack major. gather_indices = ( new_stack_index * batch_size + tf.range(batch_size, dtype=new_stack_index.dtype)) read_value = tf.gather(self.stack, gather_indices) read_value.set_shape( self.stack_index.shape.concatenate(self.stack.shape[1:])) return type(self)(self.stack, new_stack_index), read_value
def percentile(x, q, axis=None, interpolation=None, keep_dims=False, validate_args=False, preserve_gradients=True, name=None): """Compute the `q`-th percentile(s) of `x`. Given a vector `x`, the `q`-th percentile of `x` is the value `q / 100` of the way from the minimum to the maximum in a sorted copy of `x`. The values and distances of the two nearest neighbors as well as the `interpolation` parameter will determine the percentile if the normalized ranking does not match the location of `q` exactly. This function is the same as the median if `q = 50`, the same as the minimum if `q = 0` and the same as the maximum if `q = 100`. Multiple percentiles can be computed at once by using `1-D` vector `q`. Dimension zero of the returned `Tensor` will index the different percentiles. Compare to `numpy.percentile`. Args: x: Numeric `N-D` `Tensor` with `N > 0`. If `axis` is not `None`, `x` must have statically known number of dimensions. q: Scalar or vector `Tensor` with values in `[0, 100]`. The percentile(s). axis: Optional `0-D` or `1-D` integer `Tensor` with constant values. The axis that index independent samples over which to return the desired percentile. If `None` (the default), treat every dimension as a sample dimension, returning a scalar. interpolation : {'nearest', 'linear', 'lower', 'higher', 'midpoint'}. Default value: 'nearest'. This specifies the interpolation method to use when the desired quantile lies between two data points `i < j`: * linear: i + (j - i) * fraction, where fraction is the fractional part of the index surrounded by i and j. * lower: `i`. * higher: `j`. * nearest: `i` or `j`, whichever is nearest. * midpoint: (i + j) / 2. `linear` and `midpoint` interpolation do not work with integer dtypes. keep_dims: Python `bool`. If `True`, the last dimension is kept with size 1 If `False`, the last dimension is removed from the output shape. validate_args: Whether to add runtime checks of argument validity. If False, and arguments are incorrect, correct behavior is not guaranteed. preserve_gradients: Python `bool`. If `True`, ensure that gradient w.r.t the percentile `q` is preserved in the case of linear interpolation. If `False`, the gradient will be (incorrectly) zero when `q` corresponds to a point in `x`. name: A Python string name to give this `Op`. Default is 'percentile' Returns: A `(rank(q) + N - len(axis))` dimensional `Tensor` of same dtype as `x`, or, if `axis` is `None`, a `rank(q)` `Tensor`. The first `rank(q)` dimensions index quantiles for different values of `q`. Raises: ValueError: If argument 'interpolation' is not an allowed type. ValueError: If interpolation type not compatible with `dtype`. #### Examples ```python # Get 30th percentile with default ('nearest') interpolation. x = [1., 2., 3., 4.] tfp.stats.percentile(x, q=30.) ==> 2.0 # Get 30th percentile with 'linear' interpolation. x = [1., 2., 3., 4.] tfp.stats.percentile(x, q=30., interpolation='linear') ==> 1.9 # Get 30th and 70th percentiles with 'lower' interpolation x = [1., 2., 3., 4.] tfp.stats.percentile(x, q=[30., 70.], interpolation='lower') ==> [1., 3.] # Get 100th percentile (maximum). By default, this is computed over every dim x = [[1., 2.] [3., 4.]] tfp.stats.percentile(x, q=100.) ==> 4. # Treat the leading dim as indexing samples, and find the 100th quantile (max) # over all such samples. x = [[1., 2.] [3., 4.]] tfp.stats.percentile(x, q=100., axis=[0]) ==> [3., 4.] ``` """ name = name or 'percentile' allowed_interpolations = { 'linear', 'lower', 'higher', 'nearest', 'midpoint' } if interpolation is None: interpolation = 'nearest' else: if interpolation not in allowed_interpolations: raise ValueError( 'Argument `interpolation` must be in %s. Found %s' % (allowed_interpolations, interpolation)) with tf1.name_scope(name, values=[x, q]): x = tf.convert_to_tensor(value=x, name='x') if interpolation in {'linear', 'midpoint'} and x.dtype.is_integer: raise TypeError( '{} interpolation not allowed with dtype {}'.format( interpolation, x.dtype)) # Double is needed here and below, else we get the wrong index if the array # is huge along axis. q = tf.cast(q, tf.float64) _get_static_ndims(q, expect_ndims_no_more_than=1) if validate_args: q = distribution_util.with_dependencies([ tf1.assert_rank_in(q, [0, 1]), tf1.assert_greater_equal(q, tf.cast(0., tf.float64)), tf1.assert_less_equal(q, tf.cast(100., tf.float64)) ], q) # Move `axis` dims of `x` to the rightmost, call it `y`. if axis is None: y = tf.reshape(x, [-1]) else: x_ndims = _get_static_ndims(x, expect_static=True, expect_ndims_at_least=1) axis = _make_static_axis_non_negative_list(axis, x_ndims) y = _move_dims_to_flat_end(x, axis, x_ndims, right_end=True) frac_at_q_or_above = 1. - q / 100. # Sort everything, not just the top 'k' entries, which allows multiple calls # to sort only once (under the hood) and use CSE. sorted_y = _sort_tensor(y) d = tf.cast(tf.shape(input=y)[-1], tf.float64) def _get_indices(interp_type): """Get values of y at the indices implied by interp_type.""" # Note `lower` <--> ceiling. Confusing, huh? Due to the fact that # _sort_tensor sorts highest to lowest, tf.ceil corresponds to the higher # index, but the lower value of y! if interp_type == 'lower': indices = tf.math.ceil((d - 1) * frac_at_q_or_above) elif interp_type == 'higher': indices = tf.floor((d - 1) * frac_at_q_or_above) elif interp_type == 'nearest': indices = tf.round((d - 1) * frac_at_q_or_above) # d - 1 will be distinct from d in int32, but not necessarily double. # So clip to avoid out of bounds errors. return tf.clip_by_value(tf.cast(indices, tf.int32), 0, tf.shape(input=y)[-1] - 1) if interpolation in ['nearest', 'lower', 'higher']: gathered_y = tf.gather(sorted_y, _get_indices(interpolation), axis=-1) elif interpolation == 'midpoint': gathered_y = 0.5 * ( tf.gather(sorted_y, _get_indices('lower'), axis=-1) + tf.gather(sorted_y, _get_indices('higher'), axis=-1)) elif interpolation == 'linear': # Copy-paste of docstring on interpolation: # linear: i + (j - i) * fraction, where fraction is the fractional part # of the index surrounded by i and j. larger_y_idx = _get_indices('lower') exact_idx = (d - 1) * frac_at_q_or_above if preserve_gradients: # If q corresponds to a point in x, we will initially have # larger_y_idx == smaller_y_idx. # This results in the gradient w.r.t. fraction being zero (recall `q` # enters only through `fraction`...and see that things cancel). # The fix is to ensure that smaller_y_idx and larger_y_idx are always # separated by exactly 1. smaller_y_idx = tf.maximum(larger_y_idx - 1, 0) larger_y_idx = tf.minimum(smaller_y_idx + 1, tf.shape(input=y)[-1] - 1) fraction = tf.cast(larger_y_idx, tf.float64) - exact_idx else: smaller_y_idx = _get_indices('higher') fraction = tf.math.ceil( (d - 1) * frac_at_q_or_above) - exact_idx fraction = tf.cast(fraction, y.dtype) gathered_y = ( tf.gather(sorted_y, larger_y_idx, axis=-1) * (1 - fraction) + tf.gather(sorted_y, smaller_y_idx, axis=-1) * fraction) # Propagate NaNs if x.dtype in (tf.bfloat16, tf.float16, tf.float32, tf.float64): # Apparently tf.is_nan doesn't like other dtypes nan_batch_members = tf.reduce_any(input_tensor=tf.math.is_nan(x), axis=axis) right_rank_matched_shape = tf.pad( tensor=tf.shape(input=nan_batch_members), paddings=[[0, tf.rank(input=q)]], constant_values=1) nan_batch_members = tf.reshape(nan_batch_members, shape=right_rank_matched_shape) nan = np.array(np.nan, gathered_y.dtype.as_numpy_dtype) gathered_y = tf.where(nan_batch_members, nan, gathered_y) # Expand dimensions if requested if keep_dims: if axis is None: ones_vec = tf.ones(shape=[ _get_best_effort_ndims(x) + _get_best_effort_ndims(q) ], dtype=tf.int32) gathered_y *= tf.ones(ones_vec, dtype=x.dtype) else: gathered_y = _insert_back_keep_dims(gathered_y, axis) # If q is a scalar, then result has the right shape. # If q is a vector, then result has trailing dim of shape q.shape, which # needs to be rotated to dim 0. return distribution_util.rotate_transpose(gathered_y, tf.rank(q))
def expected_calibration_error(y_true, y_pred, nbins=20): """Calculates Expected Calibration Error (ECE). ECE is a scalar summary statistic of calibration error. It is the sample-weighted average of the difference between the predicted and true probabilities of a positive detection across uniformly-spaced model confidences [0, 1]. See referenced paper for a thorough explanation. Reference: Guo, et. al, "On Calibration of Modern Neural Networks" Page 2, Expected Calibration Error (ECE). https://arxiv.org/pdf/1706.04599.pdf This function creates three local variables, `bin_counts`, `bin_true_sum`, and `bin_preds_sum` that are used to compute ECE. For estimation of the metric over a stream of data, the function creates an `update_op` operation that updates these variables and returns the ECE. Args: y_true: 1-D tf.int64 Tensor of binarized ground truth, corresponding to each prediction in y_pred. y_pred: 1-D tf.float32 tensor of model confidence scores in range [0.0, 1.0]. nbins: int specifying the number of uniformly-spaced bins into which y_pred will be bucketed. Returns: value_op: A value metric op that returns ece. update_op: An operation that increments the `bin_counts`, `bin_true_sum`, and `bin_preds_sum` variables appropriately and whose value matches `ece`. Raises: InvalidArgumentError: if y_pred is not in [0.0, 1.0]. """ bin_counts = metrics_impl.metric_variable([nbins], tf.float32, name='bin_counts') bin_true_sum = metrics_impl.metric_variable([nbins], tf.float32, name='true_sum') bin_preds_sum = metrics_impl.metric_variable([nbins], tf.float32, name='preds_sum') with tf.control_dependencies([ tf.assert_greater_equal(y_pred, 0.0), tf.assert_less_equal(y_pred, 1.0), ]): bin_ids = tf.histogram_fixed_width_bins(y_pred, [0.0, 1.0], nbins=nbins) with tf.control_dependencies([bin_ids]): update_bin_counts_op = tf.assign_add( bin_counts, tf.cast(tf.bincount(bin_ids, minlength=nbins), dtype=tf.float32)) update_bin_true_sum_op = tf.assign_add( bin_true_sum, tf.cast(tf.bincount(bin_ids, weights=y_true, minlength=nbins), dtype=tf.float32)) update_bin_preds_sum_op = tf.assign_add( bin_preds_sum, tf.cast(tf.bincount(bin_ids, weights=y_pred, minlength=nbins), dtype=tf.float32)) ece_update_op = _ece_from_bins(update_bin_counts_op, update_bin_true_sum_op, update_bin_preds_sum_op, name='update_op') ece = _ece_from_bins(bin_counts, bin_true_sum, bin_preds_sum, name='value') return ece, ece_update_op