def type_inference_stage(typingctx, interp, args, return_type, locals={}, raise_errors=True): if len(args) != interp.arg_count: raise TypeError("Mismatch number of argument types") warnings = errors.WarningsFixer(errors.NumbaWarning) infer = typeinfer.TypeInferer(typingctx, interp, warnings) with typingctx.callstack.register(infer, interp.func_id, args): # Seed argument types for index, (name, ty) in enumerate(zip(interp.arg_names, args)): infer.seed_argument(name, index, ty) # Seed return type if return_type is not None: infer.seed_return(return_type) # Seed local types for k, v in locals.items(): infer.seed_type(k, v) infer.build_constraint() # return errors in case of partial typing errs = infer.propagate(raise_errors=raise_errors) typemap, restype, calltypes = infer.unify(raise_errors=raise_errors) # Output all Numba warnings warnings.flush() return _TypingResults(typemap, restype, calltypes, errs)
def test_warnings_fixer(self): # For some context, see #4083 wfix = errors.WarningsFixer(errors.NumbaWarning) with wfix.catch_warnings('foo', 10): warnings.warn(errors.NumbaWarning('same')) warnings.warn(errors.NumbaDeprecationWarning('same')) with warnings.catch_warnings(record=True) as w: warnings.simplefilter('always') wfix.flush() self.assertEqual(len(w), 2) # the order of these will be backwards to the above, the # WarningsFixer flush method sorts with a key based on str # comparison self.assertEqual(w[0].category, NumbaDeprecationWarning) self.assertEqual(w[1].category, NumbaWarning) self.assertIn('same', str(w[0].message)) self.assertIn('same', str(w[1].message))