예제 #1
0
    def jvp(g):
        if axis is None:
            num_reps = np.size(g)
        elif isinstance(axis, int):
            num_reps = np.shape(g)[axis]
        elif isinstance(axis, tuple):
            num_reps = np.prod(np.array(np.shape(g))[list(axis)])

        if num_reps <= 1:
            return np.zeros_like(ans)
        x_minus_mean = np.conj(x - np.mean(x, axis=axis, keepdims=True))
        return np.sum(np.real(g * x_minus_mean), axis=axis,
                      keepdims=keepdims) / ((num_reps - ddof) * ans)
예제 #2
0
 def vjp(g):
     if iscomplex:
         g = g + 0j
     g_repeated, num_reps = repeat_to_match_shape(
         g, shape, dtype, axis, keepdims
     )  # Avoid division by zero.
     if num_reps <= 1:
         return g_repeated * 0.0
     else:
         g_repeated, num_reps = repeat_to_match_shape(
             g / ans, shape, dtype, axis, keepdims
         )
         x_minus_mean = np.conj(x - x / np.sum(x, axis=axis, keepdims=True))
         return g_repeated * x_minus_mean / (num_reps - ddof)
예제 #3
0
def arctan2_diff(x1, x2):
    return (x1[1] * x2[0] - x1[0] * x2[1]) / (np.square(x1[0]) +
                                              np.square(x2[0]))


register_diff(np.add, lambda x1, x2: x1[1] + x2[1])
register_diff(np.subtract, lambda x1, x2: x1[1] - x2[1])
register_diff(np.multiply, multiply_diff)
register_diff(np.matmul, matmul_diff)
register_diff(np.divide, divide_diff)
register_diff(np.true_divide, divide_diff)
register_diff(np.power, pow_diff)
register_diff(np.positive, lambda x: +x[1])
register_diff(np.negative, lambda x: -x[1])
register_diff(np.conj, lambda x: np.conj(x[1]))
register_diff(np.conj, lambda x: np.conj(x[1]))
register_diff(np.exp, lambda x: x[1] * np.exp(x[0]))
register_diff(np.exp2, lambda x: x[1] * np.log(2) * np.exp(x[0]))
register_diff(np.log, lambda x: x[1] / x[0])
register_diff(np.log2, lambda x: x[1] / (np.log(2) * x[0]))
register_diff(np.log10, lambda x: x[1] / (np.log(10) * x[0]))
register_diff(np.sqrt, lambda x: x[1] / (2 * np.sqrt(x[0])))
register_diff(np.square, lambda x: 2 * x[1] * x[0])
register_diff(np.cbrt, lambda x: x[1] / (3 * (x[0]**(2 / 3))))
register_diff(np.reciprocal, lambda x: -x[1] / np.square(x[0]))
register_diff(np.broadcast_to, lambda x, shape: np.broadcast_to(x[1], shape))

register_diff(np.sin, lambda x: x[1] * np.cos(x[0]))
register_diff(np.cos, lambda x: -x[1] * np.sin(x[0]))
register_diff(np.tan, lambda x: x[1] / np.square(np.cos(x[0])))
예제 #4
0
 def vjp(g):
     if iscomplex:
         g = g + 0j
     g_repeated, num_reps = repeat_to_match_shape(g, shape, dtype, axis, keepdims)
     x_minus_mean = np.conj(x - x / np.sum(x, axis=axis, keepdims=True))
     return 2.0 * g_repeated * x_minus_mean / (num_reps - ddof)
예제 #5
0
    lambda ans, x, y: unbroadcast_f(x, lambda g: g * y / (x ** 2 + y ** 2)),
    lambda ans, x, y: unbroadcast_f(y, lambda g: g * -x / (x ** 2 + y ** 2)),
)
defvjp(
    np.hypot,
    lambda ans, x, y: unbroadcast_f(x, lambda g: g * x / ans),
    lambda ans, x, y: unbroadcast_f(y, lambda g: g * y / ans),
)

# ----- Simple grads -----
defvjp(np.sign, lambda ans, x: lambda g: np.nan if x == 0 else 0)
defvjp(np.positive, lambda ans, x: lambda g: g)
defvjp(np.negative, lambda ans, x: lambda g: -g)
defvjp(
    np.absolute,
    lambda ans, x: lambda g: g * replace_zero(np.conj(x), 0.0) / replace_zero(ans, 1.0),
)
defvjp(
    np.fabs, lambda ans, x: lambda g: np.sign(x) * g
)  # fabs doesn't take complex numbers.
defvjp(np.absolute, lambda ans, x: lambda g: g * np.conj(x) / ans)
defvjp(np.reciprocal, lambda ans, x: lambda g: -g / x ** 2)
defvjp(np.exp, lambda ans, x: lambda g: ans * g)
defvjp(np.exp2, lambda ans, x: lambda g: ans * np.log(2) * g)
defvjp(np.expm1, lambda ans, x: lambda g: (ans + 1) * g)
defvjp(np.log, lambda ans, x: lambda g: g / x)
defvjp(np.log2, lambda ans, x: lambda g: g / x / np.log(2))
defvjp(np.log10, lambda ans, x: lambda g: g / x / np.log(10))
defvjp(np.log1p, lambda ans, x: lambda g: g / (x + 1))
defvjp(np.sin, lambda ans, x: lambda g: g * np.cos(x))
defvjp(np.cos, lambda ans, x: lambda g: -g * np.sin(x))
예제 #6
0
defjvp(np.rot90, "same")
defjvp(np.full, "same", argnums=(1, ))
defjvp(np.triu, "same")
defjvp(np.tril, "same")
defjvp(np.swapaxes, "same")
defjvp(np.rollaxis, "same")
defjvp(np.moveaxis, "same")
defjvp(np.broadcast_to, "same")
def_linear(np.cross)

# ----- Simple grads -----
defjvp(np.positive, lambda ans, x: lambda g: np.ones_like(x) * g)
defjvp(np.negative, lambda ans, x: lambda g: -np.ones_like(x) * g)
defjvp(np.fabs, lambda ans, x: lambda g: np.sign(x) * g
       )  # fabs doesn't take complex numbers.
defjvp(np.absolute, lambda ans, x: lambda g: np.real(g * np.conj(x)) / ans)
defjvp(np.reciprocal, lambda ans, x: lambda g: -g / x**2)
defjvp(np.exp, lambda ans, x: lambda g: ans * g)
defjvp(np.exp2, lambda ans, x: lambda g: ans * np.log(2) * g)
defjvp(np.expm1, lambda ans, x: lambda g: (ans + 1) * g)
defjvp(np.log, lambda ans, x: lambda g: g / x)
defjvp(np.log2, lambda ans, x: lambda g: g / x / np.log(2))
defjvp(np.log10, lambda ans, x: lambda g: g / x / np.log(10))
defjvp(np.log1p, lambda ans, x: lambda g: g / (x + 1))
defjvp(np.sin, lambda ans, x: lambda g: g * np.cos(x))
defjvp(np.cos, lambda ans, x: lambda g: -g * np.sin(x))
defjvp(np.tan, lambda ans, x: lambda g: g / np.cos(x)**2)
defjvp(np.arcsin, lambda ans, x: lambda g: g / np.sqrt(1 - x**2))
defjvp(np.arccos, lambda ans, x: lambda g: -g / np.sqrt(1 - x**2))
defjvp(np.arctan, lambda ans, x: lambda g: g / (1 + x**2))
defjvp(np.sinh, lambda ans, x: lambda g: g * np.cosh(x))