def test_softsort(self):
        # Tests that the values are sorted (epsilon being small enough)
        x = tf.constant([3, 4, 1, 5, 2], dtype=tf.float32)
        eps = 1e-3
        sinkhorn_threshold = 1e-3
        values = ops.softsort(x,
                              direction='ASCENDING',
                              epsilon=eps,
                              sinkhorn_threshold=sinkhorn_threshold)
        self.assertEqual(values.shape, x.shape)
        self.assertAllGreater(np.diff(values), 0.0)

        # Since epsilon is not very small, we cannot expect to retrieve the sorted
        # values with high precision.
        tolerance = 1e-1
        self.assertAllClose(tf.sort(x), values, tolerance, tolerance)

        # Test descending sort.
        direction = 'DESCENDING'
        values = ops.softsort(x,
                              direction=direction,
                              epsilon=eps,
                              sinkhorn_threshold=sinkhorn_threshold)
        self.assertEqual(values.shape, x.shape)
        self.assertAllLess(np.diff(values), 0.0)
        self.assertAllClose(tf.sort(x, direction=direction), values, tolerance,
                            tolerance)
Exemple #2
0
    def test_softsort(self, topk):
        # Tests that the values are sorted (epsilon being small enough)
        x = tf.constant([3, 4, 1, 5, 2, 9, 12, 11, 8, 15], dtype=tf.float32)
        eps = 1e-3
        sinkhorn_threshold = 1e-3
        values = ops.softsort(x,
                              direction='ASCENDING',
                              topk=topk,
                              epsilon=eps,
                              threshold=sinkhorn_threshold)
        expect_shape = x.shape if topk is None else (topk, )
        self.assertEqual(values.shape, expect_shape)
        self.assertAllGreater(np.diff(values), 0.0)

        # Since epsilon is not very small, we cannot expect to retrieve the sorted
        # values with high precision.
        tolerance = 1e-1
        expected_values = tf.sort(x)
        if topk is not None:
            expected_values = expected_values[:topk]
        self.assertAllClose(expected_values, values, tolerance, tolerance)

        # Test descending sort.
        direction = 'DESCENDING'
        values = ops.softsort(x,
                              direction=direction,
                              topk=topk,
                              epsilon=eps,
                              threshold=sinkhorn_threshold)
        expected_values = tf.sort(x, direction=direction)
        if topk is not None:
            expected_values = expected_values[:topk]
        self.assertEqual(values.shape, expect_shape)
        self.assertAllLess(np.diff(values), 0.0)
        self.assertAllClose(expected_values, values, tolerance, tolerance)
 def call(self, y_true, y_pred):
     error = tf.pow(tf.abs(tf.squeeze(y_pred) - y_true), self._power)
     target_weights, target_index = self._get_target_weights_and_indices()
     quantiles = ops.softsort(error,
                              axis=0,
                              target_weights=target_weights,
                              **self._kwargs)
     return tf.gather(quantiles, target_index, axis=0)
Exemple #4
0
 def call(self, y_true, y_pred):
   error = tf.pow(tf.abs(tf.squeeze(y_pred) - y_true), self._power)
   target_weights = [
       self._start_quantile,
       self._end_quantile - self._start_quantile,
       1.0 - self._end_quantile
   ]
   quantiles = ops.softsort(
       error, axis=0, target_weights=target_weights, **self._kwargs)
   return quantiles[1]
Exemple #5
0
  def call(self, inputs):
    outputs = ops.softsort(
        inputs, axis=self._axis, topk=self._topk, **self._kwargs)

    if self._topk is not None:
      return outputs
    # For some reason, when doing a full sort, tf has a hard time computing the
    # shape of the output tensor. To specify the shape to tf we use the trick
    # to call tf.reshape.
    return tf.reshape(outputs, tf.shape(inputs))