예제 #1
0
def numba_funcify_QRFull(op, node, **kwargs):
    mode = op.mode

    if mode != "reduced":
        warnings.warn(
            ("Numba will use object mode to allow the "
             "`mode` argument to `numpy.linalg.qr`."),
            UserWarning,
        )

        if len(node.outputs) > 1:
            ret_sig = numba.types.Tuple(
                [get_numba_type(o.type) for o in node.outputs])
        else:
            ret_sig = get_numba_type(node.outputs[0].type)

        @numba.njit
        def qr_full(x):
            with numba.objmode(ret=ret_sig):
                ret = np.linalg.qr(x, mode=mode)
            return ret

    else:

        out_dtype = node.outputs[0].type.numpy_dtype
        inputs_cast = int_to_float_fn(node.inputs, out_dtype)

        @numba.njit(inline="always")
        def qr_full(x):
            return np.linalg.qr(inputs_cast(x))

    return qr_full
예제 #2
0
def numba_funcify_SVD(op, node, **kwargs):
    full_matrices = op.full_matrices
    compute_uv = op.compute_uv

    if not compute_uv:

        warnings.warn(
            ("Numba will use object mode to allow the "
             "`compute_uv` argument to `numpy.linalg.svd`."),
            UserWarning,
        )

        ret_sig = get_numba_type(node.outputs[0].type)

        @numba.njit
        def svd(x):
            with numba.objmode(ret=ret_sig):
                ret = np.linalg.svd(x, full_matrices, compute_uv)
            return ret

    else:

        out_dtype = node.outputs[0].type.numpy_dtype
        inputs_cast = int_to_float_fn(node.inputs, out_dtype)

        @numba.njit(inline="always")
        def svd(x):
            return np.linalg.svd(inputs_cast(x), full_matrices)

    return svd
예제 #3
0
def numba_funcify_MatrixPinv(op, node, **kwargs):

    out_dtype = node.outputs[0].type.numpy_dtype
    inputs_cast = int_to_float_fn(node.inputs, out_dtype)

    @numba.njit(inline="always")
    def matrixpinv(x):
        return np.linalg.pinv(inputs_cast(x)).astype(out_dtype)

    return matrixpinv
예제 #4
0
def numba_funcify_Det(op, node, **kwargs):

    out_dtype = node.outputs[0].type.numpy_dtype
    inputs_cast = int_to_float_fn(node.inputs, out_dtype)

    @numba.njit(inline="always")
    def det(x):
        return numba_basic.direct_cast(np.linalg.det(inputs_cast(x)),
                                       out_dtype)

    return det
예제 #5
0
def numba_funcify_Eig(op, node, **kwargs):

    out_dtype_1 = node.outputs[0].type.numpy_dtype
    out_dtype_2 = node.outputs[1].type.numpy_dtype

    inputs_cast = int_to_float_fn(node.inputs, out_dtype_1)

    @numba.njit
    def eig(x):
        out = np.linalg.eig(inputs_cast(x))
        return (out[0].astype(out_dtype_1), out[1].astype(out_dtype_2))

    return eig