def test_omitted_args(self): ty0 = typeof(OmittedArg(0.0)) ty1 = typeof(OmittedArg(1)) ty2 = typeof(OmittedArg(1.0)) ty3 = typeof(OmittedArg(1.0)) self.assertEqual(ty0, types.Omitted(0.0)) self.assertEqual(ty1, types.Omitted(1)) self.assertEqual(ty2, types.Omitted(1.0)) self.assertEqual(len({ty0, ty1, ty2}), 3) self.assertEqual(ty3, ty2)
def test_inline_call_branch_pruning(self): # branch pruning pass should run properly in inlining to enable # functions with type checks @njit def foo(A=None): if A is None: return 2 else: return A def test_impl(A=None): return foo(A) @register_pass(analysis_only=False, mutates_CFG=True) class PruningInlineTestPass(FunctionPass): _name = "pruning_inline_test_pass" def __init__(self): FunctionPass.__init__(self) def run_pass(self, state): # assuming the function has one block with one call inside assert len(state.func_ir.blocks) == 1 block = list(state.func_ir.blocks.values())[0] for i, stmt in enumerate(block.body): if (guard(find_callname, state.func_ir, stmt.value) is not None): inline_closure_call(state.func_ir, {}, block, i, foo.py_func, state.typingctx, (state.type_annotation.typemap[ stmt.value.args[0].name], ), state.type_annotation.typemap, state.calltypes) break return True class InlineTestPipelinePrune(numba.compiler.CompilerBase): def define_pipelines(self): pm = gen_pipeline(self.state, PruningInlineTestPass) pm.finalize() return [pm] # make sure inline_closure_call runs in full pipeline j_func = njit(pipeline_class=InlineTestPipelinePrune)(test_impl) A = 3 self.assertEqual(test_impl(A), j_func(A)) self.assertEqual(test_impl(), j_func()) # make sure IR doesn't have branches fir = j_func.overloads[( types.Omitted(None), )].metadata['preserved_ir'] fir.blocks = numba.ir_utils.simplify_CFG(fir.blocks) self.assertEqual(len(fir.blocks), 1)
def _compile_for_args(self, *args, **kws): """ For internal use. Compile a specialized version of the function for the given *args* and *kws*, and return the resulting callable. """ assert not kws real_args = [] for a in args: if isinstance(a, OmittedArg): real_args.append(types.Omitted(a.value)) else: real_args.append(self.typeof_pyval(a)) return self.compile(tuple(real_args))
def test_cond_is_kwarg_none(self): def impl(x=None): if x is None: y = 10 else: y = 40 if x is not None: z = 100 else: z = 400 return z, y self.assert_prune(impl, (types.Omitted(None),), [False, True], None) self.assert_prune(impl, (types.NoneType('none'),), [False, True], None) self.assert_prune(impl, (types.IntegerLiteral(10),), [True, False], 10)
def test_cond_is_kwarg_value(self): def impl(x=1000): if x == 1000: y = 10 else: y = 40 if x != 1000: z = 100 else: z = 400 return z, y self.assert_prune(impl, (types.Omitted(1000),), [None, None], 1000) self.assert_prune(impl, (types.IntegerLiteral(1000),), [None, None], 1000) self.assert_prune(impl, (types.IntegerLiteral(0),), [None, None], 0) self.assert_prune(impl, (types.NoneType('none'),), [True, False], None)
def test_inline_call_branch_pruning(self): # branch pruning pass should run properly in inlining to enable # functions with type checks @njit def foo(A=None): if A is None: return 2 else: return A def test_impl(A=None): return foo(A) class InlineTestPipelinePrune(InlineTestPipeline): def stage_inline_test_pass(self): # assuming the function has one block with one call inside assert len(self.func_ir.blocks) == 1 block = list(self.func_ir.blocks.values())[0] for i, stmt in enumerate(block.body): if (guard(find_callname, self.func_ir, stmt.value) is not None): inline_closure_call( self.func_ir, {}, block, i, foo.py_func, self.typingctx, (self.typemap[stmt.value.args[0].name], ), self.typemap, self.calltypes) break # make sure inline_closure_call runs in full pipeline j_func = njit(pipeline_class=InlineTestPipelinePrune)(test_impl) A = 3 self.assertEqual(test_impl(A), j_func(A)) self.assertEqual(test_impl(), j_func()) # make sure IR doesn't have branches fir = j_func.overloads[( types.Omitted(None), )].metadata['final_func_ir'] fir.blocks = numba.ir_utils.simplify_CFG(fir.blocks) self.assertEqual(len(fir.blocks), 1)
def _compile_for_args(self, *args, **kws): """ For internal use. Compile a specialized version of the function for the given *args* and *kws*, and return the resulting callable. """ assert not kws argtypes = [] for a in args: if isinstance(a, OmittedArg): argtypes.append(types.Omitted(a.value)) else: argtypes.append(self.typeof_pyval(a)) try: return self.compile(tuple(argtypes)) except errors.TypingError as e: # Intercept typing error that may be due to an argument # that failed inferencing as a Numba type failed_args = [] for i, arg in enumerate(args): val = arg.value if isinstance(arg, OmittedArg) else arg try: tp = typeof(val, Purpose.argument) except ValueError as typeof_exc: failed_args.append((i, str(typeof_exc))) else: if tp is None: failed_args.append( (i, "cannot determine Numba type of value %r" % (val,))) if failed_args: # Patch error message to ease debugging msg = str(e).rstrip() + ( "\n\nThis error may have been caused by the following argument(s):\n%s\n" % "\n".join("- argument %d: %s" % (i, err) for i, err in failed_args)) e.patch_message(msg) raise e
def default_handler(index, param, default): return types.Omitted(default)
def _numba_type_(self): return types.Omitted(self.value)
def _compile_for_args(self, *args, **kws): """ For internal use. Compile a specialized version of the function for the given *args* and *kws*, and return the resulting callable. """ assert not kws def error_rewrite(e, issue_type): """ Rewrite and raise Exception `e` with help supplied based on the specified issue_type. """ if config.SHOW_HELP: help_msg = errors.error_extras[issue_type] e.patch_message(''.join(e.args) + help_msg) if config.FULL_TRACEBACKS: raise e else: reraise(type(e), e, None) argtypes = [] for a in args: if isinstance(a, OmittedArg): argtypes.append(types.Omitted(a.value)) else: argtypes.append(self.typeof_pyval(a)) try: return self.compile(tuple(argtypes)) except errors.TypingError as e: # Intercept typing error that may be due to an argument # that failed inferencing as a Numba type failed_args = [] for i, arg in enumerate(args): val = arg.value if isinstance(arg, OmittedArg) else arg try: tp = typeof(val, Purpose.argument) except ValueError as typeof_exc: failed_args.append((i, str(typeof_exc))) else: if tp is None: failed_args.append( (i, "cannot determine Numba type of value %r" % (val, ))) if failed_args: # Patch error message to ease debugging msg = str(e).rstrip() + ( "\n\nThis error may have been caused by the following argument(s):\n%s\n" % "\n".join("- argument %d: %s" % (i, err) for i, err in failed_args)) e.patch_message(msg) error_rewrite(e, 'typing') except errors.UnsupportedError as e: # Something unsupported is present in the user code, add help info error_rewrite(e, 'unsupported_error') except (errors.NotDefinedError, errors.RedefinedError, errors.VerificationError) as e: # These errors are probably from an issue with either the code supplied # being syntactically or otherwise invalid error_rewrite(e, 'interpreter') except errors.ConstantInferenceError as e: # this is from trying to infer something as constant when it isn't # or isn't supported as a constant error_rewrite(e, 'constant_inference') except Exception as e: if config.SHOW_HELP: if hasattr(e, 'patch_message'): help_msg = errors.error_extras['reportable'] e.patch_message(''.join(e.args) + help_msg) # ignore the FULL_TRACEBACKS config, this needs reporting! raise e
def _compile_for_args(self, *args, **kws): """ For internal use. Compile a specialized version of the function for the given *args* and *kws*, and return the resulting callable. """ assert not kws def error_rewrite(e, issue_type): """ Rewrite and raise Exception `e` with help supplied based on the specified issue_type. """ if config.SHOW_HELP: help_msg = errors.error_extras[issue_type] e.patch_message('\n'.join((str(e).rstrip(), help_msg))) if config.FULL_TRACEBACKS: raise e else: reraise(type(e), e, None) argtypes = [] for a in args: if isinstance(a, OmittedArg): argtypes.append(types.Omitted(a.value)) else: argtypes.append(self.typeof_pyval(a)) try: return self.compile(tuple(argtypes)) except errors.ForceLiteralArg as e: # Received request for compiler re-entry with the list of arguments # indicated by e.requested_args. # First, check if any of these args are already Literal-ized already_lit_pos = [ i for i in e.requested_args if isinstance(args[i], types.Literal) ] if already_lit_pos: # Abort compilation if any argument is already a Literal. # Letting this continue will cause infinite compilation loop. m = ("Repeated literal typing request.\n" "{}.\n" "This is likely caused by an error in typing. " "Please see nested and suppressed exceptions.") info = ', '.join('Arg #{} is {}'.format(i, args[i]) for i in sorted(already_lit_pos)) raise errors.CompilerError(m.format(info)) # Convert requested arguments into a Literal. args = [(types.literal if i in e.requested_args else lambda x: x)( args[i]) for i, v in enumerate(args)] # Re-enter compilation with the Literal-ized arguments return self._compile_for_args(*args) except errors.TypingError as e: # Intercept typing error that may be due to an argument # that failed inferencing as a Numba type failed_args = [] for i, arg in enumerate(args): val = arg.value if isinstance(arg, OmittedArg) else arg try: tp = typeof(val, Purpose.argument) except ValueError as typeof_exc: failed_args.append((i, str(typeof_exc))) else: if tp is None: failed_args.append( (i, "cannot determine Numba type of value %r" % (val, ))) if failed_args: # Patch error message to ease debugging msg = str(e).rstrip() + ( "\n\nThis error may have been caused by the following argument(s):\n%s\n" % "\n".join("- argument %d: %s" % (i, err) for i, err in failed_args)) e.patch_message(msg) error_rewrite(e, 'typing') except errors.UnsupportedError as e: # Something unsupported is present in the user code, add help info error_rewrite(e, 'unsupported_error') except (errors.NotDefinedError, errors.RedefinedError, errors.VerificationError) as e: # These errors are probably from an issue with either the code supplied # being syntactically or otherwise invalid error_rewrite(e, 'interpreter') except errors.ConstantInferenceError as e: # this is from trying to infer something as constant when it isn't # or isn't supported as a constant error_rewrite(e, 'constant_inference') except Exception as e: if config.SHOW_HELP: if hasattr(e, 'patch_message'): help_msg = errors.error_extras['reportable'] e.patch_message('\n'.join((str(e).rstrip(), help_msg))) # ignore the FULL_TRACEBACKS config, this needs reporting! raise e
def _compile_for_args(self, *args, **kws): """ For internal use. Compile a specialized version of the function for the given *args* and *kws*, and return the resulting callable. """ assert not kws argtypes = [] for a in args: if isinstance(a, OmittedArg): argtypes.append(types.Omitted(a.value)) else: argtypes.append(self.typeof_pyval(a)) try: return self.compile(tuple(argtypes)) except errors.TypingError as e: # Intercept typing error that may be due to an argument # that failed inferencing as a Numba type failed_args = [] for i, arg in enumerate(args): val = arg.value if isinstance(arg, OmittedArg) else arg try: tp = typeof(val, Purpose.argument) except ValueError as typeof_exc: failed_args.append((i, str(typeof_exc))) else: if tp is None: failed_args.append( (i, "cannot determine Numba type of value %r" % (val, ))) if failed_args: # Patch error message to ease debugging msg = str(e).rstrip() + ( "\n\nThis error may have been caused by the following argument(s):\n%s\n" % "\n".join("- argument %d: %s" % (i, err) for i, err in failed_args)) e.patch_message(msg) # add in help info if config.SHOW_HELP: help_msg = errors.error_extras['typing'] e.patch_message(''.join(e.args) + help_msg) # raise if config.FULL_TRACEBACKS: raise e else: reraise(type(e), e, None) except errors.UnsupportedError as e: # Something unsupported is present in the user code, add help info if config.SHOW_HELP: help_msg = errors.error_extras['unsupported_error'] e.patch_message(''.join(e.args) + help_msg) if config.FULL_TRACEBACKS: raise e else: reraise(type(e), e, None) except Exception as e: if config.SHOW_HELP: if hasattr(e, 'patch_message'): help_msg = errors.error_extras['reportable'] e.patch_message(''.join(e.args) + help_msg) # ignore the FULL_TRACEBACKS config, this needs reporting! raise e