Exemplo n.º 1
0
 def test_softquantile(self, quantile, axis):
     x = np.array([[[7.9, 1.2, 5.5, 9.8, 3.5], [7.9, 12.2, 45.5, 9.8, 3.5],
                    [17.9, 14.2, 55.5, 9.8, 3.5]],
                   [[4.9, 1.2, 15.5, 4.8, 3.5], [7.9, 1.2, 5.5, 7.8, 2.5],
                    [1.9, 4.2, 55.5, 9.8, 1.5]]])
     qs = ops.softquantile(x, quantile, axis=axis)
     s = list(x.shape)
     s.pop(axis)
     self.assertTupleEqual(qs.shape, tuple(s))
     self.assertAllClose(qs,
                         np.quantile(x, quantile, axis=axis),
                         True,
                         rtol=1e-2)
Exemplo n.º 2
0
 def test_softquantile(self, quantile, axis):
     x = jnp.array([[[7.9, 1.2, 5.5, 9.8, 3.5], [7.9, 12.2, 45.5, 9.8, 3.5],
                     [17.9, 14.2, 55.5, 9.8, 3.5]],
                    [[4.9, 1.2, 15.5, 4.8, 3.5], [7.9, 1.2, 5.5, 7.8, 2.5],
                     [1.9, 4.2, 55.5, 9.8, 1.5]]])
     qs = ops.softquantile(x,
                           quantile,
                           axis=axis,
                           threshold=1e-3,
                           epsilon=1e-3)
     s = list(x.shape)
     s.pop(axis)
     self.assertTupleEqual(qs.shape, tuple(s))
     np.testing.assert_allclose(qs,
                                jnp.quantile(x, quantile, axis=axis),
                                rtol=1e-2)