Exemple #1
0
def remove_squeezable_dimensions(labels,
                                 predictions,
                                 expected_rank_diff=0,
                                 name=None):
    """Squeeze last dim if ranks differ from expected by exactly 1.

  In the common case where we expect shapes to match, `expected_rank_diff`
  defaults to 0, and we squeeze the last dimension of the larger rank if they
  differ by 1.

  But, for example, if `labels` contains class IDs and `predictions` contains 1
  probability per class, we expect `predictions` to have 1 more dimension than
  `labels`, so `expected_rank_diff` would be 1. In this case, we'd squeeze
  `labels` if `rank(predictions) - rank(labels) == 0`, and
  `predictions` if `rank(predictions) - rank(labels) == 2`.

  This will use static shape if available. Otherwise, it will add graph
  operations, which could result in a performance hit.

  Args:
    labels: Label values, a `Tensor` whose dimensions match `predictions`.
    predictions: Predicted values, a `Tensor` of arbitrary dimensions.
    expected_rank_diff: Expected result of `rank(predictions) - rank(labels)`.
    name: Name of the op.

  Returns:
    Tuple of `labels` and `predictions`, possibly with last dim squeezed.
  """
    with backend.name_scope(name or 'remove_squeezable_dimensions'):
        if not tf_utils.is_tensor_or_extension_type(predictions):
            predictions = tf.convert_to_tensor(predictions)
        if not tf_utils.is_tensor_or_extension_type(labels):
            labels = tf.convert_to_tensor(labels)
        predictions_shape = predictions.shape
        predictions_rank = predictions_shape.ndims
        labels_shape = labels.shape
        labels_rank = labels_shape.ndims
        if (labels_rank is not None) and (predictions_rank is not None):
            # Use static rank.
            rank_diff = predictions_rank - labels_rank
            if (rank_diff == expected_rank_diff + 1
                    and predictions_shape.dims[-1].is_compatible_with(1)):
                predictions = tf.squeeze(predictions, [-1])
            elif (rank_diff == expected_rank_diff - 1
                  and labels_shape.dims[-1].is_compatible_with(1)):
                labels = tf.squeeze(labels, [-1])
            return labels, predictions

        # Use dynamic rank.
        rank_diff = tf.rank(predictions) - tf.rank(labels)
        if (predictions_rank is
                None) or (predictions_shape.dims[-1].is_compatible_with(1)):
            predictions = tf.cond(tf.equal(expected_rank_diff + 1, rank_diff),
                                  lambda: tf.squeeze(predictions, [-1]),
                                  lambda: predictions)
        if (labels_rank is
                None) or (labels_shape.dims[-1].is_compatible_with(1)):
            labels = tf.cond(tf.equal(expected_rank_diff - 1, rank_diff),
                             lambda: tf.squeeze(labels, [-1]), lambda: labels)
        return labels, predictions
Exemple #2
0
    def test_is_tensor_or_extension_type_return_true_for_custom_ext_types(
            self):
        class DummyExtensionType(tf.experimental.ExtensionType):
            ...

        self.assertTrue(
            tf_utils.is_tensor_or_extension_type(DummyExtensionType()))
Exemple #3
0
 def test_is_tensor_or_extension_type_return_false_for_list(self):
   self.assertFalse(tf_utils.is_tensor_or_extension_type([1., 2., 3.]))
Exemple #4
0
 def test_is_tensor_or_extension_type_return_true_for_dense_tensor(self):
   self.assertTrue(tf_utils.is_tensor_or_extension_type(
       tf.constant([[1, 2], [3, 4]])))
Exemple #5
0
 def test_is_tensor_or_extension_type_return_true_for_sparse_tensor(self):
   self.assertTrue(tf_utils.is_tensor_or_extension_type(
       tf.sparse.from_dense([[1, 2], [3, 4]])))
Exemple #6
0
 def test_is_tensor_or_extension_type_return_true_for_ragged_tensor(self):
   self.assertTrue(tf_utils.is_tensor_or_extension_type(
       tf.ragged.constant([[1, 2], [3]])))