예제 #1
0
def _sort_check_key(key):
    if isinstance(key, types.Optional):
        msg = ("Key must concretely be None or a Numba JIT compiled function, "
               "an Optional (union of None and a value) was found")
        raise errors.TypingError(msg)
    if not (cgutils.is_nonelike(key) or isinstance(key, types.Dispatcher)):
        msg = "Key must be None or a Numba JIT compiled function"
        raise errors.TypingError(msg)
예제 #2
0
    def mutate_with_body(self, func_ir, blocks, blk_start, blk_end,
                         body_blocks, dispatcher_factory, extra):
        typeanns = self._legalize_args(extra, loc=blocks[blk_start].loc)
        vlt = func_ir.variable_lifetime

        inputs, outputs = find_region_inout_vars(
            blocks=blocks,
            livemap=vlt.livemap,
            callfrom=blk_start,
            returnto=blk_end,
            body_block_ids=set(body_blocks),
            )

        # Determine types in the output tuple
        def strip_var_ver(x):
            return x.split('.', 1)[0]

        stripped_outs = list(map(strip_var_ver, outputs))

        # Verify that only outputs are annotated
        extra_annotated = set(typeanns) - set(stripped_outs)
        if extra_annotated:
            msg = (
                'Invalid type annotation on non-outgoing variables: {}.'
                'Suggestion: remove annotation of the listed variables'
            )
            raise errors.TypingError(msg.format(extra_annotated))

        # Verify that all outputs are annotated
        not_annotated = set(stripped_outs) - set(typeanns)
        if not_annotated:
            msg = 'missing type annotation on outgoing variables: {}'
            raise errors.TypingError(msg.format(not_annotated))

        # Get output types
        outtup = types.Tuple([typeanns[v] for v in stripped_outs])

        lifted_blks = {k: blocks[k] for k in body_blocks}
        _mutate_with_block_callee(lifted_blks, blk_start, blk_end,
                                  inputs, outputs)

        lifted_ir = func_ir.derive(
            blocks=lifted_blks,
            arg_names=tuple(inputs),
            arg_count=len(inputs),
            force_non_generator=True,
            )

        dispatcher = dispatcher_factory(lifted_ir, objectmode=True,
                                        output_types=outtup)

        newblk = _mutate_with_block_caller(
            dispatcher, blocks, blk_start, blk_end, inputs, outputs,
            )

        blocks[blk_start] = newblk
        _clear_blocks(blocks, body_blocks)
        return dispatcher
예제 #3
0
 def _legalize_arg_types(self, args):
     for i, a in enumerate(args, start=1):
         if isinstance(a, types.List):
             msg = ('Does not support list type inputs into '
                    'with-context for arg {}')
             raise errors.TypingError(msg.format(i))
         elif isinstance(a, types.Dispatcher):
             msg = ('Does not support function type inputs into '
                    'with-context for arg {}')
             raise errors.TypingError(msg.format(i))
예제 #4
0
    def get_call_type_with_literals(self, context, args, kws, literals=None):
        failures = _ResolutionFailures(context, self, args, kws)
        for temp_cls in self.templates:
            temp = temp_cls(context)
            for support_literals in [True, False]:
                try:
                    if support_literals:
                        sig = temp.apply(args, kws)
                    else:
                        nolitargs = tuple([unliteral(a) for a in args])
                        nolitkws = {k: unliteral(v) for k, v in kws.items()}
                        sig = temp.apply(nolitargs, nolitkws)
                except Exception as e:
                    sig = None
                    failures.add_error(temp_cls, e)
                else:
                    if sig is not None:
                        self._impl_keys[sig.args] = temp.get_impl_key(sig)
                        return sig
                    else:
                        haslit = '' if support_literals else 'out'
                        msg = "All templates rejected with%s literals." % haslit
                        failures.add_error(temp_cls, msg)

        if len(failures) == 0:
            raise AssertionError("Internal Error. "
                                 "Function resolution ended with no failures "
                                 "or successfull signature")

        raise errors.TypingError(failures.format())
    def get_call_type_with_literals(self, context, args, kws, literals):
        failures = _ResolutionFailures(context, self, args, kws)
        for temp_cls in self.templates:
            temp = temp_cls(context)
            try:
                if literals is not None and temp.support_literals:
                    sig = temp.apply(*literals)
                else:
                    sig = temp.apply(args, kws)
            except Exception as e:
                sig = None
                failures.add_error(temp_cls, e)
            else:
                if sig is not None:
                    self._impl_keys[sig.args] = temp.get_impl_key(sig)
                    return sig
                else:
                    failures.add_error(temp_cls, "All templates rejected")

        if len(failures) == 0:
            raise AssertionError("Internal Error. "
                                 "Function resolution ended with no failures "
                                 "or successfull signature")

        raise errors.TypingError(failures.format())
예제 #6
0
def _sort_check_reverse(reverse):
    if isinstance(reverse, types.Omitted):
        rty = reverse.value
    elif isinstance(reverse, types.Optional):
        rty = reverse.type
    else:
        rty = reverse
    if not isinstance(rty, (types.Boolean, types.Integer, int, bool)):
        msg = "an integer is required for 'reverse' (got type %s)" % reverse
        raise errors.TypingError(msg)
    return rty
예제 #7
0
 def add_return_type(self, return_type):
     """Add *return_type* to the list of inferred return-types.
     If there are too many, raise `TypingError`.
     """
     # The maximum limit is picked arbitrarily.
     # Don't think that this needs to be user configurable.
     RETTY_LIMIT = 16
     self._inferred_retty.add(return_type)
     if len(self._inferred_retty) >= RETTY_LIMIT:
         m = "Return type of recursive function does not converge"
         raise errors.TypingError(m)
예제 #8
0
 def generic(self, args, kws):
     assert not kws
     [obj] = args
     if isinstance(obj, types.IterableType):
         # Raise this here to provide a very specific message about this
         # common issue, delaying the error until later leads to something
         # less specific being noted as the problem (e.g. no support for
         # getiter on array(<>, 2, 'C')).
         if isinstance(obj, types.Array) and obj.ndim > 1:
             msg = ("Direct iteration is not supported for arrays with "
                    "dimension > 1. Try using indexing instead.")
             raise errors.TypingError(msg)
         else:
             return signature(obj.iterator_type, obj)
예제 #9
0
    def resolve___call__(self, classty):
        """
        Resolve a number class's constructor (e.g. calling int(...))
        """
        ty = classty.instance_type

        if not isinstance(ty, types.Number):
            raise errors.TypingError("invalid use of non-number types")

        def typer(val):
            # Scalar constructor, e.g. int32(42)
            return ty

        return types.Function(make_callable_template(key=ty, typer=typer))
예제 #10
0
def overload_where_scalars(cond, x, y):
    """
    Implement where() for scalars.
    """
    if not isinstance(cond, types.Array):
        if x != y:
            raise errors.TypingError("x and y should have the same type")

        def where_impl(cond, x, y):
            """
            Scalar where() => return a 0-dim array
            """
            scal = x if cond else y
            # Can't use full_like() on Numpy < 1.8
            arr = np.empty_like(scal)
            arr[()] = scal
            return arr

        return where_impl
예제 #11
0
def overload_where_arrays(cond, x, y):
    """
    Implement where() for arrays.
    """
    # Choose implementation based on argument types.
    if isinstance(cond, types.Array):
        if x.dtype != y.dtype:
            raise errors.TypingError("x and y should have the same dtype")

        # Array where() => return an array of the same shape
        if all(ty.layout == 'C' for ty in (cond, x, y)):

            def where_impl(cond, x, y):
                """
                Fast implementation for C-contiguous arrays
                """
                shape = cond.shape
                if x.shape != shape or y.shape != shape:
                    raise ValueError("all inputs should have the same shape")
                res = np.empty_like(x)
                cf = cond.flat
                xf = x.flat
                yf = y.flat
                rf = res.flat
                for i in range(cond.size):
                    rf[i] = xf[i] if cf[i] else yf[i]
                return res
        else:

            def where_impl(cond, x, y):
                """
                Generic implementation for other arrays
                """
                shape = cond.shape
                if x.shape != shape or y.shape != shape:
                    raise ValueError("all inputs should have the same shape")
                res = np.empty_like(x)
                for idx, c in np.ndenumerate(cond):
                    res[idx] = x[idx] if c else y[idx]
                return res

        return where_impl
예제 #12
0
def _inject_hashsecret_read(tyctx, name):
    """Emit code to load the hashsecret.
    """
    if not isinstance(name, types.StringLiteral):
        raise errors.TypingError("requires literal string")

    sym = _hashsecret[name.literal_value].symbol
    resty = types.uint64
    sig = resty(name)

    def impl(cgctx, builder, sig, args):
        mod = builder.module
        try:
            # Search for existing global
            gv = mod.get_global(sym)
        except KeyError:
            # Inject the symbol if not already exist.
            gv = ir.GlobalVariable(mod, ir.IntType(64), name=sym)
        v = builder.load(gv)
        return v

    return sig, impl
예제 #13
0
 def raise_error(self):
     for _tempcls, e in self._failures:
         if isinstance(e, errors.ForceLiteralArg):
             raise e
     raise errors.TypingError(self.format())