def best_match(func_wrapper, argtypes): """ Find the right overload for a numba function. Arguments --------- func_wrapper: FunctionWrapper The function argtypes: [Type] Types to call the overloaded function with Returns ------- (py_func, result_signature) """ o = func_wrapper.dispatcher scope = determine_scope(func_wrapper.py_func) bound = {} # TODO: overloaded = resolve_overloads(o, scope, bound) argtypes = [to_blaze(t) for t in argtypes] overload = overloading.best_match(overloaded, argtypes) signature = resolve(overload.resolved_sig, scope, bound) return (overload.func, signature, overload.kwds)
def test_best_match_ellipses(self): d1 = dshape('10, T1, int32') d2 = dshape('..., float32') match = best_match(g, [d1, d2]) self.assertEqual(str(match.sig), 'X, Y, float32 -> ..., float32 -> X, int32') self.assertEqual(str(match.resolved_sig), '10, T1, float32 -> ..., float32 -> 10, int32')
def test_best_match_broadcasting(self): d1 = dshape('10, complex64') d2 = dshape('10, float32') match = best_match(f, [d1, d2]) self.assertEqual(str(match.sig), 'X, Y, cfloat32 -> X, Y, cfloat32 -> X, Y, cfloat32') self.assertEqual(str(match.resolved_sig), '1, 10, cfloat32 -> 1, 10, cfloat32 -> 1, 10, cfloat32')
def test_best_match(self): d1 = dshape('10, T1, int32') d2 = dshape('T2, T2, float32') match = best_match(f, [d1, d2]) self.assertEqual(str(match.sig), 'X, Y, float32 -> X, Y, float32 -> X, Y, float32') input = dshape('1, 1, float32 -> 1, 1, float32 -> R') self.assertEqual(str(unify_simple(input, match.resolved_sig)), '10, 1, float32 -> 10, 1, float32 -> 10, 1, float32')