def test_failing_different_ragged_and_dense_ranks(self, x_list, y_list): x = tf.ragged.constant(x_list) y = tf.ragged.constant(y_list) with self.assertRaises(ValueError): # pylint: disable=g-error-prone-assert-raises [x, y ], _ = metrics_utils.ragged_assert_compatible_and_get_flat_values( [x, y])
def update_state(self, y_true, y_pred, sample_weight=None): """Accumulates metric statistics. `y_true` and `y_pred` should have the same shape. Args: y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`. y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`. sample_weight: Optional `sample_weight` acts as a coefficient for the metric. If a scalar is provided, then the metric is simply scaled by the given value. If `sample_weight` is a tensor of size `[batch_size]`, then the metric for each sample of the batch is rescaled by the corresponding element in the `sample_weight` vector. If the shape of `sample_weight` is `[batch_size, d0, .. dN-1]` (or can be broadcasted to this shape), then each metric element of `y_pred` is scaled by the corresponding value of `sample_weight`. (Note on `dN-1`: all metric functions reduce by 1 dimension, usually the last axis (-1)). Returns: Update op. """ y_true = tf.cast(y_true, self._dtype) y_pred = tf.cast(y_pred, self._dtype) [y_true, y_pred], sample_weight = ( metrics_utils.ragged_assert_compatible_and_get_flat_values( [y_true, y_pred], sample_weight)) y_pred, y_true = losses_utils.squeeze_or_expand_dimensions( y_pred, y_true) ag_fn = tf.__internal__.autograph.tf_convert( self._fn, tf.__internal__.autograph.control_status_ctx()) matches = ag_fn(y_true, y_pred, **self._fn_kwargs) return super().update_state(matches, sample_weight=sample_weight)
def update_state(self, y_true, y_pred, sample_weight=None): dtype_true = tf.dtypes.as_dtype(y_true.dtype) scale_true = dtype_true.max if dtype_true.is_integer else 1. y_true = tf.cast(y_true, self._dtype) / scale_true dtype_pred = tf.dtypes.as_dtype(y_pred.dtype) scale_pred = dtype_pred.max if dtype_pred.is_integer else 1. y_pred = tf.cast(y_pred, self._dtype) / scale_pred if sample_weight is not None: sample_weight = tf.cast(sample_weight, self._dtype) [ y_true, y_pred ], sample_weight = metrics_utils.ragged_assert_compatible_and_get_flat_values( [y_true, y_pred], sample_weight) if sample_weight is None: y_pred, y_true = losses_utils.squeeze_or_expand_dimensions( y_pred, y_true, sample_weight) else: y_pred, y_true, sample_weight = losses_utils.squeeze_or_expand_dimensions( y_pred, y_true, sample_weight) values = connectivity_error(y_true, y_pred, self.step, sample_weight) return super().update_state(values)
def test_failing_different_ragged_ranks(self): dt = tf.constant([[[1, 2]]]) # adding a ragged dimension x = tf.RaggedTensor.from_row_splits(dt, row_splits=[0, 1]) y = tf.ragged.constant([[[[1, 2]]]]) with self.assertRaises(ValueError): # pylint: disable=g-error-prone-assert-raises [x, y], _ = \ metrics_utils.ragged_assert_compatible_and_get_flat_values([x, y])
def test_passing_both_ragged_with_mask(self, x_list, y_list, mask_list): x = tf.ragged.constant(x_list) y = tf.ragged.constant(y_list) mask = tf.ragged.constant(mask_list) [x, y], mask = \ metrics_utils.ragged_assert_compatible_and_get_flat_values([x, y], mask) x.shape.assert_is_compatible_with(y.shape) y.shape.assert_is_compatible_with(mask.shape)
def test_failing_different_mask_ranks(self, x_list, y_list, mask_list): x = tf.ragged.constant(x_list) y = tf.ragged.constant(y_list) mask = tf.ragged.constant(mask_list) with self.assertRaises(ValueError): [ x, y, ], _ = metrics_utils.ragged_assert_compatible_and_get_flat_values( [x, y], mask)
def update_state(self, values, sample_weight=None): """Accumulates statistics for computing the metric. Args: values: Per-example value. sample_weight: Optional weighting of each example. Defaults to 1. Returns: Update op. """ [values], sample_weight = \ metrics_utils.ragged_assert_compatible_and_get_flat_values( [values], sample_weight) try: values = tf.cast(values, self._dtype) except (ValueError, TypeError): msg = ( 'The output of a metric function can only be a single Tensor. ' f'Received: {values}. ') if isinstance(values, dict): msg += ( 'To return a dict of values, implement a custom Metric ' 'subclass.') raise RuntimeError(msg) if sample_weight is not None: sample_weight = tf.cast(sample_weight, self._dtype) # Update dimensions of weights to match with values if possible. values, _, sample_weight = losses_utils.squeeze_or_expand_dimensions( values, sample_weight=sample_weight) try: # Broadcast weights if possible. sample_weight = tf.__internal__.ops.broadcast_weights( sample_weight, values) except ValueError: # Reduce values to same ndim as weight array ndim = backend.ndim(values) weight_ndim = backend.ndim(sample_weight) if self.reduction == metrics_utils.Reduction.SUM: values = tf.reduce_sum(values, axis=list(range(weight_ndim, ndim))) else: values = tf.reduce_mean(values, axis=list(range(weight_ndim, ndim))) values = tf.multiply(values, sample_weight) value_sum = tf.reduce_sum(values) with tf.control_dependencies([value_sum]): update_total_op = self.total.assign_add(value_sum) # Exit early if the reduction doesn't have a denominator. if self.reduction == metrics_utils.Reduction.SUM: return update_total_op # Update `count` for reductions that require a denominator. if self.reduction == metrics_utils.Reduction.SUM_OVER_BATCH_SIZE: num_values = tf.cast(tf.size(values), self._dtype) elif self.reduction == metrics_utils.Reduction.WEIGHTED_MEAN: if sample_weight is None: num_values = tf.cast(tf.size(values), self._dtype) else: num_values = tf.reduce_sum(sample_weight) else: raise NotImplementedError( f'Reduction "{self.reduction}" not implemented. Expected ' '"sum", "weighted_mean", or "sum_over_batch_size".') with tf.control_dependencies([update_total_op]): return self.count.assign_add(num_values)
def test_passing_one_dense_tensor(self, x_list): x = tf.constant(x_list) [x ], _ = metrics_utils.ragged_assert_compatible_and_get_flat_values([x])
def test_passing_dense_tensors(self, x_list, y_list): x = tf.constant(x_list) y = tf.constant(y_list) [x, y], _ = metrics_utils.ragged_assert_compatible_and_get_flat_values( [x, y]) x.shape.assert_is_compatible_with(y.shape)