def test_grad_reverse():
    def f(x):
        return x**2

    adobj = AD(f)
    adobj.set_mode('reverse')
    assert adobj.grad(2) == [[4]]
def test_constant_function():
    adobj = AD(lambda x: 4)
    assert adobj.grad(2) == [[0]]
def test_multidim_m():
    adobj = AD(lambda x, y: [x + y, x**2])
    grad = adobj.grad(1, 1)
    truth = [[1, 1], [2, 0]]
    assert (grad == truth).all()
def test_multidim_n():
    adobj = AD(lambda x, y: x + y)
    grad = adobj.grad(1, 1)
    truth = [[1, 1]]
    assert (grad == truth).all()
def test_grad_reverse_multidim_n():
    adobj = AD(lambda x, y: x**2 + 2 * y)
    adobj.set_mode('reverse')
    truth = [[4, 2]]
    assert (adobj.grad(2, 1) == truth).all()
def test_set_incorrect_mode():
    adobj = AD(lambda x, y: x + y)
    adobj.mode = 'backprop'
    with pytest.raises(ValueError):
        adobj.grad(1, 2)