示例#1
0
def unbroadcast(x, target_meta, broadcast_idx=0):
    target_shape, target_ndim, _, _ = target_meta
    while np.ndim(x) > target_ndim:
        x = np.sum(x, axis=broadcast_idx)
    for axis, size in enumerate(target_shape):
        if size == 1:
            x = np.sum(x, axis=axis, keepdims=True)
    if np.iscomplexobj(x) and not target_iscomplex:
        x = np.real(x)
    return x
示例#2
0
 def vjp(g):
     if axis is None:  # If axis is none, np.repeat() repeats the flattened array.
         expanded = np.reshape(g, (np.prod(shape),) + (repeats,))
         return np.reshape(np.sum(expanded, axis=1, keepdims=False), shape)
     else:
         if shape[axis] == 1:  # For this common case, the logic is simple.
             return np.sum(g, axis=axis, keepdims=True)
         else:
             expanded = np.reshape(
                 g, shape[0 : axis + 1] + (repeats,) + shape[axis + 1 :]
             )
             return np.sum(expanded, axis=axis + 1, keepdims=False)
示例#3
0
 def jvp(g):
     if np.isscalar(x):
         return g
     if not keepdims:
         if isinstance(axis, int):
             ans = np.expand_dims(ans, axis)
         elif isinstance(axis, tuple):
             for ax in sorted(axis):
                 ans = np.expand_dims(ans, ax)
     chosen_locations = x == ans
     return np.sum(
         (g * chosen_locations), axis=axis, keepdims=keepdims) / np.sum(
             chosen_locations, axis=axis, keepdims=keepdims)
示例#4
0
def grad_broadcast_to(ans, x, new_shape):
    old_shape = np.shape(x)
    assert np.shape(ans) == new_shape
    assert len(old_shape) == len(new_shape), "Can't handle extra leading dims"

    broadcast_axes = tuple(
        i for i in range(len(old_shape)) if old_shape[i] == 1 and new_shape[i] > 1
    )

    return lambda g: np.sum(g, axis=broadcast_axes, keepdims=True)
示例#5
0
 def vjp(g):
     """Builds gradient of functions that choose a single item, such as min or max."""
     g_repeated, _ = repeat_to_match_shape(g, shape, dtype, axis, keepdims)
     argmax_locations = (
         x == repeat_to_match_shape(ans, shape, dtype, axis, keepdims)[0]
     )
     return (
         g_repeated
         * argmax_locations
         / np.sum(argmax_locations, axis=axis, keepdims=True)
     )
示例#6
0
    def jvp(g):
        if axis is None:
            num_reps = np.size(g)
        elif isinstance(axis, int):
            num_reps = np.shape(g)[axis]
        elif isinstance(axis, tuple):
            num_reps = np.prod(np.array(np.shape(g))[list(axis)])

        if num_reps <= 1:
            return np.zeros_like(ans)
        x_minus_mean = np.conj(x - np.mean(x, axis=axis, keepdims=True))
        return np.sum(np.real(g * x_minus_mean), axis=axis,
                      keepdims=keepdims) / ((num_reps - ddof) * ans)
示例#7
0
 def vjp(g):
     if iscomplex:
         g = g + 0j
     g_repeated, num_reps = repeat_to_match_shape(
         g, shape, dtype, axis, keepdims
     )  # Avoid division by zero.
     if num_reps <= 1:
         return g_repeated * 0.0
     else:
         g_repeated, num_reps = repeat_to_match_shape(
             g / ans, shape, dtype, axis, keepdims
         )
         x_minus_mean = np.conj(x - x / np.sum(x, axis=axis, keepdims=True))
         return g_repeated * x_minus_mean / (num_reps - ddof)
示例#8
0
 def vjp(g):
     if iscomplex:
         g = g + 0j
     g_repeated, num_reps = repeat_to_match_shape(g, shape, dtype, axis, keepdims)
     x_minus_mean = np.conj(x - x / np.sum(x, axis=axis, keepdims=True))
     return 2.0 * g_repeated * x_minus_mean / (num_reps - ddof)
示例#9
0
defvjp(np.vsplit, lambda ans, ary, idxs: lambda g: np.concatenate(g, axis=0))
defvjp(np.hsplit, lambda ans, ary, idxs: lambda g: np.concatenate(g, axis=1))
defvjp(np.dsplit, lambda ans, ary, idxs: lambda g: np.concatenate(g, axis=2))
defvjp(
    np.ravel,
    lambda ans, x, order=None: lambda g: np.reshape(g, np.shape(x), order=order),
)
defvjp(np.expand_dims, lambda ans, x, axis: lambda g: np.reshape(g, np.shape(x)))
defvjp(np.squeeze, lambda ans, x, axis=None: lambda g: np.reshape(g, np.shape(x)))
defvjp(np.diag, lambda ans, x, k=0: lambda g: np.diag(g, k))
defvjp(np.flipud, lambda ans, x,: lambda g: np.flipud(g))
defvjp(np.fliplr, lambda ans, x,: lambda g: np.fliplr(g))
defvjp(np.rot90, lambda ans, x, k=1: lambda g: np.rot90(g, -k))
defvjp(
    np.full,
    lambda ans, shape, fill_value, dtype=None: lambda g: np.sum(g),
    argnums=(1,),
)
defvjp(np.triu, lambda ans, x, k=0: lambda g: np.triu(g, k=k))
defvjp(np.tril, lambda ans, x, k=0: lambda g: np.tril(g, k=k))
defvjp(
    np.clip,
    lambda ans, x, a_min, a_max: lambda g: g
    * np.logical_and(ans != a_min, ans != a_max),
)
defvjp(np.swapaxes, lambda ans, x, axis1, axis2: lambda g: np.swapaxes(g, axis2, axis1))
defvjp(
    np.moveaxis,
    lambda ans, a, source, destination: lambda g: np.moveaxis(g, destination, source),
)
defvjp(np.real_if_close, lambda ans, x: lambda g: match_complex(x, g))
示例#10
0
                                                      np.zeros(np.shape(g))),
    lambda ans, c, x=None, y=None: lambda g: np.where(c, np.zeros(g.shape), g),
)

# ----- Trickier grads -----
# defjvp(np.kron, "same", "same")
defjvp(np.diff, "same")
defjvp(np.gradient, "same")
defjvp(np.repeat, "same")
defjvp(np.tile, "same")
defjvp(np.transpose, "same")
defjvp(np.sum, "same")

defjvp(
    np.prod,
    lambda ans, x, axis=None, keepdims=False: lambda g: ans * np.sum(
        g / x, axis=axis, keepdims=keepdims),
)
defjvp(
    np.linspace,
    lambda ans, start, stop, *args, **kwargs: lambda g: np.linspace(
        g, 0, *args, **kwargs),
    lambda ans, start, stop, *args, **kwargs: lambda g: np.linspace(
        0, g, *args, **kwargs),
)


def forward_grad_np_var(ans, x, axis=None, ddof=0, keepdims=False):
    def jvp(g):
        if axis is None:
            num_reps = np.size(g)
        elif isinstance(axis, int):
示例#11
0
    backend = request.param
    return backend


@pytest.fixture(scope="session", params=["vjp", "jvp"])
def mode(request):
    mode = request.param
    return mode


@pytest.mark.parametrize(
    "x, func, expect_jacobian",
    [
        (
            onp.arange(12).reshape(2, 3, 2),
            lambda x: np.sum(x, axis=1),
            [
                [
                    [[[1, 0], [1, 0], [1, 0]], [[0, 0], [0, 0], [0, 0]]],
                    [[[0, 1], [0, 1], [0, 1]], [[0, 0], [0, 0], [0, 0]]],
                ],
                [
                    [[[0, 0], [0, 0], [0, 0]], [[1, 0], [1, 0], [1, 0]]],
                    [[[0, 0], [0, 0], [0, 0]], [[0, 1], [0, 1], [0, 1]]],
                ],
            ],
        ),
        (
            onp.arange(4).reshape((2, 2)),
            lambda x: x,
            [