def test_softsort(self): # Tests that the values are sorted (epsilon being small enough) x = tf.constant([3, 4, 1, 5, 2], dtype=tf.float32) eps = 1e-3 sinkhorn_threshold = 1e-3 values = ops.softsort(x, direction='ASCENDING', epsilon=eps, sinkhorn_threshold=sinkhorn_threshold) self.assertEqual(values.shape, x.shape) self.assertAllGreater(np.diff(values), 0.0) # Since epsilon is not very small, we cannot expect to retrieve the sorted # values with high precision. tolerance = 1e-1 self.assertAllClose(tf.sort(x), values, tolerance, tolerance) # Test descending sort. direction = 'DESCENDING' values = ops.softsort(x, direction=direction, epsilon=eps, sinkhorn_threshold=sinkhorn_threshold) self.assertEqual(values.shape, x.shape) self.assertAllLess(np.diff(values), 0.0) self.assertAllClose(tf.sort(x, direction=direction), values, tolerance, tolerance)
def test_softsort(self, topk): # Tests that the values are sorted (epsilon being small enough) x = tf.constant([3, 4, 1, 5, 2, 9, 12, 11, 8, 15], dtype=tf.float32) eps = 1e-3 sinkhorn_threshold = 1e-3 values = ops.softsort(x, direction='ASCENDING', topk=topk, epsilon=eps, threshold=sinkhorn_threshold) expect_shape = x.shape if topk is None else (topk, ) self.assertEqual(values.shape, expect_shape) self.assertAllGreater(np.diff(values), 0.0) # Since epsilon is not very small, we cannot expect to retrieve the sorted # values with high precision. tolerance = 1e-1 expected_values = tf.sort(x) if topk is not None: expected_values = expected_values[:topk] self.assertAllClose(expected_values, values, tolerance, tolerance) # Test descending sort. direction = 'DESCENDING' values = ops.softsort(x, direction=direction, topk=topk, epsilon=eps, threshold=sinkhorn_threshold) expected_values = tf.sort(x, direction=direction) if topk is not None: expected_values = expected_values[:topk] self.assertEqual(values.shape, expect_shape) self.assertAllLess(np.diff(values), 0.0) self.assertAllClose(expected_values, values, tolerance, tolerance)
def call(self, y_true, y_pred): error = tf.pow(tf.abs(tf.squeeze(y_pred) - y_true), self._power) target_weights, target_index = self._get_target_weights_and_indices() quantiles = ops.softsort(error, axis=0, target_weights=target_weights, **self._kwargs) return tf.gather(quantiles, target_index, axis=0)
def call(self, y_true, y_pred): error = tf.pow(tf.abs(tf.squeeze(y_pred) - y_true), self._power) target_weights = [ self._start_quantile, self._end_quantile - self._start_quantile, 1.0 - self._end_quantile ] quantiles = ops.softsort( error, axis=0, target_weights=target_weights, **self._kwargs) return quantiles[1]
def call(self, inputs): outputs = ops.softsort( inputs, axis=self._axis, topk=self._topk, **self._kwargs) if self._topk is not None: return outputs # For some reason, when doing a full sort, tf has a hard time computing the # shape of the output tensor. To specify the shape to tf we use the trick # to call tf.reshape. return tf.reshape(outputs, tf.shape(inputs))