def test_ranks_one_based(self): ranks = ops.softranks(self.x, axis=-1, zero_based=False, threshold=1e-3, epsilon=1e-3) self.assertEqual(ranks.shape, self.x.shape) true_ranks = np.argsort(np.argsort(self.x, axis=-1), axis=-1) + 1 self.assertAllClose(ranks, true_ranks, check_dtypes=False, atol=1e-3)
def test_ranks_one_based(self): ranks = ops.softranks(self.x, axis=-1, zero_based=False, threshold=1e-3, epsilon=1e-3) self.assertEqual(ranks.shape, self.x.shape) true_ranks = jnp.argsort(jnp.argsort(self.x, axis=-1), axis=-1) + 1 np.testing.assert_allclose(ranks, true_ranks, atol=1e-3)
def test_ranks_descending(self): ranks = ops.softranks( self.x, axis=-1, zero_based=True, direction='DESCENDING', threshold=1e-3, epsilon=1e-3) self.assertEqual(ranks.shape, self.x.shape) max_rank = self.x.shape[-1] - 1 true_ranks = max_rank - np.argsort(np.argsort(self.x, axis=-1), axis=-1) self.assertAllClose(ranks, true_ranks, check_dtypes=False, atol=1e-3)
def test_ranks_descending(self): ranks = ops.softranks(self.x, axis=-1, zero_based=True, direction='DESCENDING', threshold=1e-3, epsilon=1e-3) self.assertEqual(ranks.shape, self.x.shape) max_rank = self.x.shape[-1] - 1 true_ranks = max_rank - jnp.argsort(jnp.argsort(self.x, axis=-1), axis=-1) np.testing.assert_allclose(ranks, true_ranks, atol=1e-3)
def test_ranks_one_based(self): ranks = ops.softranks(self.x, axis=-1, zero_based=False) self.assertEqual(ranks.shape, self.x.shape) true_ranks = np.argsort(np.argsort(self.x, axis=-1), axis=-1) + 1 self.assertAllClose(ranks, true_ranks, False, atol=1e-3)
def test_ranks(self): ranks = ops.softranks(self.x, axis=-1, threshold=1e-3, epsilon=1e-3) self.assertEqual(ranks.shape, self.x.shape) true_ranks = np.argsort(np.argsort(self.x, axis=-1), axis=-1) self.assertAllClose(ranks, true_ranks, False, atol=1e-3)