예제 #1
0
    lambda ans, x: lambda g: g * replace_zero(np.conj(x), 0.0) / replace_zero(ans, 1.0),
)
defvjp(
    np.fabs, lambda ans, x: lambda g: np.sign(x) * g
)  # fabs doesn't take complex numbers.
defvjp(np.absolute, lambda ans, x: lambda g: g * np.conj(x) / ans)
defvjp(np.reciprocal, lambda ans, x: lambda g: -g / x ** 2)
defvjp(np.exp, lambda ans, x: lambda g: ans * g)
defvjp(np.exp2, lambda ans, x: lambda g: ans * np.log(2) * g)
defvjp(np.expm1, lambda ans, x: lambda g: (ans + 1) * g)
defvjp(np.log, lambda ans, x: lambda g: g / x)
defvjp(np.log2, lambda ans, x: lambda g: g / x / np.log(2))
defvjp(np.log10, lambda ans, x: lambda g: g / x / np.log(10))
defvjp(np.log1p, lambda ans, x: lambda g: g / (x + 1))
defvjp(np.sin, lambda ans, x: lambda g: g * np.cos(x))
defvjp(np.cos, lambda ans, x: lambda g: -g * np.sin(x))
defvjp(np.tan, lambda ans, x: lambda g: g / np.cos(x) ** 2)
defvjp(np.arcsin, lambda ans, x: lambda g: g / np.sqrt(1 - x ** 2))
defvjp(np.arccos, lambda ans, x: lambda g: -g / np.sqrt(1 - x ** 2))
defvjp(np.arctan, lambda ans, x: lambda g: g / (1 + x ** 2))
defvjp(np.sinh, lambda ans, x: lambda g: g * np.cosh(x))
defvjp(np.cosh, lambda ans, x: lambda g: g * np.sinh(x))
defvjp(np.tanh, lambda ans, x: lambda g: g / np.cosh(x) ** 2)
defvjp(np.arcsinh, lambda ans, x: lambda g: g / np.sqrt(x ** 2 + 1))
defvjp(np.arccosh, lambda ans, x: lambda g: g / np.sqrt(x ** 2 - 1))
defvjp(np.arctanh, lambda ans, x: lambda g: g / (1 - x ** 2))
defvjp(np.rad2deg, lambda ans, x: lambda g: g / np.pi * 180.0)
defvjp(np.degrees, lambda ans, x: lambda g: g / np.pi * 180.0)
defvjp(np.deg2rad, lambda ans, x: lambda g: g * np.pi / 180.0)
defvjp(np.radians, lambda ans, x: lambda g: g * np.pi / 180.0)
defvjp(np.square, lambda ans, x: lambda g: g * 2 * x)
예제 #2
0
register_diff(np.power, pow_diff)
register_diff(np.positive, lambda x: +x[1])
register_diff(np.negative, lambda x: -x[1])
register_diff(np.conj, lambda x: np.conj(x[1]))
register_diff(np.conj, lambda x: np.conj(x[1]))
register_diff(np.exp, lambda x: x[1] * np.exp(x[0]))
register_diff(np.exp2, lambda x: x[1] * np.log(2) * np.exp(x[0]))
register_diff(np.log, lambda x: x[1] / x[0])
register_diff(np.log2, lambda x: x[1] / (np.log(2) * x[0]))
register_diff(np.log10, lambda x: x[1] / (np.log(10) * x[0]))
register_diff(np.sqrt, lambda x: x[1] / (2 * np.sqrt(x[0])))
register_diff(np.square, lambda x: 2 * x[1] * x[0])
register_diff(np.cbrt, lambda x: x[1] / (3 * (x[0]**(2 / 3))))
register_diff(np.reciprocal, lambda x: -x[1] / np.square(x[0]))
register_diff(np.broadcast_to, lambda x, shape: np.broadcast_to(x[1], shape))

register_diff(np.sin, lambda x: x[1] * np.cos(x[0]))
register_diff(np.cos, lambda x: -x[1] * np.sin(x[0]))
register_diff(np.tan, lambda x: x[1] / np.square(np.cos(x[0])))
register_diff(np.arcsin, lambda x: x[1] / np.sqrt(1 - np.square(x[0])))
register_diff(np.arccos, lambda x: -x[1] / np.sqrt(1 - np.square(x[0])))
register_diff(np.arctan, lambda x: x[1] / (1 + np.square(x[0])))
register_diff(np.arctan2, arctan2_diff)

register_diff(np.sinh, lambda x: x[1] * np.cosh(x[0]))
register_diff(np.cosh, lambda x: x[1] * np.sinh(x[0]))
register_diff(np.tanh, lambda x: x[1] / np.square(np.cosh(x[0])))
register_diff(np.arcsinh, lambda x: x[1] / np.sqrt(1 + np.square(x[0])))
register_diff(np.arccosh, lambda x: x[1] / np.sqrt(1 - np.square(x[0])))
register_diff(np.arctanh, lambda x: x[1] / (1 - np.square(x[0])))
예제 #3
0
# ----- Simple grads -----
defjvp(np.positive, lambda ans, x: lambda g: np.ones_like(x) * g)
defjvp(np.negative, lambda ans, x: lambda g: -np.ones_like(x) * g)
defjvp(np.fabs, lambda ans, x: lambda g: np.sign(x) * g
       )  # fabs doesn't take complex numbers.
defjvp(np.absolute, lambda ans, x: lambda g: np.real(g * np.conj(x)) / ans)
defjvp(np.reciprocal, lambda ans, x: lambda g: -g / x**2)
defjvp(np.exp, lambda ans, x: lambda g: ans * g)
defjvp(np.exp2, lambda ans, x: lambda g: ans * np.log(2) * g)
defjvp(np.expm1, lambda ans, x: lambda g: (ans + 1) * g)
defjvp(np.log, lambda ans, x: lambda g: g / x)
defjvp(np.log2, lambda ans, x: lambda g: g / x / np.log(2))
defjvp(np.log10, lambda ans, x: lambda g: g / x / np.log(10))
defjvp(np.log1p, lambda ans, x: lambda g: g / (x + 1))
defjvp(np.sin, lambda ans, x: lambda g: g * np.cos(x))
defjvp(np.cos, lambda ans, x: lambda g: -g * np.sin(x))
defjvp(np.tan, lambda ans, x: lambda g: g / np.cos(x)**2)
defjvp(np.arcsin, lambda ans, x: lambda g: g / np.sqrt(1 - x**2))
defjvp(np.arccos, lambda ans, x: lambda g: -g / np.sqrt(1 - x**2))
defjvp(np.arctan, lambda ans, x: lambda g: g / (1 + x**2))
defjvp(np.sinh, lambda ans, x: lambda g: g * np.cosh(x))
defjvp(np.cosh, lambda ans, x: lambda g: g * np.sinh(x))
defjvp(np.tanh, lambda ans, x: lambda g: g / np.cosh(x)**2)
defjvp(np.arcsinh, lambda ans, x: lambda g: g / np.sqrt(x**2 + 1))
defjvp(np.arccosh, lambda ans, x: lambda g: g / np.sqrt(x**2 - 1))
defjvp(np.arctanh, lambda ans, x: lambda g: g / (1 - x**2))
defjvp(np.square, lambda ans, x: lambda g: g * 2 * x)
defjvp(np.sqrt, lambda ans, x: lambda g: g * 0.5 * x**-0.5)
defjvp(
    np.sinc,
    lambda ans, x: lambda g: g *
예제 #4
0
    except ua.BackendNotImplementedError:
        if backend in FULLY_TESTED_BACKENDS:
            raise
        pytest.xfail(
            reason="The backend has no implementation for this ufunc.")

    if isinstance(ret, da.Array):
        ret.compute()

    assert_allclose(ret.diffs[x].arr, y_d_arr)


@pytest.mark.parametrize(
    "func, y_d",
    [(lambda x: np.power(2 * x + 1, 3), lambda x: 6 * np.power(2 * x + 1, 2)),
     (lambda x: np.sin(np.power(x, 2)) / np.power(np.sin(x), 2), lambda x:
      (2 * x * np.cos(np.power(x, 2)) * np.sin(x) - 2 * np.sin(np.power(x, 2))
       * np.cos(x)) / np.power(np.sin(x), 3)),
     (lambda x: np.power(np.log(np.power(x, 3)), 1 / 3),
      lambda x: 2 * np.power(np.log(np.power(x, 2)), -2 / 3) / (3 * x)),
     (lambda x: np.log(
         (1 + x) / (1 - x)) / 4 - np.arctan(x) / 2, lambda x: np.power(x, 2) /
      (1 - np.power(x, 4)))],
)
def test_arbitrary_function(backend, func, y_d):
    x_arr = [0.2, 0.3]
    try:
        with ua.set_backend(backend), ua.set_backend(udiff, coerce=True):
            x = np.asarray(x_arr)
            x.var = udiff.Variable('x')
            ret = func(x)
예제 #5
0
                     format(mode))

    if isinstance(y, da.Array):
        y.compute()

    assert_allclose(u_diff.value, expect_u_diff)
    assert_allclose(v_diff.value, expect_v_diff)


@pytest.mark.parametrize(
    "func, y_d, domain",
    [
        (lambda x: x * x, lambda x: 2 * x, None),
        (lambda x: (2 * x + 1)**3, lambda x: 6 * (2 * x + 1)**2, (0.5, None)),
        (
            lambda x: np.sin(x**2) / np.sin(x)**2,
            lambda x:
            (2 * x * cos(x**2) * sin(x) - 2 * sin(x**2) * cos(x)) / sin(x)**3,
            (0, pi),
        ),
        (
            lambda x: np.log(x**2)**(1 / 3),
            lambda x: 2 * log(x**2)**(-2 / 3) / (3 * x),
            (1, None),
        ),
        (
            lambda x: np.log((1 + x) / (1 - x)) / 4 - np.arctan(x) / 2,
            lambda x: x**2 / (1 - x**4),
            (-1, 1),
        ),
        (
예제 #6
0
            raise
        pytest.xfail(reason="The backend has no implementation for this ufunc.")

    if isinstance(ret, da.Array):
        ret.compute()

    assert_allclose(ret.diffs[x].arr, y_d_arr)


@pytest.mark.xfail
@pytest.mark.parametrize(
    "func, y_d, domain",
    [
        (lambda x: (2 * x + 1) ** 3, lambda x: 6 * (2 * x + 1) ** 2, (0.5, None)),
        (
            lambda x: np.sin(x ** 2) / (np.sin(x)) ** 2,
            lambda x: (
                2 * x * np.cos(x ** 2) * np.sin(x) - 2 * np.sin(x ** 2) * np.cos(x)
            )
            / (np.sin(x)) ** 3,
            (0, pi),
        ),
        (
            lambda x: (np.log(x ** 2)) ** (1 / 3),
            lambda x: 2 * (np.log(x ** 2)) ** (-2 / 3) / (3 * x),
            (1, None),
        ),
        (
            lambda x: np.log((1 + x) / (1 - x)) / 4 - np.arctan(x) / 2,
            lambda x: x ** 2 / (1 - x ** 4),
            (-1, 1),
예제 #7
0
    except ua.BackendNotImplementedError:
        if backend in FULLY_TESTED_BACKENDS:
            raise
        pytest.xfail(
            reason="The backend has no implementation for this ufunc.")

    if isinstance(ret, da.Array):
        ret.compute()

    assert_allclose(ret.diffs[x].arr, y_d_arr)


@pytest.mark.parametrize(
    "func, y_d, domain",
    [(lambda x: (2 * x + 1)**3, lambda x: 6 * (2 * x + 1)**2, (0.5, None)),
     (lambda x: np.sin(x**2) / (np.sin(x))**2, lambda x:
      (2 * x * np.cos(x**2) * np.sin(x) - 2 * np.sin(x**2) * np.cos(x)) /
      (np.sin(x))**3, (0, pi)),
     (lambda x: (np.log(x**2))**(1 / 3), lambda x: 2 *
      (np.log(x**2))**(-2 / 3) / (3 * x), (1, None)),
     (lambda x: np.log((1 + x) /
                       (1 - x)) / 4 - np.arctan(x) / 2, lambda x: x**2 /
      (1 - x**4), (-1, 1)),
     (lambda x: np.arctanh(3 * x**3 + x**2 + 1), lambda x: (9 * x**2 + 2 * x) /
      (1 - (3 * x**3 + x**2 + 1)**2), (0, None)),
     (lambda x: np.sinh(np.cbrt(x)) + np.cosh(4 * x**3),
      lambda x: np.cosh(np.cbrt(x)) / (3 * x**(2 / 3)) + 12 *
      (x**2) * np.sinh(4 * x**3), (1 / 4, None)),
     (lambda x: np.log(1 + x**2) / np.arctanh(x), lambda x:
      ((2 * x * np.arctanh(x) / (1 + x**2)) -
       (np.log(1 + x**2) / (1 - x**2))) / (np.arctanh(x))**2, (0, 1))],