def test_omitted_args(self): ty0 = typeof(OmittedArg(0.0)) ty1 = typeof(OmittedArg(1)) ty2 = typeof(OmittedArg(1.0)) ty3 = typeof(OmittedArg(1.0)) self.assertEqual(ty0, types.Omitted(0.0)) self.assertEqual(ty1, types.Omitted(1)) self.assertEqual(ty2, types.Omitted(1.0)) self.assertEqual(len({ty0, ty1, ty2}), 3) self.assertEqual(ty3, ty2)
def test_inline_call_branch_pruning(self): # branch pruning pass should run properly in inlining to enable # functions with type checks @njit def foo(A=None): if A is None: return 2 else: return A def test_impl(A=None): return foo(A) @register_pass(analysis_only=False, mutates_CFG=True) class PruningInlineTestPass(FunctionPass): _name = "pruning_inline_test_pass" def __init__(self): FunctionPass.__init__(self) def run_pass(self, state): # assuming the function has one block with one call inside assert len(state.func_ir.blocks) == 1 block = list(state.func_ir.blocks.values())[0] for i, stmt in enumerate(block.body): if guard(find_callname, state.func_ir, stmt.value) is not None: inline_closure_call( state.func_ir, {}, block, i, foo.py_func, state.typingctx, (state.type_annotation.typemap[ stmt.value.args[0].name], ), state.type_annotation.typemap, state.calltypes, ) break return True class InlineTestPipelinePrune(compiler.CompilerBase): def define_pipelines(self): pm = gen_pipeline(self.state, PruningInlineTestPass) pm.finalize() return [pm] # make sure inline_closure_call runs in full pipeline j_func = njit(pipeline_class=InlineTestPipelinePrune)(test_impl) A = 3 self.assertEqual(test_impl(A), j_func(A)) self.assertEqual(test_impl(), j_func()) # make sure IR doesn't have branches fir = j_func.overloads[( types.Omitted(None), )].metadata["preserved_ir"] fir.blocks = simplify_CFG(fir.blocks) self.assertEqual(len(fir.blocks), 1)
def _lit_or_omitted(value): """Returns a Literal instance if the type of value is supported; otherwise, return `Omitted(value)`. """ try: return types.literal(value) except LiteralTypingError: return types.Omitted(value)
def test_omitted_arg(self): # See issue 7726 @njit(debug=True) def foo(missing=None): pass # check that it will actually compile (verifies DI emission is ok) with override_config('DEBUGINFO_DEFAULT', 1): foo() metadata = self._get_metadata(foo, sig=(types.Omitted(None), )) metadata_definition_map = self._get_metadata_map(metadata) # Find DISubroutineType tmp_disubr = [] for md in metadata: if "DISubroutineType" in md: tmp_disubr.append(md) self.assertEqual(len(tmp_disubr), 1) disubr = tmp_disubr.pop() disubr_matched = re.match(r'.*!DISubroutineType\(types: ([!0-9]+)\)$', disubr) self.assertIsNotNone(disubr_matched) disubr_groups = disubr_matched.groups() self.assertEqual(len(disubr_groups), 1) disubr_meta = disubr_groups[0] # Find the types in the DISubroutineType arg list disubr_types = metadata_definition_map[disubr_meta] disubr_types_matched = re.match(r'!{(.*)}', disubr_types) self.assertIsNotNone(disubr_matched) disubr_types_groups = disubr_types_matched.groups() self.assertEqual(len(disubr_types_groups), 1) # fetch out and assert the last argument type, should be void * md_fn_arg = [x.strip() for x in disubr_types_groups[0].split(',')][-1] arg_ty = metadata_definition_map[md_fn_arg] expected_arg_ty = (r'^.*!DICompositeType\(tag: DW_TAG_structure_type, ' r'name: "Anonymous struct \({}\)", elements: ' r'(![0-9]+), identifier: "{}"\)') self.assertRegex(arg_ty, expected_arg_ty) md_base_ty = re.match(expected_arg_ty, arg_ty).groups()[0] base_ty = metadata_definition_map[md_base_ty] # expect ir.LiteralStructType([]) self.assertEqual(base_ty, ('!{}'))
def test_cond_is_kwarg_none(self): def impl(x=None): if x is None: y = 10 else: y = 40 if x is not None: z = 100 else: z = 400 return z, y self.assert_prune(impl, (types.Omitted(None), ), [False, True], None) self.assert_prune(impl, (types.NoneType('none'), ), [False, True], None) self.assert_prune(impl, (types.IntegerLiteral(10), ), [True, False], 10)
def test_cond_is_kwarg_value(self): def impl(x=1000): if x == 1000: y = 10 else: y = 40 if x != 1000: z = 100 else: z = 400 return z, y self.assert_prune(impl, (types.Omitted(1000), ), [None, None], 1000) self.assert_prune(impl, (types.IntegerLiteral(1000), ), [None, None], 1000) self.assert_prune(impl, (types.IntegerLiteral(0), ), [None, None], 0) self.assert_prune(impl, (types.NoneType('none'), ), [True, False], None)
def default_handler(index, param, default): return types.Omitted(default)
def _numba_type_(self): return types.Omitted(self.value)
def _compile_for_args(self, *args, **kws): """ For internal use. Compile a specialized version of the function for the given *args* and *kws*, and return the resulting callable. """ assert not kws # call any initialisation required for the compilation chain (e.g. # extension point registration). self._compilation_chain_init_hook() def error_rewrite(e, issue_type): """ Rewrite and raise Exception `e` with help supplied based on the specified issue_type. """ if config.SHOW_HELP: help_msg = errors.error_extras[issue_type] e.patch_message('\n'.join((str(e).rstrip(), help_msg))) if config.FULL_TRACEBACKS: raise e else: reraise(type(e), e, None) argtypes = [] for a in args: if isinstance(a, OmittedArg): argtypes.append(types.Omitted(a.value)) else: argtypes.append(self.typeof_pyval(a)) try: return self.compile(tuple(argtypes)) except errors.ForceLiteralArg as e: # Received request for compiler re-entry with the list of arguments # indicated by e.requested_args. # First, check if any of these args are already Literal-ized already_lit_pos = [i for i in e.requested_args if isinstance(args[i], types.Literal)] if already_lit_pos: # Abort compilation if any argument is already a Literal. # Letting this continue will cause infinite compilation loop. m = ("Repeated literal typing request.\n" "{}.\n" "This is likely caused by an error in typing. " "Please see nested and suppressed exceptions.") info = ', '.join('Arg #{} is {}'.format(i, args[i]) for i in sorted(already_lit_pos)) raise errors.CompilerError(m.format(info)) # Convert requested arguments into a Literal. args = [(types.literal if i in e.requested_args else lambda x: x)(args[i]) for i, v in enumerate(args)] # Re-enter compilation with the Literal-ized arguments return self._compile_for_args(*args) except errors.TypingError as e: # Intercept typing error that may be due to an argument # that failed inferencing as a Numba type failed_args = [] for i, arg in enumerate(args): val = arg.value if isinstance(arg, OmittedArg) else arg try: tp = typeof(val, Purpose.argument) except ValueError as typeof_exc: failed_args.append((i, str(typeof_exc))) else: if tp is None: failed_args.append( (i, "cannot determine Numba type of value %r" % (val,))) if failed_args: # Patch error message to ease debugging msg = str(e).rstrip() + ( "\n\nThis error may have been caused by the following argument(s):\n%s\n" % "\n".join("- argument %d: %s" % (i, err) for i, err in failed_args)) e.patch_message(msg) error_rewrite(e, 'typing') except errors.UnsupportedError as e: # Something unsupported is present in the user code, add help info error_rewrite(e, 'unsupported_error') except (errors.NotDefinedError, errors.RedefinedError, errors.VerificationError) as e: # These errors are probably from an issue with either the code supplied # being syntactically or otherwise invalid error_rewrite(e, 'interpreter') except errors.ConstantInferenceError as e: # this is from trying to infer something as constant when it isn't # or isn't supported as a constant error_rewrite(e, 'constant_inference') except Exception as e: if config.SHOW_HELP: if hasattr(e, 'patch_message'): help_msg = errors.error_extras['reportable'] e.patch_message('\n'.join((str(e).rstrip(), help_msg))) # ignore the FULL_TRACEBACKS config, this needs reporting! raise e