Пример #1
0
 def test_omitted_args(self):
     ty0 = typeof(OmittedArg(0.0))
     ty1 = typeof(OmittedArg(1))
     ty2 = typeof(OmittedArg(1.0))
     ty3 = typeof(OmittedArg(1.0))
     self.assertEqual(ty0, types.Omitted(0.0))
     self.assertEqual(ty1, types.Omitted(1))
     self.assertEqual(ty2, types.Omitted(1.0))
     self.assertEqual(len({ty0, ty1, ty2}), 3)
     self.assertEqual(ty3, ty2)
Пример #2
0
    def test_inline_call_branch_pruning(self):
        # branch pruning pass should run properly in inlining to enable
        # functions with type checks
        @njit
        def foo(A=None):
            if A is None:
                return 2
            else:
                return A

        def test_impl(A=None):
            return foo(A)

        @register_pass(analysis_only=False, mutates_CFG=True)
        class PruningInlineTestPass(FunctionPass):
            _name = "pruning_inline_test_pass"

            def __init__(self):
                FunctionPass.__init__(self)

            def run_pass(self, state):
                # assuming the function has one block with one call inside
                assert len(state.func_ir.blocks) == 1
                block = list(state.func_ir.blocks.values())[0]
                for i, stmt in enumerate(block.body):
                    if (guard(find_callname, state.func_ir, stmt.value)
                            is not None):
                        inline_closure_call(state.func_ir, {}, block, i,
                                            foo.py_func, state.typingctx,
                                            (state.type_annotation.typemap[
                                                stmt.value.args[0].name], ),
                                            state.type_annotation.typemap,
                                            state.calltypes)
                        break
                return True

        class InlineTestPipelinePrune(numba.compiler.CompilerBase):
            def define_pipelines(self):
                pm = gen_pipeline(self.state, PruningInlineTestPass)
                pm.finalize()
                return [pm]

        # make sure inline_closure_call runs in full pipeline
        j_func = njit(pipeline_class=InlineTestPipelinePrune)(test_impl)
        A = 3
        self.assertEqual(test_impl(A), j_func(A))
        self.assertEqual(test_impl(), j_func())

        # make sure IR doesn't have branches
        fir = j_func.overloads[(
            types.Omitted(None), )].metadata['preserved_ir']
        fir.blocks = numba.ir_utils.simplify_CFG(fir.blocks)
        self.assertEqual(len(fir.blocks), 1)
Пример #3
0
 def _compile_for_args(self, *args, **kws):
     """
     For internal use.  Compile a specialized version of the function
     for the given *args* and *kws*, and return the resulting callable.
     """
     assert not kws
     real_args = []
     for a in args:
         if isinstance(a, OmittedArg):
             real_args.append(types.Omitted(a.value))
         else:
             real_args.append(self.typeof_pyval(a))
     return self.compile(tuple(real_args))
    def test_cond_is_kwarg_none(self):

        def impl(x=None):
            if x is None:
                y = 10
            else:
                y = 40

            if x is not None:
                z = 100
            else:
                z = 400

            return z, y

        self.assert_prune(impl, (types.Omitted(None),),
                          [False, True], None)
        self.assert_prune(impl, (types.NoneType('none'),), [False, True], None)
        self.assert_prune(impl, (types.IntegerLiteral(10),), [True, False], 10)
    def test_cond_is_kwarg_value(self):

        def impl(x=1000):
            if x == 1000:
                y = 10
            else:
                y = 40

            if x != 1000:
                z = 100
            else:
                z = 400

            return z, y

        self.assert_prune(impl, (types.Omitted(1000),), [None, None], 1000)
        self.assert_prune(impl, (types.IntegerLiteral(1000),), [None, None],
                          1000)
        self.assert_prune(impl, (types.IntegerLiteral(0),), [None, None], 0)
        self.assert_prune(impl, (types.NoneType('none'),), [True, False], None)
Пример #6
0
    def test_inline_call_branch_pruning(self):
        # branch pruning pass should run properly in inlining to enable
        # functions with type checks
        @njit
        def foo(A=None):
            if A is None:
                return 2
            else:
                return A

        def test_impl(A=None):
            return foo(A)

        class InlineTestPipelinePrune(InlineTestPipeline):
            def stage_inline_test_pass(self):
                # assuming the function has one block with one call inside
                assert len(self.func_ir.blocks) == 1
                block = list(self.func_ir.blocks.values())[0]
                for i, stmt in enumerate(block.body):
                    if (guard(find_callname, self.func_ir, stmt.value)
                            is not None):
                        inline_closure_call(
                            self.func_ir, {}, block, i, foo.py_func,
                            self.typingctx,
                            (self.typemap[stmt.value.args[0].name], ),
                            self.typemap, self.calltypes)
                        break

        # make sure inline_closure_call runs in full pipeline
        j_func = njit(pipeline_class=InlineTestPipelinePrune)(test_impl)
        A = 3
        self.assertEqual(test_impl(A), j_func(A))
        self.assertEqual(test_impl(), j_func())

        # make sure IR doesn't have branches
        fir = j_func.overloads[(
            types.Omitted(None), )].metadata['final_func_ir']
        fir.blocks = numba.ir_utils.simplify_CFG(fir.blocks)
        self.assertEqual(len(fir.blocks), 1)
Пример #7
0
 def _compile_for_args(self, *args, **kws):
     """
     For internal use.  Compile a specialized version of the function
     for the given *args* and *kws*, and return the resulting callable.
     """
     assert not kws
     argtypes = []
     for a in args:
         if isinstance(a, OmittedArg):
             argtypes.append(types.Omitted(a.value))
         else:
             argtypes.append(self.typeof_pyval(a))
     try:
         return self.compile(tuple(argtypes))
     except errors.TypingError as e:
         # Intercept typing error that may be due to an argument
         # that failed inferencing as a Numba type
         failed_args = []
         for i, arg in enumerate(args):
             val = arg.value if isinstance(arg, OmittedArg) else arg
             try:
                 tp = typeof(val, Purpose.argument)
             except ValueError as typeof_exc:
                 failed_args.append((i, str(typeof_exc)))
             else:
                 if tp is None:
                     failed_args.append(
                         (i,
                          "cannot determine Numba type of value %r" % (val,)))
         if failed_args:
             # Patch error message to ease debugging
             msg = str(e).rstrip() + (
                 "\n\nThis error may have been caused by the following argument(s):\n%s\n"
                 % "\n".join("- argument %d: %s" % (i, err)
                             for i, err in failed_args))
             e.patch_message(msg)
         raise e
Пример #8
0
 def default_handler(index, param, default):
     return types.Omitted(default)
Пример #9
0
 def _numba_type_(self):
     return types.Omitted(self.value)
Пример #10
0
    def _compile_for_args(self, *args, **kws):
        """
        For internal use.  Compile a specialized version of the function
        for the given *args* and *kws*, and return the resulting callable.
        """
        assert not kws

        def error_rewrite(e, issue_type):
            """
            Rewrite and raise Exception `e` with help supplied based on the
            specified issue_type.
            """
            if config.SHOW_HELP:
                help_msg = errors.error_extras[issue_type]
                e.patch_message(''.join(e.args) + help_msg)
            if config.FULL_TRACEBACKS:
                raise e
            else:
                reraise(type(e), e, None)

        argtypes = []
        for a in args:
            if isinstance(a, OmittedArg):
                argtypes.append(types.Omitted(a.value))
            else:
                argtypes.append(self.typeof_pyval(a))
        try:
            return self.compile(tuple(argtypes))
        except errors.TypingError as e:
            # Intercept typing error that may be due to an argument
            # that failed inferencing as a Numba type
            failed_args = []
            for i, arg in enumerate(args):
                val = arg.value if isinstance(arg, OmittedArg) else arg
                try:
                    tp = typeof(val, Purpose.argument)
                except ValueError as typeof_exc:
                    failed_args.append((i, str(typeof_exc)))
                else:
                    if tp is None:
                        failed_args.append(
                            (i, "cannot determine Numba type of value %r" %
                             (val, )))
            if failed_args:
                # Patch error message to ease debugging
                msg = str(e).rstrip() + (
                    "\n\nThis error may have been caused by the following argument(s):\n%s\n"
                    % "\n".join("- argument %d: %s" % (i, err)
                                for i, err in failed_args))
                e.patch_message(msg)

            error_rewrite(e, 'typing')
        except errors.UnsupportedError as e:
            # Something unsupported is present in the user code, add help info
            error_rewrite(e, 'unsupported_error')
        except (errors.NotDefinedError, errors.RedefinedError,
                errors.VerificationError) as e:
            # These errors are probably from an issue with either the code supplied
            # being syntactically or otherwise invalid
            error_rewrite(e, 'interpreter')
        except errors.ConstantInferenceError as e:
            # this is from trying to infer something as constant when it isn't
            # or isn't supported as a constant
            error_rewrite(e, 'constant_inference')
        except Exception as e:
            if config.SHOW_HELP:
                if hasattr(e, 'patch_message'):
                    help_msg = errors.error_extras['reportable']
                    e.patch_message(''.join(e.args) + help_msg)
            # ignore the FULL_TRACEBACKS config, this needs reporting!
            raise e
Пример #11
0
    def _compile_for_args(self, *args, **kws):
        """
        For internal use.  Compile a specialized version of the function
        for the given *args* and *kws*, and return the resulting callable.
        """
        assert not kws

        def error_rewrite(e, issue_type):
            """
            Rewrite and raise Exception `e` with help supplied based on the
            specified issue_type.
            """
            if config.SHOW_HELP:
                help_msg = errors.error_extras[issue_type]
                e.patch_message('\n'.join((str(e).rstrip(), help_msg)))
            if config.FULL_TRACEBACKS:
                raise e
            else:
                reraise(type(e), e, None)

        argtypes = []
        for a in args:
            if isinstance(a, OmittedArg):
                argtypes.append(types.Omitted(a.value))
            else:
                argtypes.append(self.typeof_pyval(a))
        try:
            return self.compile(tuple(argtypes))
        except errors.ForceLiteralArg as e:
            # Received request for compiler re-entry with the list of arguments
            # indicated by e.requested_args.
            # First, check if any of these args are already Literal-ized
            already_lit_pos = [
                i for i in e.requested_args
                if isinstance(args[i], types.Literal)
            ]
            if already_lit_pos:
                # Abort compilation if any argument is already a Literal.
                # Letting this continue will cause infinite compilation loop.
                m = ("Repeated literal typing request.\n"
                     "{}.\n"
                     "This is likely caused by an error in typing. "
                     "Please see nested and suppressed exceptions.")
                info = ', '.join('Arg #{} is {}'.format(i, args[i])
                                 for i in sorted(already_lit_pos))
                raise errors.CompilerError(m.format(info))
            # Convert requested arguments into a Literal.
            args = [(types.literal if i in e.requested_args else lambda x: x)(
                args[i]) for i, v in enumerate(args)]
            # Re-enter compilation with the Literal-ized arguments
            return self._compile_for_args(*args)

        except errors.TypingError as e:
            # Intercept typing error that may be due to an argument
            # that failed inferencing as a Numba type
            failed_args = []
            for i, arg in enumerate(args):
                val = arg.value if isinstance(arg, OmittedArg) else arg
                try:
                    tp = typeof(val, Purpose.argument)
                except ValueError as typeof_exc:
                    failed_args.append((i, str(typeof_exc)))
                else:
                    if tp is None:
                        failed_args.append(
                            (i, "cannot determine Numba type of value %r" %
                             (val, )))
            if failed_args:
                # Patch error message to ease debugging
                msg = str(e).rstrip() + (
                    "\n\nThis error may have been caused by the following argument(s):\n%s\n"
                    % "\n".join("- argument %d: %s" % (i, err)
                                for i, err in failed_args))
                e.patch_message(msg)

            error_rewrite(e, 'typing')
        except errors.UnsupportedError as e:
            # Something unsupported is present in the user code, add help info
            error_rewrite(e, 'unsupported_error')
        except (errors.NotDefinedError, errors.RedefinedError,
                errors.VerificationError) as e:
            # These errors are probably from an issue with either the code supplied
            # being syntactically or otherwise invalid
            error_rewrite(e, 'interpreter')
        except errors.ConstantInferenceError as e:
            # this is from trying to infer something as constant when it isn't
            # or isn't supported as a constant
            error_rewrite(e, 'constant_inference')
        except Exception as e:
            if config.SHOW_HELP:
                if hasattr(e, 'patch_message'):
                    help_msg = errors.error_extras['reportable']
                    e.patch_message('\n'.join((str(e).rstrip(), help_msg)))
            # ignore the FULL_TRACEBACKS config, this needs reporting!
            raise e
Пример #12
0
    def _compile_for_args(self, *args, **kws):
        """
        For internal use.  Compile a specialized version of the function
        for the given *args* and *kws*, and return the resulting callable.
        """
        assert not kws
        argtypes = []
        for a in args:
            if isinstance(a, OmittedArg):
                argtypes.append(types.Omitted(a.value))
            else:
                argtypes.append(self.typeof_pyval(a))
        try:
            return self.compile(tuple(argtypes))
        except errors.TypingError as e:
            # Intercept typing error that may be due to an argument
            # that failed inferencing as a Numba type
            failed_args = []
            for i, arg in enumerate(args):
                val = arg.value if isinstance(arg, OmittedArg) else arg
                try:
                    tp = typeof(val, Purpose.argument)
                except ValueError as typeof_exc:
                    failed_args.append((i, str(typeof_exc)))
                else:
                    if tp is None:
                        failed_args.append(
                            (i, "cannot determine Numba type of value %r" %
                             (val, )))
            if failed_args:
                # Patch error message to ease debugging
                msg = str(e).rstrip() + (
                    "\n\nThis error may have been caused by the following argument(s):\n%s\n"
                    % "\n".join("- argument %d: %s" % (i, err)
                                for i, err in failed_args))
                e.patch_message(msg)

            # add in help info
            if config.SHOW_HELP:
                help_msg = errors.error_extras['typing']
                e.patch_message(''.join(e.args) + help_msg)

            # raise
            if config.FULL_TRACEBACKS:
                raise e
            else:
                reraise(type(e), e, None)
        except errors.UnsupportedError as e:
            # Something unsupported is present in the user code, add help info
            if config.SHOW_HELP:
                help_msg = errors.error_extras['unsupported_error']
                e.patch_message(''.join(e.args) + help_msg)
            if config.FULL_TRACEBACKS:
                raise e
            else:
                reraise(type(e), e, None)
        except Exception as e:
            if config.SHOW_HELP:
                if hasattr(e, 'patch_message'):
                    help_msg = errors.error_extras['reportable']
                    e.patch_message(''.join(e.args) + help_msg)
            # ignore the FULL_TRACEBACKS config, this needs reporting!
            raise e