Exemplo n.º 1
0
def _register_ufuncs():
    kernels = {}

    for ufunc in ufunc_db.get_ufuncs():
        kernels[ufunc] = register_ufunc_kernel(ufunc, _ufunc_db_function(ufunc))

    for _op_map in (npydecl.NumpyRulesUnaryArrayOperator._op_map,
                    npydecl.NumpyRulesArrayOperator._op_map,
                    ):
        for operator, ufunc_name in _op_map.items():
            ufunc = getattr(np, ufunc_name)
            kernel = kernels[ufunc]
            if ufunc.nin == 1:
                register_unary_operator_kernel(operator, kernel)
            elif ufunc.nin == 2:
                register_binary_operator_kernel(operator, kernel)
            else:
                raise RuntimeError("There shouldn't be any non-unary or binary operators")

    for _op_map in (npydecl.NumpyRulesInplaceArrayOperator._op_map,
                    ):
        for operator, ufunc_name in _op_map.items():
            ufunc = getattr(np, ufunc_name)
            kernel = kernels[ufunc]
            if ufunc.nin == 1:
                register_unary_operator_kernel(operator, kernel, inplace=True)
            elif ufunc.nin == 2:
                register_binary_operator_kernel(operator, kernel, inplace=True)
            else:
                raise RuntimeError("There shouldn't be any non-unary or binary operators")
Exemplo n.º 2
0
    _any = types.Any
    _arr_kind = types.Array
    formal_sigs = [(_arr_kind, _arr_kind), (_any, _arr_kind),
                   (_arr_kind, _any)]
    for sig in formal_sigs:
        if not inplace:
            lower(op, *sig)(lower_binary_operator)
        else:
            lower(op, *sig)(lower_inplace_operator)


################################################################################
# Use the contents of ufunc_db to initialize the supported ufuncs

for ufunc in ufunc_db.get_ufuncs():
    if ufunc.nin == 1:
        register_unary_ufunc_kernel(ufunc, _ufunc_db_function(ufunc))
    elif ufunc.nin == 2:
        register_binary_ufunc_kernel(ufunc, _ufunc_db_function(ufunc))
    else:
        raise RuntimeError(
            "Don't know how to register ufuncs from ufunc_db with arity > 2")


@lower(operator.pos, types.Array)
def array_positive_impl(context, builder, sig, args):
    '''Lowering function for +(array) expressions.  Defined here
    (numba.targets.npyimpl) since the remaining array-operator
    lowering functions are also registered in this module.
    '''