np.square(x2[0])) register_diff(np.add, lambda x1, x2: x1[1] + x2[1]) register_diff(np.subtract, lambda x1, x2: x1[1] - x2[1]) register_diff(np.multiply, multiply_diff) register_diff(np.matmul, matmul_diff) register_diff(np.divide, divide_diff) register_diff(np.true_divide, divide_diff) register_diff(np.power, pow_diff) register_diff(np.positive, lambda x: +x[1]) register_diff(np.negative, lambda x: -x[1]) register_diff(np.conj, lambda x: np.conj(x[1])) register_diff(np.conj, lambda x: np.conj(x[1])) register_diff(np.exp, lambda x: x[1] * np.exp(x[0])) register_diff(np.exp2, lambda x: x[1] * np.log(2) * np.exp(x[0])) register_diff(np.log, lambda x: x[1] / x[0]) register_diff(np.log2, lambda x: x[1] / (np.log(2) * x[0])) register_diff(np.log10, lambda x: x[1] / (np.log(10) * x[0])) register_diff(np.sqrt, lambda x: x[1] / (2 * np.sqrt(x[0]))) register_diff(np.square, lambda x: 2 * x[1] * x[0]) register_diff(np.cbrt, lambda x: x[1] / (3 * (x[0]**(2 / 3)))) register_diff(np.reciprocal, lambda x: -x[1] / np.square(x[0])) register_diff(np.broadcast_to, lambda x, shape: np.broadcast_to(x[1], shape)) register_diff(np.sin, lambda x: x[1] * np.cos(x[0])) register_diff(np.cos, lambda x: -x[1] * np.sin(x[0])) register_diff(np.tan, lambda x: x[1] / np.square(np.cos(x[0]))) register_diff(np.arcsin, lambda x: x[1] / np.sqrt(1 - np.square(x[0]))) register_diff(np.arccos, lambda x: -x[1] / np.sqrt(1 - np.square(x[0]))) register_diff(np.arctan, lambda x: x[1] / (1 + np.square(x[0])))
def pow_diff(x1, x2): ftog = x1[0]**x2[0] gplogf = np.log(x1[0]) * x2[1] gfpof = x1[1] / x1[0] * x2[0] return ftog * (gplogf + gfpof)
) defjvp(np.true_divide, "same", lambda ans, x, y: lambda g: -g * x / y**2) defjvp( np.mod, lambda ans, x, y: lambda g: np.broadcast_to(g, np.shape(ans)), lambda ans, x, y: lambda g: -g * np.floor(x / y), ) defjvp( np.remainder, lambda ans, x, y: lambda g: np.broadcast_to(g, np.shape(ans)), lambda ans, x, y: lambda g: -g * np.floor(x / y), ) defjvp( np.power, lambda ans, x, y: lambda g: g * y * x**np.where(y, y - 1, 1.0), lambda ans, x, y: lambda g: g * np.log(replace_zero(x, 1.0)) * ans, ) defjvp( np.arctan2, lambda ans, x, y: lambda g: g * y / (x**2 + y**2), lambda ans, x, y: lambda g: g * -x / (x**2 + y**2), ) # ----- Simple grads (linear) ----- defjvp(np.negative, "same") defjvp(np.rad2deg, "same") defjvp(np.degrees, "same") defjvp(np.deg2rad, "same") defjvp(np.radians, "same") defjvp(np.reshape, "same") defjvp(np.roll, "same")
) defvjp( np.mod, lambda ans, x, y: unbroadcast_f(x, lambda g: g), lambda ans, x, y: unbroadcast_f(y, lambda g: -g * np.floor(x / y)), ) defvjp( np.remainder, lambda ans, x, y: unbroadcast_f(x, lambda g: g), lambda ans, x, y: unbroadcast_f(y, lambda g: -g * np.floor(x / y)), ) defvjp( np.power, lambda ans, x, y: unbroadcast_f(x, lambda g: g * y * x ** np.where(y, y - 1, 1.0)), lambda ans, x, y: unbroadcast_f( y, lambda g: g * np.log(replace_non_positive(x, 1.0)) * ans ), ) defvjp( np.arctan2, lambda ans, x, y: unbroadcast_f(x, lambda g: g * y / (x ** 2 + y ** 2)), lambda ans, x, y: unbroadcast_f(y, lambda g: g * -x / (x ** 2 + y ** 2)), ) defvjp( np.hypot, lambda ans, x, y: unbroadcast_f(x, lambda g: g * x / ans), lambda ans, x, y: unbroadcast_f(y, lambda g: g * y / ans), ) # ----- Simple grads ----- defvjp(np.sign, lambda ans, x: lambda g: np.nan if x == 0 else 0)
pytest.xfail( reason="The backend has no implementation for this ufunc.") if isinstance(ret, da.Array): ret.compute() assert_allclose(ret.diffs[x].arr, y_d_arr) @pytest.mark.parametrize( "func, y_d", [(lambda x: np.power(2 * x + 1, 3), lambda x: 6 * np.power(2 * x + 1, 2)), (lambda x: np.sin(np.power(x, 2)) / np.power(np.sin(x), 2), lambda x: (2 * x * np.cos(np.power(x, 2)) * np.sin(x) - 2 * np.sin(np.power(x, 2)) * np.cos(x)) / np.power(np.sin(x), 3)), (lambda x: np.power(np.log(np.power(x, 3)), 1 / 3), lambda x: 2 * np.power(np.log(np.power(x, 2)), -2 / 3) / (3 * x)), (lambda x: np.log( (1 + x) / (1 - x)) / 4 - np.arctan(x) / 2, lambda x: np.power(x, 2) / (1 - np.power(x, 4)))], ) def test_arbitrary_function(backend, func, y_d): x_arr = [0.2, 0.3] try: with ua.set_backend(backend), ua.set_backend(udiff, coerce=True): x = np.asarray(x_arr) x.var = udiff.Variable('x') ret = func(x) y_d_arr = y_d(x) except ua.BackendNotImplementedError: if backend in FULLY_TESTED_BACKENDS:
assert_allclose(v_diff.value, expect_v_diff) @pytest.mark.parametrize( "func, y_d, domain", [ (lambda x: x * x, lambda x: 2 * x, None), (lambda x: (2 * x + 1)**3, lambda x: 6 * (2 * x + 1)**2, (0.5, None)), ( lambda x: np.sin(x**2) / np.sin(x)**2, lambda x: (2 * x * cos(x**2) * sin(x) - 2 * sin(x**2) * cos(x)) / sin(x)**3, (0, pi), ), ( lambda x: np.log(x**2)**(1 / 3), lambda x: 2 * log(x**2)**(-2 / 3) / (3 * x), (1, None), ), ( lambda x: np.log((1 + x) / (1 - x)) / 4 - np.arctan(x) / 2, lambda x: x**2 / (1 - x**4), (-1, 1), ), ( lambda x: np.log(1 + x**2) / np.arctanh(x), lambda x: ((2 * x * atanh(x) / (1 + x**2)) - (log(1 + x**2) / (1 - x**2))) / atanh(x)**2, (0, 1), ), ],
@pytest.mark.xfail @pytest.mark.parametrize( "func, y_d, domain", [ (lambda x: (2 * x + 1) ** 3, lambda x: 6 * (2 * x + 1) ** 2, (0.5, None)), ( lambda x: np.sin(x ** 2) / (np.sin(x)) ** 2, lambda x: ( 2 * x * np.cos(x ** 2) * np.sin(x) - 2 * np.sin(x ** 2) * np.cos(x) ) / (np.sin(x)) ** 3, (0, pi), ), ( lambda x: (np.log(x ** 2)) ** (1 / 3), lambda x: 2 * (np.log(x ** 2)) ** (-2 / 3) / (3 * x), (1, None), ), ( lambda x: np.log((1 + x) / (1 - x)) / 4 - np.arctan(x) / 2, lambda x: x ** 2 / (1 - x ** 4), (-1, 1), ), ( lambda x: np.arctanh(3 * x ** 3 + x ** 2 + 1), lambda x: (9 * x ** 2 + 2 * x) / (1 - (3 * x ** 3 + x ** 2 + 1) ** 2), (0, None), ), ( lambda x: np.sinh(np.cbrt(x)) + np.cosh(4 * x ** 3),
except ua.BackendNotImplementedError: if backend in FULLY_TESTED_BACKENDS: raise pytest.xfail(reason="The backend has no implementation for this ufunc.") if isinstance(ret, da.Array): ret.compute() assert_allclose(ret.diffs[x].arr, y_d_arr) @pytest.mark.parametrize( "func, y_d", [ (lambda x: np.power(2 * x + 1, 3), lambda x: 6 * np.power(2 * x + 1, 2)), (lambda x: np.sin(np.power(x, 2)) / np.power(np.sin(x), 2), lambda x: (2 * x * np.cos(np.power(x, 2)) * np.sin(x) - 2 * np.sin(np.power(x, 2)) * np.cos(x)) / np.power(np.sin(x), 3)), (lambda x: np.power(np.log(np.power(x, 3)), 1/3), lambda x: 2 * np.power(np.log(np.power(x, 2)), -2/3) / (3 * x)), (lambda x: np.log((1 + x) / (1 - x)) / 4 - np.arctan(x) / 2, lambda x: np.power(x, 2) / (1 - np.power(x, 4))), (lambda x: np.arctanh(3 * x ** 3 + x ** 2 +1), lambda x: (9 * x ** 2 + 2 * x) / (1 - np.power(3 * x ** 3 + x ** 2 + 1 , 2))), (lambda x: np.sinh(np.cbrt(x)) + np.cosh(4 * x ** 3) , lambda x: np.cosh(np.cbrt(x)) / (3 * x ** (2/3)) + 12 * (x ** 2) * np.sinh(4 * x ** 3)), (lambda x: np.log(1 + x ** 2) / np.arctanh(x), lambda x: ((2 * x * np.arctanh(x) / (1 + x ** 2)) - (np.log(1 + x ** 2)/(1 - x ** 2))) / np.power(np.arctanh(x) , 2)) ], ) def test_arbitrary_function(backend, func, y_d): x_arr = [0.2, 0.3] try: with ua.set_backend(backend), ua.set_backend(udiff, coerce=True): x = np.asarray(x_arr) x.var = udiff.Variable('x') ret = func(x) y_d_arr = y_d(x) except ua.BackendNotImplementedError:
pytest.xfail( reason="The backend has no implementation for this ufunc.") if isinstance(ret, da.Array): ret.compute() assert_allclose(ret.diffs[x].arr, y_d_arr) @pytest.mark.parametrize( "func, y_d, domain", [(lambda x: (2 * x + 1)**3, lambda x: 6 * (2 * x + 1)**2, (0.5, None)), (lambda x: np.sin(x**2) / (np.sin(x))**2, lambda x: (2 * x * np.cos(x**2) * np.sin(x) - 2 * np.sin(x**2) * np.cos(x)) / (np.sin(x))**3, (0, pi)), (lambda x: (np.log(x**2))**(1 / 3), lambda x: 2 * (np.log(x**2))**(-2 / 3) / (3 * x), (1, None)), (lambda x: np.log((1 + x) / (1 - x)) / 4 - np.arctan(x) / 2, lambda x: x**2 / (1 - x**4), (-1, 1)), (lambda x: np.arctanh(3 * x**3 + x**2 + 1), lambda x: (9 * x**2 + 2 * x) / (1 - (3 * x**3 + x**2 + 1)**2), (0, None)), (lambda x: np.sinh(np.cbrt(x)) + np.cosh(4 * x**3), lambda x: np.cosh(np.cbrt(x)) / (3 * x**(2 / 3)) + 12 * (x**2) * np.sinh(4 * x**3), (1 / 4, None)), (lambda x: np.log(1 + x**2) / np.arctanh(x), lambda x: ((2 * x * np.arctanh(x) / (1 + x**2)) - (np.log(1 + x**2) / (1 - x**2))) / (np.arctanh(x))**2, (0, 1))], ) def test_arbitrary_function(backend, func, y_d, domain): if domain is None: