Example #1
0
 def jvp(g):
     ret = []
     ng = np.broadcast_to(g, np.shape(ans))
     shape = np.shape(ng)
     for idx in range(shape[axis]):
         ret.append(np.take(ng, idx, axis=axis))
     return tuple(ret)
Example #2
0
def repeat_to_match_shape(g, shape, dtype, axis, keepdims):
    """Returns the array g repeated along axis to fit vector space vs.
    Also returns the number of repetitions of the array."""
    with ua.set_backend(numpy_backend, coerce=True):
        if shape == ():
            return g, 1
        axis = list(axis) if isinstance(axis, tuple) else axis
        new_shape = np.array(shape, dtype=int)
        new_shape[axis] = 1
        num_reps = np.prod(np.array(shape)[axis])
    return np.broadcast_to(np.reshape(g, new_shape), shape), num_reps
Example #3
0
    def __setitem__(self, k, v):
        if isinstance(k, DiffArray):
            if not k.var:
                raise ValueError("k does not have a set var")

            k = k.var

        if not isinstance(k, Variable):
            raise ValueError("k must be a Variable")

        if not isinstance(v, DiffArray):
            v = DiffArray(v)

        v = np.broadcast_to(v, self.shape)
        super().__setitem__(k, v)
Example #4
0
def __ua_function__(func, args, kwargs, skip_vars=None):
    from udiff import SKIP_SELF
    from ._func_diff_registry import global_registry

    if skip_vars is None:
        skip_vars = set()

    extracted_args = func.arg_extractor(*args, **kwargs)
    arr_args = tuple(x.value for x in extracted_args if x.type is np.ndarray)
    input_args = tuple(x.value for x in extracted_args
                       if x.coercible and x.type is np.ndarray)

    with SKIP_SELF:
        if len(arr_args) == 0:
            out = func(*args, **kwargs)
            return DiffArray(out)

        a, kw = replace_arrays(func, args, kwargs,
                               (x.arr if x is not None else None
                                for x in arr_args))
        out_arr = func(*a, **kw)

    out = DiffArray(out_arr)

    diff_keys = set.union(*(set(d_j.diffs.keys()) for d_j in input_args))

    with ua.set_backend(NoRecurseBackend(set.union(skip_vars, diff_keys))):
        for k in diff_keys - skip_vars:
            diff_args = []
            for arr in arr_args:
                if arr is None:
                    diff_args.append(None)
                    continue

                if k in arr.diffs:
                    diff_args.append((arr, arr.diffs[k]))
                else:
                    diff_args.append((arr, np.broadcast_to(0, arr.shape)))

            a, kw = replace_arrays(func, args, kwargs, diff_args)

            if func is np.ufunc.__call__:
                diff_arr = global_registry[a[0]](*a[1:], **kw)
            else:
                diff_arr = global_registry[func](*a, **kw)
            out.diffs[k] = diff_arr

    return out
Example #5
0
def __ua_function__(func, args, kwargs, tree=None):
    from udiff import SKIP_SELF
    from ._func_diff_registry import global_registry

    extracted_args = func.arg_extractor(*args, **kwargs)
    arr_args = tuple(x.value for x in extracted_args if x.type is np.ndarray)
    input_args = tuple(
        x.value for x in extracted_args if x.coercible and x.type is np.ndarray
    )

    if tree is None:
        tree = compute_diff_tree(*input_args)

    with SKIP_SELF:
        if len(arr_args) == 0:
            out = func(*args, **kwargs)
            return DiffArray(out)

        a, kw = replace_arrays(
            func, args, kwargs, (x.arr if x is not None else None for x in arr_args)
        )
        out_arr = func(*a, **kw)

    out = DiffArray(out_arr)
    for k in tree:
        diff_args = []
        for arr in arr_args:
            if arr is None:
                diff_args.append(None)
                continue

            if k in arr.diffs:
                diff_args.append((arr, arr.diffs[k]))
            else:
                diff_args.append((arr, np.broadcast_to(0, arr.shape)))

        a, kw = replace_arrays(func, args, kwargs, diff_args)

        with ua.set_backend(NoRecurseBackend(tree[k])):
            if func is np.ufunc.__call__:
                diff_arr = global_registry[a[0]](*a[1:], **kw)
            else:
                diff_arr = global_registry[func](*a, **kw)
            out.diffs[k] = diff_arr

    return out
Example #6
0
register_diff(np.true_divide, divide_diff)
register_diff(np.power, pow_diff)
register_diff(np.positive, lambda x: +x[1])
register_diff(np.negative, lambda x: -x[1])
register_diff(np.conj, lambda x: np.conj(x[1]))
register_diff(np.conj, lambda x: np.conj(x[1]))
register_diff(np.exp, lambda x: x[1] * np.exp(x[0]))
register_diff(np.exp2, lambda x: x[1] * np.log(2) * np.exp(x[0]))
register_diff(np.log, lambda x: x[1] / x[0])
register_diff(np.log2, lambda x: x[1] / (np.log(2) * x[0]))
register_diff(np.log10, lambda x: x[1] / (np.log(10) * x[0]))
register_diff(np.sqrt, lambda x: x[1] / (2 * np.sqrt(x[0])))
register_diff(np.square, lambda x: 2 * x[1] * x[0])
register_diff(np.cbrt, lambda x: x[1] / (3 * (x[0]**(2 / 3))))
register_diff(np.reciprocal, lambda x: -x[1] / np.square(x[0]))
register_diff(np.broadcast_to, lambda x, shape: np.broadcast_to(x[1], shape))

register_diff(np.sin, lambda x: x[1] * np.cos(x[0]))
register_diff(np.cos, lambda x: -x[1] * np.sin(x[0]))
register_diff(np.tan, lambda x: x[1] / np.square(np.cos(x[0])))
register_diff(np.arcsin, lambda x: x[1] / np.sqrt(1 - np.square(x[0])))
register_diff(np.arccos, lambda x: -x[1] / np.sqrt(1 - np.square(x[0])))
register_diff(np.arctan, lambda x: x[1] / (1 + np.square(x[0])))
register_diff(np.arctan2, arctan2_diff)

register_diff(np.sinh, lambda x: x[1] * np.cosh(x[0]))
register_diff(np.cosh, lambda x: x[1] * np.sinh(x[0]))
register_diff(np.tanh, lambda x: x[1] / np.square(np.cosh(x[0])))
register_diff(np.arcsinh, lambda x: x[1] / np.sqrt(1 + np.square(x[0])))
register_diff(np.arccosh, lambda x: x[1] / np.sqrt(1 - np.square(x[0])))
register_diff(np.arctanh, lambda x: x[1] / (1 - np.square(x[0])))
Example #7
0
    defjvp,
    defjvp_argnum,
    def_linear,
)

# ----- Functions that are constant w.r.t. continuous inputs -----
defjvp(np.nan_to_num,
       lambda ans, x: lambda g: np.where(np.isfinite(x), g, 0.0))

# ----- Binary ufuncs (linear) -----
def_linear(np.multiply)

# ----- Binary ufuncs -----
defjvp(
    np.add,
    lambda ans, x, y: lambda g: np.broadcast_to(g, np.shape(ans)),
    lambda ans, x, y: lambda g: np.broadcast_to(g, np.shape(ans)),
)
defjvp(
    np.subtract,
    lambda ans, x, y: lambda g: np.broadcast_to(g, np.shape(ans)),
    lambda ans, x, y: lambda g: np.broadcast_to(-g, np.shape(ans)),
)
defjvp(
    np.multiply,
    lambda ans, x, y: lambda g: np.broadcast_to(g * y, np.shape(ans)),
    lambda ans, x, y: lambda g: np.broadcast_to(x * g, np.shape(ans)),
)
defjvp(np.divide, "same", lambda ans, x, y: lambda g: -g * x / y**2)
defjvp(
    np.maximum,
Example #8
0
    gfpof = x1[1] / x1[0] * x2[0]
    return ftog * (gplogf + gfpof)


def matmul_diff(x1, x2):
    return x1[0] @ x2[1] + x1[1] @ x2[0]


def arctan2_diff(x1, x2):
    return (x1[1] * x2[0] - x1[0] * x2[1]) / (np.square(x1[0]) +
                                              np.square(x2[0]))


register_diff(
    np.sign,
    lambda x: np.broadcast_to(np.where(x[0].arr == 0, float("nan"), 0), x[1].
                              shape),
)
register_diff(np.add, lambda x1, x2: x1[1] + x2[1])
register_diff(np.subtract, lambda x1, x2: x1[1] - x2[1])
register_diff(np.multiply, multiply_diff)
register_diff(np.matmul, matmul_diff)
register_diff(np.divide, divide_diff)
register_diff(np.true_divide, divide_diff)
register_diff(np.power, pow_diff)
register_diff(
    np.absolute,
    lambda x: x[1] * np.where(np.sign(x[0]) == 0, float("nan"), np.sign(x[0])),
)
register_diff(np.positive, lambda x: +x[1])
register_diff(np.negative, lambda x: -x[1])
register_diff(np.conj, lambda x: np.conj(x[1]))
Example #9
0
    gplogf = np.log(x1[0]) * x2[1]
    gfpof = x1[1] / x1[0] * x2[0]
    return ftog * (gplogf + gfpof)


def matmul_diff(x1, x2):
    return x1[0] @ x2[1] + x1[1] @ x2[0]


def arctan2_diff(x1, x2):
    return (x1[1] * x2[0] - x1[0] * x2[1]) / (np.square(x1[0]) +
                                              np.square(x2[0]))


register_diff(
    np.sign, lambda x: np.broadcast_to(
        np.where(x[0].arr == 0, float('nan'), 0), x[1].shape))
register_diff(np.add, lambda x1, x2: x1[1] + x2[1])
register_diff(np.subtract, lambda x1, x2: x1[1] - x2[1])
register_diff(np.multiply, multiply_diff)
register_diff(np.matmul, matmul_diff)
register_diff(np.divide, divide_diff)
register_diff(np.true_divide, divide_diff)
register_diff(np.power, pow_diff)
register_diff(
    np.absolute,
    lambda x: x[1] * np.where(np.sign(x[0]) == 0, float('nan'), np.sign(x[0])))
register_diff(np.positive, lambda x: +x[1])
register_diff(np.negative, lambda x: -x[1])
register_diff(np.conj, lambda x: np.conj(x[1]))
register_diff(np.exp, lambda x: x[1] * np.exp(x[0]))
register_diff(np.exp2, lambda x: x[1] * np.log(2) * np.exp2(x[0]))