def check_reduce(shape, op, dim, keepdims, is_cuda=False): with jt.log_capture_scope(log_silent=1, log_v=0, log_vprefix="op.cc=100") as raw_log: x = jt.random(shape) key, v = jt.arg_reduce(x, op, dim, keepdims) x_ = x.data key_ = key.data v_ = v.data if (is_cuda): logs = find_log_with_re( raw_log, "(Jit op key (not )?found: " + "cub_arg_reduce" + ".*)") assert len(logs) == 1 if op == 'max': key__ = np.argmax(x_, axis=dim) v__ = np.max(x_, axis=dim) else: key__ = np.argmin(x_, axis=dim) v__ = np.min(x_, axis=dim) if keepdims: key__ = np.expand_dims(key__, axis=dim) v__ = np.expand_dims(v__, axis=dim) assert np.allclose(key_, key__) assert np.allclose(v_, v__)
def check_backward(shape, op, dim, keepdims): x = jt.random(shape) v, key = jt.arg_reduce(x, op, dim, keepdims) loss = (key * key).sum() gs = jt.grad(loss, x) / 2 assert np.allclose((gs * x).data, (gs * gs).data)