示例#1
0
def numba_funcify_BroadcastTo(op, node, **kwargs):

    create_zeros_tuple = numba_basic.create_tuple_creator(
        lambda _: 0, len(node.inputs) - 1
    )

    @numba_basic.numba_njit
    def broadcast_to(x, *shape):
        scalars_shape = create_zeros_tuple()

        i = 0
        for s_i in literal_unroll(shape):
            scalars_shape = numba_basic.tuple_setitem(
                scalars_shape, i, numba_basic.to_scalar(s_i)
            )
            i += 1

        return np.broadcast_to(x, scalars_shape)

    return broadcast_to
示例#2
0
def create_axis_reducer(
    reduce_fn: numba.np.ufunc.dufunc.DUFunc,
    identity: Union[np.ndarray, Number],
    axis: int,
    ndim: int,
    dtype: numba.types.Type,
    keepdims: bool = False,
) -> numba.core.dispatcher.Dispatcher:
    r"""Create a Numba JITed function that performs a NumPy reduction on a given axis.

    The functions generated by this function take the following form:

    .. code-block:: python

        def careduce_axis(x):
            res_shape = tuple(shape[i] if i < axis else shape[i + 1] for i in range(ndim - 1))
            res = np.full(res_shape, identity, dtype=dtype)

            x_axis_first = x.transpose(reaxis_first)

            for m in range(x.shape[axis]):
                reduce_fn(res, x_axis_first[m], res)

            if keepdims:
                return np.expand_dims(res, axis)
            else:
                return res


    This can be removed/replaced when
    https://github.com/numba/numba/issues/4504 is implemented.

    Parameters
    ==========
    reduce_fn:
        The Numba ``ufunc`` representing a binary op that can perform the
        reduction on arbitrary ``ndarray``\s.
    identity:
        The identity value for the reduction.
    axis:
        The axis to reduce.
    ndim:
        The number of dimensions of the result.
    dtype:
        The data type of the result.
    keepdims:
        Determines whether or not the reduced dimension is retained.
    """
    if ndim > 1:

        if keepdims:

            @numba.njit(inline="always")
            def set_out_dims(x):
                return np.expand_dims(x, axis)

        else:

            @numba.njit(inline="always")
            def set_out_dims(x):
                return x

        res_shape_tuple_ctor = create_tuple_creator(
            lambda i, shape: shape[i] if i < axis else shape[i + 1], ndim - 1)

        reaxis_first = (axis, ) + tuple(i for i in range(ndim) if i != axis)

        @numba.njit(boundscheck=False)
        def careduce_axis(x):
            res_shape = res_shape_tuple_ctor(x.shape)
            x_axis_first = x.transpose(reaxis_first)

            res = np.full(res_shape,
                          numba_basic.to_scalar(identity),
                          dtype=dtype)
            for m in range(x.shape[axis]):
                reduce_fn(res, x_axis_first[m], res)

            return set_out_dims(res)

    else:

        if keepdims:

            @numba.njit(inline="always")
            def set_out_dims(x):
                return np.array([x], dtype)

        else:

            @numba.njit(inline="always")
            def set_out_dims(x):
                return numba_basic.direct_cast(x, dtype)

        @numba.njit(boundscheck=False)
        def careduce_axis(x):
            res = numba_basic.to_scalar(identity)
            for val in x:
                res = reduce_fn(res, val)
            return set_out_dims(res)

    return careduce_axis
示例#3
0
def numba_funcify_DimShuffle(op, **kwargs):
    shuffle = tuple(op.shuffle)
    transposition = tuple(op.transposition)
    augment = tuple(op.augment)
    inplace = op.inplace

    ndim_new_shape = len(shuffle) + len(augment)

    if len(shuffle) > 0:

        @numba.njit
        def populate_new_shape(i, j, new_shape, shuffle_shape):
            if i in augment:
                new_shape = tuple_setitem(new_shape, i, 1)
                return j, new_shape
            else:
                new_shape = tuple_setitem(new_shape, i, shuffle_shape[j])
                return j + 1, new_shape

    else:
        # When `len(shuffle) == 0`, the `shuffle_shape[j]` expression above is
        # is typed as `getitem(Tuple(), int)`, which has no implementation
        # (since getting an item from an empty sequence doesn't make sense).
        # To avoid this compile-time error, we omit the expression altogether.
        @numba.njit(inline="always")
        def populate_new_shape(i, j, new_shape, shuffle_shape):
            return j, tuple_setitem(new_shape, i, 1)

    if ndim_new_shape > 0:
        create_zeros_tuple = numba_basic.create_tuple_creator(
            lambda _: 0, ndim_new_shape)

        @numba.njit
        def dimshuffle_inner(x, shuffle):
            res = np.transpose(x, transposition)
            shuffle_shape = res.shape[:len(shuffle)]

            new_shape = create_zeros_tuple()

            j = 0
            for i in range(len(new_shape)):
                j, new_shape = populate_new_shape(i, j, new_shape,
                                                  shuffle_shape)

            # FIXME: Numba's `array.reshape` only accepts C arrays.
            res_reshape = np.reshape(np.ascontiguousarray(res), new_shape)

            if not inplace:
                return res_reshape.copy()
            else:
                return res_reshape

    else:

        @numba.njit
        def dimshuffle_inner(x, shuffle):
            return x.item()

    # Without the following wrapper function we would see this error:
    # E   No implementation of function Function(<built-in function getitem>) found for signature:
    # E
    # E    >>> getitem(UniTuple(int64 x 2), slice<a:b>)
    # E
    # E   There are 22 candidate implementations:
    # E      - Of which 22 did not match due to:
    # E      Overload of function 'getitem': File: <numerous>: Line N/A.
    # E        With argument(s): '(UniTuple(int64 x 2), slice<a:b>)':
    # E       No match.
    # ...(on this line)...
    # E           shuffle_shape = res.shape[: len(shuffle)]
    @numba.njit(inline="always")
    def dimshuffle(x):
        return dimshuffle_inner(np.asarray(x), shuffle)

    return dimshuffle
示例#4
0
文件: elemwise.py 项目: mgorny/aesara
def create_axis_reducer(
    scalar_op: Op,
    identity: Union[np.ndarray, Number],
    axis: int,
    ndim: int,
    dtype: numba.types.Type,
    keepdims: bool = False,
) -> numba.core.dispatcher.Dispatcher:
    r"""Create Python function that performs a NumPy-like reduction on a given axis.

    The functions generated by this function take the following form:

    .. code-block:: python

        def careduce_axis(x):
            res_shape = tuple(shape[i] if i < axis else shape[i + 1] for i in range(ndim - 1))
            res = np.full(res_shape, identity, dtype=dtype)

            x_axis_first = x.transpose(reaxis_first)

            for m in range(x.shape[axis]):
                reduce_fn(res, x_axis_first[m], res)

            if keepdims:
                return np.expand_dims(res, axis)
            else:
                return res


    This can be removed/replaced when
    https://github.com/numba/numba/issues/4504 is implemented.

    Parameters
    ==========
    scalar_op:
        The scalar :class:`Op` that performs the desired reduction.
    identity:
        The identity value for the reduction.
    axis:
        The axis to reduce.
    ndim:
        The number of dimensions of the result.
    dtype:
        The data type of the result.
    keepdims:
        Determines whether or not the reduced dimension is retained.


    Returns
    =======
    A Python function that can be JITed.

    """

    reduce_elemwise_fn_name = "careduce_axis"

    identity = str(identity)
    if identity == "inf":
        identity = "np.inf"
    elif identity == "-inf":
        identity = "-np.inf"

    global_env = {
        "np": np,
        "numba_basic": numba_basic,
        "out_dtype": dtype,
    }

    if ndim > 1:
        res_shape_tuple_ctor = create_tuple_creator(
            lambda i, shape: shape[i] if i < axis else shape[i + 1], ndim - 1)
        global_env["res_shape_tuple_ctor"] = res_shape_tuple_ctor

        res_indices = []
        arr_indices = []
        count = 0

        for i in range(ndim):
            if i == axis:
                arr_indices.append("i")
            else:
                res_indices.append(f"idx_arr[{count}]")
                arr_indices.append(f"idx_arr[{count}]")
                count = count + 1

        res_indices = ", ".join(res_indices)
        arr_indices = ", ".join(arr_indices)

        inplace_update_statement = scalar_in_place_fn(scalar_op, res_indices,
                                                      "res",
                                                      f"x[{arr_indices}]")
        inplace_update_statement = indent(inplace_update_statement,
                                          " " * 4 * 3)

        return_expr = f"np.expand_dims(res, {axis})" if keepdims else "res"
        reduce_elemwise_def_src = f"""
def {reduce_elemwise_fn_name}(x):

    x_shape = np.shape(x)
    res_shape = res_shape_tuple_ctor(x_shape)
    res = np.full(res_shape, numba_basic.to_scalar({identity}), dtype=out_dtype)

    axis_shape = x.shape[{axis}]

    for idx_arr in np.ndindex(res_shape):
        for i in range(axis_shape):
{inplace_update_statement}

    return {return_expr}
        """
    else:
        inplace_update_statement = scalar_in_place_fn(scalar_op, "0", "res",
                                                      "x[i]")
        inplace_update_statement = indent(inplace_update_statement,
                                          " " * 4 * 2)

        return_expr = "res" if keepdims else "res.item()"
        reduce_elemwise_def_src = f"""
def {reduce_elemwise_fn_name}(x):

    res = np.full(1, numba_basic.to_scalar({identity}), dtype=out_dtype)

    axis_shape = x.shape[{axis}]

    for i in range(axis_shape):
{inplace_update_statement}

    return {return_expr}
        """

    reduce_elemwise_fn_py = compile_function_src(reduce_elemwise_def_src,
                                                 reduce_elemwise_fn_name, {
                                                     **globals(),
                                                     **global_env
                                                 })

    return reduce_elemwise_fn_py