def test_argsort(): # Set up rng = np.random.default_rng(seed=utt.fetch_seed()) m_val = rng.random((3, 2)) v_val = rng.random((4)) # Example 1 a = dmatrix() w = argsort(a) f = aesara.function([a], w) gv = f(m_val) gt = np.argsort(m_val) utt.assert_allclose(gv, gt) # Example 2 a = dmatrix() axis = lscalar() w = argsort(a, axis) f = aesara.function([a, axis], w) for axis_val in 0, 1: gv = f(m_val, axis_val) gt = np.argsort(m_val, axis_val) utt.assert_allclose(gv, gt) # Example 3 a = dvector() w2 = argsort(a) f = aesara.function([a], w2) gv = f(v_val) gt = np.argsort(v_val) utt.assert_allclose(gv, gt) # Example 4 a = dmatrix() axis = lscalar() l = argsort(a, axis, "mergesort") f = aesara.function([a, axis], l) for axis_val in 0, 1: gv = f(m_val, axis_val) gt = np.argsort(m_val, axis_val) utt.assert_allclose(gv, gt) # Example 5 a = dmatrix() axis = lscalar() a1 = ArgSortOp("mergesort", []) a2 = ArgSortOp("quicksort", []) # All the below should give true assert a1 != a2 assert a1 == ArgSortOp("mergesort", []) assert a2 == ArgSortOp("quicksort", []) # Example 6: Testing axis=None a = dmatrix() w2 = argsort(a, None) f = aesara.function([a], w2) gv = f(m_val) gt = np.argsort(m_val, None) utt.assert_allclose(gv, gt)
def test_argsort_grad(): # Testing grad of argsort data = np.random.rand(2, 3).astype(aesara.config.floatX) utt.verify_grad(lambda x: argsort(x, axis=-1), [data]) data = np.random.rand(2, 3, 4, 5).astype(aesara.config.floatX) utt.verify_grad(lambda x: argsort(x, axis=-3), [data]) data = np.random.rand(2, 3, 3).astype(aesara.config.floatX) utt.verify_grad(lambda x: argsort(x, axis=2), [data])
def test_argsort_grad(): rng = np.random.default_rng(seed=utt.fetch_seed()) # Testing grad of argsort data = rng.random((2, 3)).astype(aesara.config.floatX) utt.verify_grad(lambda x: argsort(x, axis=-1), [data]) data = rng.random((2, 3, 4, 5)).astype(aesara.config.floatX) utt.verify_grad(lambda x: argsort(x, axis=-3), [data]) data = rng.random((2, 3, 3)).astype(aesara.config.floatX) utt.verify_grad(lambda x: argsort(x, axis=2), [data])
def argsort(self, axis=-1, kind="quicksort", order=None): """See `aesara.tensor.sort.argsort`.""" from aesara.tensor.sort import argsort return argsort(self, axis, kind, order)