def test_failing_different_ragged_and_dense_ranks(self, x_list, y_list): x = ragged_factory_ops.constant(x_list) y = ragged_factory_ops.constant(y_list) with self.assertRaises(ValueError): # pylint: disable=g-error-prone-assert-raises [x, y ], _ = metrics_utils.ragged_assert_compatible_and_get_flat_values( [x, y])
def test_failing_different_ragged_ranks(self): dt = constant_op.constant([[[1, 2]]]) # adding a ragged dimension x = ragged_tensor.RaggedTensor.from_row_splits(dt, row_splits=[0, 1]) y = ragged_factory_ops.constant([[[[1, 2]]]]) with self.assertRaises(ValueError): # pylint: disable=g-error-prone-assert-raises [x, y], _ = \ metrics_utils.ragged_assert_compatible_and_get_flat_values([x, y])
def test_passing_both_ragged_with_mask(self, x_list, y_list, mask_list): x = ragged_factory_ops.constant(x_list) y = ragged_factory_ops.constant(y_list) mask = ragged_factory_ops.constant(mask_list) [x, y], mask = \ metrics_utils.ragged_assert_compatible_and_get_flat_values([x, y], mask) x.shape.assert_is_compatible_with(y.shape) y.shape.assert_is_compatible_with(mask.shape)
def update_state(self, y_true, y_pred, sample_weight=None): y_true = math_ops.cast(y_true, self._dtype) y_pred = math_ops.cast(y_pred, self._dtype) [y_true, y_pred], sample_weight = \ metrics_utils.ragged_assert_compatible_and_get_flat_values( [y_true, y_pred], sample_weight) y_pred, y_true = tf_losses_utils.squeeze_or_expand_dimensions( y_pred, y_true) return super(PCTR, self).update_state( y_pred, sample_weight=sample_weight)
def update_state(self, y_true, y_pred, w, sample_weight=None): y_true = math_ops.cast(y_true, self._dtype) y_pred = math_ops.cast(y_pred, self._dtype) w = math_ops.cast(w, self._dtype) [y_true, y_pred], sample_weight = \ metrics_utils.ragged_assert_compatible_and_get_flat_values( [y_true, y_pred], sample_weight) y_pred, y_true = tf_losses_utils.squeeze_or_expand_dimensions( y_pred, y_true) matches = self._fn(y_true, y_pred, w, **self._fn_kwargs) return super(CustomMeanMetricWrapper2, self).update_state( matches, sample_weight=sample_weight)
def update_state(self, y_true, y_pred, sample_weight=None): y_true = math_ops.cast(y_true, self._dtype) y_pred = math_ops.cast(y_pred, self._dtype) [y_true, y_pred], sample_weight = \ metrics_utils.ragged_assert_compatible_and_get_flat_values( [y_true, y_pred], sample_weight) y_pred, y_true = tf_losses_utils.squeeze_or_expand_dimensions( y_pred, y_true) ctr_sum = math_ops.reduce_sum(y_true) with ops.control_dependencies([ctr_sum]): ctr_sum_op = self.ctr_total.assign_add(ctr_sum) pctr_sum = math_ops.reduce_sum(y_pred) with ops.control_dependencies([pctr_sum]): pctr_sum_op = self.pctr_total.assign_add(pctr_sum) return control_flow_ops.group(ctr_sum_op, pctr_sum_op)
def update_state(self, y_true, y_pred, sample_weight=None): """Accumulates metric statistics. `y_true` and `y_pred` should have the same shape. Args: y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`. y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`. sample_weight: Optional `sample_weight` acts as a coefficient for the metric. If a scalar is provided, then the metric is simply scaled by the given value. If `sample_weight` is a tensor of size `[batch_size]`, then the metric for each sample of the batch is rescaled by the corresponding element in the `sample_weight` vector. If the shape of `sample_weight` is `[batch_size, d0, .. dN-1]` (or can be broadcasted to this shape), then each metric element of `y_pred` is scaled by the corresponding value of `sample_weight`. (Note on `dN-1`: all metric functions reduce by 1 dimension, usually the last axis (-1)). Returns: Update op. """ y_true = math_ops.cast(y_true, self._dtype) # if type(y_pred) == dict: y_pred = {k: math_ops.cast(v, self._dtype) for k, v in y_pred.items()} # else: # y_pred = math_ops.cast(y_pred, self._dtype) [ y_true, y_pred, ], sample_weight = metrics_utils.ragged_assert_compatible_and_get_flat_values( [y_true, y_pred], sample_weight) # y_pred, y_true = tf_losses_utils.squeeze_or_expand_dimensions( # y_pred, y_true) ag_fn = autograph.tf_convert(self._fn, ag_ctx.control_status_ctx()) matches = ag_fn(y_true, y_pred, **self._fn_kwargs) return super(MeanMetricWrapper, self).update_state(matches, sample_weight=sample_weight)
def test_passing_one_dense_tensor(self, x_list): x = constant_op.constant(x_list) [x], _ = metrics_utils.ragged_assert_compatible_and_get_flat_values([x])
def test_passing_dense_tensors(self, x_list, y_list): x = constant_op.constant(x_list) y = constant_op.constant(y_list) [x, y], _ = metrics_utils.ragged_assert_compatible_and_get_flat_values([x, y]) x.shape.assert_is_compatible_with(y.shape)
def test_passing_one_ragged(self, x_list): x = ragged_factory_ops.constant(x_list) [x], _ = metrics_utils.ragged_assert_compatible_and_get_flat_values([x])
def update_state(self, y_true, y_pred, sample_weight=None): # Cast inputs y_pred = tf.convert_to_tensor(y_pred) y_true = tf.cast(y_true, dtype=y_pred.dtype) # Transform inputs [y_pred, y_true ], _ = metrics_utils.ragged_assert_compatible_and_get_flat_values( [y_pred, y_true], sample_weight) # Get threshold properties if isinstance(self.thresholds, list): num_thresholds = len(self.thresholds) else: num_thresholds = len(list(self.thresholds)) # Check input values and adjust shapes with ops.control_dependencies([ check_ops.assert_greater_equal( y_pred, tf.cast(0.0, dtype=y_pred.dtype), message='predictions must be >= 0'), check_ops.assert_less_equal(y_pred, tf.cast(1.0, dtype=y_pred.dtype), message='predictions must be <= 1') ]): if sample_weight is None: y_pred, y_true = tf_losses_utils.squeeze_or_expand_dimensions( y_pred, y_true) else: y_pred, y_true, sample_weight = ( tf_losses_utils.squeeze_or_expand_dimensions( y_pred, y_true, sample_weight=sample_weight)) # Check shape compatibility y_pred.shape.assert_is_compatible_with(y_true.shape) # Check if num_classes corresponds to y_pred if self.average != 'micro': tf.debugging.assert_shapes( shapes=[(y_pred, (..., self.num_classes))], data=y_pred, summarize=10, message='num_classes must correspond to the prediction') # Filter top k if self.top_k is not None: y_pred = metrics_utils._filter_top_k(y_pred, self.top_k) # Select class id if self.class_id is not None: y_true = y_true[..., self.class_id] y_pred = y_pred[..., self.class_id] # Get prediction shape pred_shape = tf.shape(y_pred) num_predictions = pred_shape[0] # Set label shapes if y_pred.shape.ndims == 1: num_labels = 1 else: num_labels = K.prod(pred_shape[1:], axis=0) # Flatten predicitons and labels predictions_extra_dim = tf.reshape(y_pred, [1, -1]) labels_extra_dim = tf.reshape(tf.cast(y_true, dtype=tf.bool), [1, -1]) # Tile the thresholds for every prediction thresh_pretile_shape = [num_thresholds, -1] thresh_tiles = [1, num_predictions * num_labels] data_tiles = [num_thresholds, 1] thresh_tiled = tf.tile( tf.reshape(tf.constant(self.thresholds, dtype=y_pred.dtype), thresh_pretile_shape), tf.stack(thresh_tiles)) # Tile the predictions for every threshold preds_tiled = tf.tile(predictions_extra_dim, data_tiles) # Compare predictions and threshold pred_is_pos = tf.greater(preds_tiled, thresh_tiled) # Tile labels by number of thresholds label_is_pos = tf.tile(labels_extra_dim, data_tiles) # Set sample weights if sample_weight is not None: sample_weight = weights_broadcast_ops.broadcast_weights( tf.cast(sample_weight, dtype=y_pred.dtype), y_pred) weights_tiled = tf.tile(tf.reshape(sample_weight, thresh_tiles), data_tiles) else: weights_tiled = None def _weighted_assign_add(label, pred, weights, var): label_and_pred = tf.cast(tf.logical_and(label, pred), dtype=y_pred.dtype) if weights is not None: label_and_pred *= weights if self.average != 'micro': label_and_pred = tf.reshape(label_and_pred, shape=[-1, self.num_classes]) return var.assign_add(tf.reduce_sum(label_and_pred, self.axis)) # Set return value update_ops = [] # Update true positives update_ops.append( _weighted_assign_add(label_is_pos, pred_is_pos, weights_tiled, self.true_positives)) # Update false negatives pred_is_neg = tf.logical_not(pred_is_pos) update_ops.append( _weighted_assign_add(label_is_pos, pred_is_neg, weights_tiled, self.false_negatives)) # Update false positives label_is_neg = tf.logical_not(label_is_pos) update_ops.append( _weighted_assign_add(label_is_neg, pred_is_pos, weights_tiled, self.false_positives)) return tf.group(update_ops)
def update_state(self, y_true, y_pred, sample_weight=None): # Cast inputs y_pred = tf.convert_to_tensor(y_pred) y_true = tf.cast(y_true, dtype=y_pred.dtype) # Transform inputs [y_pred, y_true ], _ = metrics_utils.ragged_assert_compatible_and_get_flat_values( [y_pred, y_true], sample_weight) # Check input values and adjust shapes with ops.control_dependencies([ check_ops.assert_greater_equal( y_pred, tf.cast(0.0, dtype=y_pred.dtype), message='predictions must be >= 0'), check_ops.assert_less_equal(y_pred, tf.cast(1.0, dtype=y_pred.dtype), message='predictions must be <= 1') ]): if sample_weight is None: y_pred, y_true = tf_losses_utils.squeeze_or_expand_dimensions( y_pred, y_true) else: y_pred, y_true, sample_weight = ( tf_losses_utils.squeeze_or_expand_dimensions( y_pred, y_true, sample_weight=sample_weight)) # Check shape compatibility y_pred.shape.assert_is_compatible_with(y_true.shape) # Get prediction shape pred_shape = tf.shape(y_pred) num_predictions = pred_shape[0] # Get lables (decode one-hot) y_pred_labels = K.flatten(tf.argmax(y_pred, axis=-1)) y_true_labels = K.flatten(tf.argmax(y_true, axis=-1)) # Set sample weights if sample_weight is not None: sample_weight = weights_broadcast_ops.broadcast_weights( tf.cast(sample_weight, dtype=y_pred.dtype), y_pred) weights_tiled = tf.gather( K.flatten(sample_weight), tf.range(start=0, limit=num_predictions * self.num_classes, delta=self.num_classes, dtype=tf.int64)) else: weights_tiled = None def _weighted_assign_add(label, pred, weights, var): return var.assign_add( tf.math.confusion_matrix(labels=label, predictions=pred, num_classes=self.num_classes, weights=weights, dtype=self.dtype)) # Set return value update_ops = [] # Update confusion matrix update_ops.append( _weighted_assign_add(y_true_labels, y_pred_labels, weights_tiled, self.confusion_matrix)) return tf.group(update_ops)