Example #1
0
 def test_sort(self):
     s = ops.softsort(self.x, axis=-1, threshold=1e-3, epsilon=1e-3)
     self.assertEqual(s.shape, self.x.shape)
     deltas = np.diff(s, axis=-1) > 0
     self.assertAllClose(deltas,
                         np.ones(deltas.shape, dtype=bool),
                         check_dtypes=True)
Example #2
0
 def test_sort_descending(self):
   x = self.x[0][0]
   s = ops.softsort(x, axis=-1, direction='DESCENDING',
                    threshold=1e-3, epsilon=1e-3)
   self.assertEqual(s.shape, x.shape)
   deltas = np.diff(s, axis=-1) < 0
   self.assertAllClose(
       deltas, np.ones(deltas.shape, dtype=bool), check_dtypes=True)
Example #3
0
 def test_sort_descending(self):
     x = self.x[0][0]
     s = ops.softsort(x,
                      axis=-1,
                      direction='DESCENDING',
                      threshold=1e-3,
                      epsilon=1e-3)
     self.assertEqual(s.shape, x.shape)
     deltas = jnp.diff(s, axis=-1) < 0
     np.testing.assert_allclose(deltas, jnp.ones(deltas.shape, dtype=bool))
Example #4
0
 def test_sort_descending(self):
     x = self.x[0][0]
     s = ops.softsort(x, axis=-1, direction='DESCENDING')
     self.assertEqual(s.shape, x.shape)
     deltas = np.diff(s, axis=-1) < 0
     self.assertAllClose(deltas, np.ones(deltas.shape, dtype=bool), True)
Example #5
0
 def test_sort(self):
     s = ops.softsort(self.x, axis=-1)
     self.assertEqual(s.shape, self.x.shape)
     deltas = np.diff(s, axis=-1) > 0
     self.assertAllClose(deltas, np.ones(deltas.shape, dtype=bool), True)
Example #6
0
 def test_sort(self):
     s = ops.softsort(self.x, axis=-1, threshold=1e-3, epsilon=1e-3)
     self.assertEqual(s.shape, self.x.shape)
     deltas = jnp.diff(s, axis=-1) > 0
     np.testing.assert_allclose(deltas, jnp.ones(deltas.shape, dtype=bool))