def test_softquantile(self, quantile, axis): x = np.array([[[7.9, 1.2, 5.5, 9.8, 3.5], [7.9, 12.2, 45.5, 9.8, 3.5], [17.9, 14.2, 55.5, 9.8, 3.5]], [[4.9, 1.2, 15.5, 4.8, 3.5], [7.9, 1.2, 5.5, 7.8, 2.5], [1.9, 4.2, 55.5, 9.8, 1.5]]]) qs = ops.softquantile(x, quantile, axis=axis) s = list(x.shape) s.pop(axis) self.assertTupleEqual(qs.shape, tuple(s)) self.assertAllClose(qs, np.quantile(x, quantile, axis=axis), True, rtol=1e-2)
def test_softquantile(self, quantile, axis): x = jnp.array([[[7.9, 1.2, 5.5, 9.8, 3.5], [7.9, 12.2, 45.5, 9.8, 3.5], [17.9, 14.2, 55.5, 9.8, 3.5]], [[4.9, 1.2, 15.5, 4.8, 3.5], [7.9, 1.2, 5.5, 7.8, 2.5], [1.9, 4.2, 55.5, 9.8, 1.5]]]) qs = ops.softquantile(x, quantile, axis=axis, threshold=1e-3, epsilon=1e-3) s = list(x.shape) s.pop(axis) self.assertTupleEqual(qs.shape, tuple(s)) np.testing.assert_allclose(qs, jnp.quantile(x, quantile, axis=axis), rtol=1e-2)