def remove_squeezable_dimensions(labels, predictions, expected_rank_diff=0, name=None): """Squeeze last dim if ranks differ from expected by exactly 1. In the common case where we expect shapes to match, `expected_rank_diff` defaults to 0, and we squeeze the last dimension of the larger rank if they differ by 1. But, for example, if `labels` contains class IDs and `predictions` contains 1 probability per class, we expect `predictions` to have 1 more dimension than `labels`, so `expected_rank_diff` would be 1. In this case, we'd squeeze `labels` if `rank(predictions) - rank(labels) == 0`, and `predictions` if `rank(predictions) - rank(labels) == 2`. This will use static shape if available. Otherwise, it will add graph operations, which could result in a performance hit. Args: labels: Label values, a `Tensor` whose dimensions match `predictions`. predictions: Predicted values, a `Tensor` of arbitrary dimensions. expected_rank_diff: Expected result of `rank(predictions) - rank(labels)`. name: Name of the op. Returns: Tuple of `labels` and `predictions`, possibly with last dim squeezed. """ with backend.name_scope(name or 'remove_squeezable_dimensions'): if not tf_utils.is_tensor_or_extension_type(predictions): predictions = tf.convert_to_tensor(predictions) if not tf_utils.is_tensor_or_extension_type(labels): labels = tf.convert_to_tensor(labels) predictions_shape = predictions.shape predictions_rank = predictions_shape.ndims labels_shape = labels.shape labels_rank = labels_shape.ndims if (labels_rank is not None) and (predictions_rank is not None): # Use static rank. rank_diff = predictions_rank - labels_rank if (rank_diff == expected_rank_diff + 1 and predictions_shape.dims[-1].is_compatible_with(1)): predictions = tf.squeeze(predictions, [-1]) elif (rank_diff == expected_rank_diff - 1 and labels_shape.dims[-1].is_compatible_with(1)): labels = tf.squeeze(labels, [-1]) return labels, predictions # Use dynamic rank. rank_diff = tf.rank(predictions) - tf.rank(labels) if (predictions_rank is None) or (predictions_shape.dims[-1].is_compatible_with(1)): predictions = tf.cond(tf.equal(expected_rank_diff + 1, rank_diff), lambda: tf.squeeze(predictions, [-1]), lambda: predictions) if (labels_rank is None) or (labels_shape.dims[-1].is_compatible_with(1)): labels = tf.cond(tf.equal(expected_rank_diff - 1, rank_diff), lambda: tf.squeeze(labels, [-1]), lambda: labels) return labels, predictions
def test_is_tensor_or_extension_type_return_true_for_custom_ext_types( self): class DummyExtensionType(tf.experimental.ExtensionType): ... self.assertTrue( tf_utils.is_tensor_or_extension_type(DummyExtensionType()))
def test_is_tensor_or_extension_type_return_false_for_list(self): self.assertFalse(tf_utils.is_tensor_or_extension_type([1., 2., 3.]))
def test_is_tensor_or_extension_type_return_true_for_dense_tensor(self): self.assertTrue(tf_utils.is_tensor_or_extension_type( tf.constant([[1, 2], [3, 4]])))
def test_is_tensor_or_extension_type_return_true_for_sparse_tensor(self): self.assertTrue(tf_utils.is_tensor_or_extension_type( tf.sparse.from_dense([[1, 2], [3, 4]])))
def test_is_tensor_or_extension_type_return_true_for_ragged_tensor(self): self.assertTrue(tf_utils.is_tensor_or_extension_type( tf.ragged.constant([[1, 2], [3]])))