def test_argtopk_nd(self, shp, k_, dtype, sorted, idx_dtype): ndim = len(shp) for axis in range(-ndim, ndim): if isinstance(k_, str): k = eval(k_.replace("n", str(shp[axis]))) else: k = k_ if k == 0: continue x = tensor(name="x", broadcastable=(False, ) * len(shp), dtype=dtype) y = argtopk(x, k, axis=axis, sorted=sorted, idx_dtype=idx_dtype) fn = aesara.function([x], y, mode=self.mode) assert any([ isinstance(n.op, self.op_class) for n in fn.maker.fgraph.apply_nodes ]) size = reduce(int.__mul__, shp) xval = gen_unique_vector(size, dtype).reshape(shp) yval = fn(xval) idx = slice(-k, None) if k > 0 else slice(-k) l = axis % ndim r = ndim - l idx = (slice(None), ) * l + (idx, ) + (slice(None), ) * (r - 1) goal = np.argsort(xval, axis=axis)[idx].astype(idx_dtype) assert np.all(np.sort(yval, axis=axis) == np.sort(goal, axis=axis))
def test_argtopk_1d_collision(self, size, k, dtype, sorted): # with non-unique kth max value if isinstance(k, str): k = eval(k.replace("n", str(size))) x = vector(name="x", dtype=dtype) y = argtopk(x, k, sorted=sorted, idx_dtype="int32") # DebugMode won't like the index change on collision on CPU # So don't use DebugMode here. mode = self.mode if isinstance(self.mode, aesara.compile.debugmode.DebugMode): mode = Mode(optimizer=mode.optimizer) fn = aesara.function([x], y, mode=mode) assert any([ isinstance(n.op, self.op_class) for n in fn.maker.fgraph.apply_nodes ]) rng = np.random.default_rng(utt.fetch_seed()) xval = np.repeat( rng.uniform(-100.0, 100.0, size=size // 2).astype(dtype), 2) xval = xval[rng.permutation(size)] yval = fn(xval) idx = slice(-k, None) if k > 0 else slice(-k) goal = np.argsort(xval)[idx].astype("int32") utt.assert_allclose(np.sort(xval[yval]), np.sort(xval[goal]))
def test_argtopk_sanity(self, dtype, idx_dtype, axis, sorted): x = vector(name="x", dtype=dtype) fn = aesara.function( [x], argtopk(x, 1, axis=axis, sorted=sorted, idx_dtype=idx_dtype), mode=self.mode, ) assert any( isinstance(n.op, self.op_class) for n in fn.maker.fgraph.apply_nodes) xval = np.asarray([1]).astype(dtype) yval = fn(xval) assert yval == np.asarray([0], dtype=idx_dtype) assert yval.dtype == np.dtype(idx_dtype)
def test_argtopk_1d(self, size, k, dtype, sorted, idx_dtype): if isinstance(k, str): k = eval(k.replace("n", str(size))) x = vector(name="x", dtype=dtype) y = argtopk(x, k, sorted=sorted, idx_dtype=idx_dtype) fn = aesara.function([x], y, mode=self.mode) assert any( isinstance(n.op, self.op_class) for n in fn.maker.fgraph.apply_nodes) # assert local_useless_topk opt is done properly assert 1 == len(fn.maker.fgraph.outputs[0].owner.outputs) # generate a all-unique array xval = gen_unique_vector(size, dtype) yval = fn(xval) idx = slice(-k, None) if k > 0 else slice(-k) goal = np.argsort(xval)[idx].astype(idx_dtype) # due to uniqueness, we expect indices same assert np.all(xval[np.sort(yval)] == xval[np.sort(goal)])