def test_sum():
    def myfunc(a):
        return a.sum()

    adobj = AD(myfunc)
    inp = np.array([n for n in range(5)])
    truth = [[1, 1, 1, 1, 1]]
    assert (adobj._reverse(inp) == truth).all()
def test_reverse_noinfl_in():
    adobj = AD(lambda x, y: x**2 + x)
    grad = adobj._reverse(1, 1)
    truth = [[3, 0]]
    assert (grad == truth).all()
def test_reverse_multidim_m():
    adobj = AD(lambda x, y: [x + y, x**2])
    grad = adobj._reverse(1, 1)
    truth = [[1, 1], [2, 0]]
    assert (grad == truth).all()
def test_reverse_multidim_n():
    adobj = AD(lambda x, y: x**2 + 2 * y)
    truth = [[4, 2]]
    assert (adobj._reverse(2, 1) == truth).all()
    return a.sum()


ad_object = AD(my_sum)

# showcase reverse better than forward when inputs increase
fwd_times = []
rev_times = []
ninputs = range(1, 10000, 100)
for n in range(1, 10000, 100):
    inp = np.array([i for i in range(n)])

    start_time = time.time()
    ad_object._forward(inp)
    runtime = time.time() - start_time
    fwd_times.append(runtime)

    start_time = time.time()
    ad_object._reverse(inp)
    runtime = time.time() - start_time
    rev_times.append(runtime)

# plot results
plt.plot(ninputs, fwd_times, label='forward')
plt.plot(ninputs, rev_times, label='backward')
plt.legend()
plt.title('Comparing times under different AD modes')
plt.xlabel('Number of inputs')
plt.ylabel('Elapsed time in seconds')
plt.savefig('../img/fwd_rev_increasing_n.png')