def check_reduce(shape, op, dim, keepdims, is_cuda=False):
    with jt.log_capture_scope(log_silent=1, log_v=0,
                              log_vprefix="op.cc=100") as raw_log:
        x = jt.random(shape)
        key, v = jt.arg_reduce(x, op, dim, keepdims)
        x_ = x.data
        key_ = key.data
        v_ = v.data
    if (is_cuda):
        logs = find_log_with_re(
            raw_log, "(Jit op key (not )?found: " + "cub_arg_reduce" + ".*)")
        assert len(logs) == 1
    if op == 'max':
        key__ = np.argmax(x_, axis=dim)
        v__ = np.max(x_, axis=dim)
    else:
        key__ = np.argmin(x_, axis=dim)
        v__ = np.min(x_, axis=dim)

    if keepdims:
        key__ = np.expand_dims(key__, axis=dim)
        v__ = np.expand_dims(v__, axis=dim)
    assert np.allclose(key_, key__)
    assert np.allclose(v_, v__)
def check_backward(shape, op, dim, keepdims):
    x = jt.random(shape)
    v, key = jt.arg_reduce(x, op, dim, keepdims)
    loss = (key * key).sum()
    gs = jt.grad(loss, x) / 2
    assert np.allclose((gs * x).data, (gs * gs).data)