def test_grad_reverse():
    def f(x):
        return x**2

    adobj = AD(f)
    adobj.set_mode('reverse')
    assert adobj.grad(2) == [[4]]
def test_set_incorrect_set_mode():
    adobj = AD(lambda x, y: x + y)
    with pytest.raises(ValueError):
        adobj.set_mode('backprop')
def test_set_mode():
    adobj = AD(lambda x, y: x + y)
    adobj.set_mode('reverse')
    assert adobj.mode == 'reverse'
def test_grad_reverse_multidim_n():
    adobj = AD(lambda x, y: x**2 + 2 * y)
    adobj.set_mode('reverse')
    truth = [[4, 2]]
    assert (adobj.grad(2, 1) == truth).all()