Esempio n. 1
0
    def test_argtopk_nd(self, shp, k_, dtype, sorted, idx_dtype):
        ndim = len(shp)
        for axis in range(-ndim, ndim):
            if isinstance(k_, str):
                k = eval(k_.replace("n", str(shp[axis])))
            else:
                k = k_

            if k == 0:
                continue

            x = tensor(name="x",
                       broadcastable=(False, ) * len(shp),
                       dtype=dtype)
            y = argtopk(x, k, axis=axis, sorted=sorted, idx_dtype=idx_dtype)
            fn = aesara.function([x], y, mode=self.mode)
            assert any([
                isinstance(n.op, self.op_class)
                for n in fn.maker.fgraph.apply_nodes
            ])
            size = reduce(int.__mul__, shp)
            xval = gen_unique_vector(size, dtype).reshape(shp)
            yval = fn(xval)
            idx = slice(-k, None) if k > 0 else slice(-k)
            l = axis % ndim
            r = ndim - l
            idx = (slice(None), ) * l + (idx, ) + (slice(None), ) * (r - 1)
            goal = np.argsort(xval, axis=axis)[idx].astype(idx_dtype)

            assert np.all(np.sort(yval, axis=axis) == np.sort(goal, axis=axis))
Esempio n. 2
0
    def test_argtopk_1d_collision(self, size, k, dtype, sorted):
        # with non-unique kth max value
        if isinstance(k, str):
            k = eval(k.replace("n", str(size)))

        x = vector(name="x", dtype=dtype)
        y = argtopk(x, k, sorted=sorted, idx_dtype="int32")
        # DebugMode won't like the index change on collision on CPU
        # So don't use DebugMode here.
        mode = self.mode
        if isinstance(self.mode, aesara.compile.debugmode.DebugMode):
            mode = Mode(optimizer=mode.optimizer)
        fn = aesara.function([x], y, mode=mode)
        assert any([
            isinstance(n.op, self.op_class)
            for n in fn.maker.fgraph.apply_nodes
        ])
        rng = np.random.default_rng(utt.fetch_seed())
        xval = np.repeat(
            rng.uniform(-100.0, 100.0, size=size // 2).astype(dtype), 2)
        xval = xval[rng.permutation(size)]
        yval = fn(xval)
        idx = slice(-k, None) if k > 0 else slice(-k)
        goal = np.argsort(xval)[idx].astype("int32")
        utt.assert_allclose(np.sort(xval[yval]), np.sort(xval[goal]))
Esempio n. 3
0
 def test_argtopk_sanity(self, dtype, idx_dtype, axis, sorted):
     x = vector(name="x", dtype=dtype)
     fn = aesara.function(
         [x],
         argtopk(x, 1, axis=axis, sorted=sorted, idx_dtype=idx_dtype),
         mode=self.mode,
     )
     assert any(
         isinstance(n.op, self.op_class)
         for n in fn.maker.fgraph.apply_nodes)
     xval = np.asarray([1]).astype(dtype)
     yval = fn(xval)
     assert yval == np.asarray([0], dtype=idx_dtype)
     assert yval.dtype == np.dtype(idx_dtype)
Esempio n. 4
0
    def test_argtopk_1d(self, size, k, dtype, sorted, idx_dtype):
        if isinstance(k, str):
            k = eval(k.replace("n", str(size)))

        x = vector(name="x", dtype=dtype)
        y = argtopk(x, k, sorted=sorted, idx_dtype=idx_dtype)
        fn = aesara.function([x], y, mode=self.mode)
        assert any(
            isinstance(n.op, self.op_class)
            for n in fn.maker.fgraph.apply_nodes)

        # assert local_useless_topk opt is done properly
        assert 1 == len(fn.maker.fgraph.outputs[0].owner.outputs)

        # generate a all-unique array
        xval = gen_unique_vector(size, dtype)
        yval = fn(xval)
        idx = slice(-k, None) if k > 0 else slice(-k)
        goal = np.argsort(xval)[idx].astype(idx_dtype)

        # due to uniqueness, we expect indices same
        assert np.all(xval[np.sort(yval)] == xval[np.sort(goal)])