Esempio n. 1
0
 def test_omitted_args(self):
     ty0 = typeof(OmittedArg(0.0))
     ty1 = typeof(OmittedArg(1))
     ty2 = typeof(OmittedArg(1.0))
     ty3 = typeof(OmittedArg(1.0))
     self.assertEqual(ty0, types.Omitted(0.0))
     self.assertEqual(ty1, types.Omitted(1))
     self.assertEqual(ty2, types.Omitted(1.0))
     self.assertEqual(len({ty0, ty1, ty2}), 3)
     self.assertEqual(ty3, ty2)
Esempio n. 2
0
    def test_inline_call_branch_pruning(self):
        # branch pruning pass should run properly in inlining to enable
        # functions with type checks
        @njit
        def foo(A=None):
            if A is None:
                return 2
            else:
                return A

        def test_impl(A=None):
            return foo(A)

        @register_pass(analysis_only=False, mutates_CFG=True)
        class PruningInlineTestPass(FunctionPass):
            _name = "pruning_inline_test_pass"

            def __init__(self):
                FunctionPass.__init__(self)

            def run_pass(self, state):
                # assuming the function has one block with one call inside
                assert len(state.func_ir.blocks) == 1
                block = list(state.func_ir.blocks.values())[0]
                for i, stmt in enumerate(block.body):
                    if guard(find_callname, state.func_ir,
                             stmt.value) is not None:
                        inline_closure_call(
                            state.func_ir,
                            {},
                            block,
                            i,
                            foo.py_func,
                            state.typingctx,
                            (state.type_annotation.typemap[
                                stmt.value.args[0].name], ),
                            state.type_annotation.typemap,
                            state.calltypes,
                        )
                        break
                return True

        class InlineTestPipelinePrune(compiler.CompilerBase):
            def define_pipelines(self):
                pm = gen_pipeline(self.state, PruningInlineTestPass)
                pm.finalize()
                return [pm]

        # make sure inline_closure_call runs in full pipeline
        j_func = njit(pipeline_class=InlineTestPipelinePrune)(test_impl)
        A = 3
        self.assertEqual(test_impl(A), j_func(A))
        self.assertEqual(test_impl(), j_func())

        # make sure IR doesn't have branches
        fir = j_func.overloads[(
            types.Omitted(None), )].metadata["preserved_ir"]
        fir.blocks = simplify_CFG(fir.blocks)
        self.assertEqual(len(fir.blocks), 1)
Esempio n. 3
0
def _lit_or_omitted(value):
    """Returns a Literal instance if the type of value is supported;
    otherwise, return `Omitted(value)`.
    """
    try:
        return types.literal(value)
    except LiteralTypingError:
        return types.Omitted(value)
Esempio n. 4
0
    def test_omitted_arg(self):
        # See issue 7726
        @njit(debug=True)
        def foo(missing=None):
            pass

        # check that it will actually compile (verifies DI emission is ok)
        with override_config('DEBUGINFO_DEFAULT', 1):
            foo()

        metadata = self._get_metadata(foo, sig=(types.Omitted(None), ))
        metadata_definition_map = self._get_metadata_map(metadata)

        # Find DISubroutineType
        tmp_disubr = []
        for md in metadata:
            if "DISubroutineType" in md:
                tmp_disubr.append(md)
        self.assertEqual(len(tmp_disubr), 1)
        disubr = tmp_disubr.pop()

        disubr_matched = re.match(r'.*!DISubroutineType\(types: ([!0-9]+)\)$',
                                  disubr)
        self.assertIsNotNone(disubr_matched)
        disubr_groups = disubr_matched.groups()
        self.assertEqual(len(disubr_groups), 1)
        disubr_meta = disubr_groups[0]

        # Find the types in the DISubroutineType arg list
        disubr_types = metadata_definition_map[disubr_meta]
        disubr_types_matched = re.match(r'!{(.*)}', disubr_types)
        self.assertIsNotNone(disubr_matched)
        disubr_types_groups = disubr_types_matched.groups()
        self.assertEqual(len(disubr_types_groups), 1)

        # fetch out and assert the last argument type, should be void *
        md_fn_arg = [x.strip() for x in disubr_types_groups[0].split(',')][-1]
        arg_ty = metadata_definition_map[md_fn_arg]
        expected_arg_ty = (r'^.*!DICompositeType\(tag: DW_TAG_structure_type, '
                           r'name: "Anonymous struct \({}\)", elements: '
                           r'(![0-9]+), identifier: "{}"\)')
        self.assertRegex(arg_ty, expected_arg_ty)
        md_base_ty = re.match(expected_arg_ty, arg_ty).groups()[0]
        base_ty = metadata_definition_map[md_base_ty]
        # expect ir.LiteralStructType([])
        self.assertEqual(base_ty, ('!{}'))
Esempio n. 5
0
    def test_cond_is_kwarg_none(self):
        def impl(x=None):
            if x is None:
                y = 10
            else:
                y = 40

            if x is not None:
                z = 100
            else:
                z = 400

            return z, y

        self.assert_prune(impl, (types.Omitted(None), ), [False, True], None)
        self.assert_prune(impl, (types.NoneType('none'), ), [False, True],
                          None)
        self.assert_prune(impl, (types.IntegerLiteral(10), ), [True, False],
                          10)
Esempio n. 6
0
    def test_cond_is_kwarg_value(self):
        def impl(x=1000):
            if x == 1000:
                y = 10
            else:
                y = 40

            if x != 1000:
                z = 100
            else:
                z = 400

            return z, y

        self.assert_prune(impl, (types.Omitted(1000), ), [None, None], 1000)
        self.assert_prune(impl, (types.IntegerLiteral(1000), ), [None, None],
                          1000)
        self.assert_prune(impl, (types.IntegerLiteral(0), ), [None, None], 0)
        self.assert_prune(impl, (types.NoneType('none'), ), [True, False],
                          None)
Esempio n. 7
0
 def default_handler(index, param, default):
     return types.Omitted(default)
Esempio n. 8
0
 def _numba_type_(self):
     return types.Omitted(self.value)
Esempio n. 9
0
    def _compile_for_args(self, *args, **kws):
        """
        For internal use.  Compile a specialized version of the function
        for the given *args* and *kws*, and return the resulting callable.
        """
        assert not kws
        # call any initialisation required for the compilation chain (e.g.
        # extension point registration).
        self._compilation_chain_init_hook()

        def error_rewrite(e, issue_type):
            """
            Rewrite and raise Exception `e` with help supplied based on the
            specified issue_type.
            """
            if config.SHOW_HELP:
                help_msg = errors.error_extras[issue_type]
                e.patch_message('\n'.join((str(e).rstrip(), help_msg)))
            if config.FULL_TRACEBACKS:
                raise e
            else:
                reraise(type(e), e, None)

        argtypes = []
        for a in args:
            if isinstance(a, OmittedArg):
                argtypes.append(types.Omitted(a.value))
            else:
                argtypes.append(self.typeof_pyval(a))
        try:
            return self.compile(tuple(argtypes))
        except errors.ForceLiteralArg as e:
            # Received request for compiler re-entry with the list of arguments
            # indicated by e.requested_args.
            # First, check if any of these args are already Literal-ized
            already_lit_pos = [i for i in e.requested_args
                               if isinstance(args[i], types.Literal)]
            if already_lit_pos:
                # Abort compilation if any argument is already a Literal.
                # Letting this continue will cause infinite compilation loop.
                m = ("Repeated literal typing request.\n"
                     "{}.\n"
                     "This is likely caused by an error in typing. "
                     "Please see nested and suppressed exceptions.")
                info = ', '.join('Arg #{} is {}'.format(i, args[i])
                                 for i in  sorted(already_lit_pos))
                raise errors.CompilerError(m.format(info))
            # Convert requested arguments into a Literal.
            args = [(types.literal
                     if i in e.requested_args
                     else lambda x: x)(args[i])
                    for i, v in enumerate(args)]
            # Re-enter compilation with the Literal-ized arguments
            return self._compile_for_args(*args)

        except errors.TypingError as e:
            # Intercept typing error that may be due to an argument
            # that failed inferencing as a Numba type
            failed_args = []
            for i, arg in enumerate(args):
                val = arg.value if isinstance(arg, OmittedArg) else arg
                try:
                    tp = typeof(val, Purpose.argument)
                except ValueError as typeof_exc:
                    failed_args.append((i, str(typeof_exc)))
                else:
                    if tp is None:
                        failed_args.append(
                            (i,
                             "cannot determine Numba type of value %r" % (val,)))
            if failed_args:
                # Patch error message to ease debugging
                msg = str(e).rstrip() + (
                    "\n\nThis error may have been caused by the following argument(s):\n%s\n"
                    % "\n".join("- argument %d: %s" % (i, err)
                                for i, err in failed_args))
                e.patch_message(msg)

            error_rewrite(e, 'typing')
        except errors.UnsupportedError as e:
            # Something unsupported is present in the user code, add help info
            error_rewrite(e, 'unsupported_error')
        except (errors.NotDefinedError, errors.RedefinedError,
                errors.VerificationError) as e:
            # These errors are probably from an issue with either the code supplied
            # being syntactically or otherwise invalid
            error_rewrite(e, 'interpreter')
        except errors.ConstantInferenceError as e:
            # this is from trying to infer something as constant when it isn't
            # or isn't supported as a constant
            error_rewrite(e, 'constant_inference')
        except Exception as e:
            if config.SHOW_HELP:
                if hasattr(e, 'patch_message'):
                    help_msg = errors.error_extras['reportable']
                    e.patch_message('\n'.join((str(e).rstrip(), help_msg)))
            # ignore the FULL_TRACEBACKS config, this needs reporting!
            raise e