def test_define_dynamic(self): # Define an element-wise blaze function f = blaze_func("test_func", dshape("(A... * T, A... * T) -> A... * T"), elementwise=True) # Define implementation of element-wise blaze function # use implementation category 'funky' signature1 = T.Function(*[dshape("float64")] * 3) kernel1 = lambda a, b: a * b kernel(f, 'funky', kernel1, signature1) signature2 = T.Function(*[dshape("Axes... * float64")] * 3) kernel2 = lambda a, b: a * b kernel(f, 'funky', kernel2, signature2) # See that we can find the right 'funky' implementation overload = f.best_match('funky', T.Tuple([dshape("float32"), dshape("float64")])) self.assertEqual(overload.resolved_sig, signature1) self.assertEqual(overload.func, kernel1) overload = f.best_match('funky', T.Tuple([dshape("10 * 10 * float32"), dshape("10 * 10 * float64")])) self.assertEqual(overload.resolved_sig, dshape("(10 * 10 * float64, 10 * 10 * float64) -> 10 * 10 * float64")[0]) self.assertEqual(overload.func, kernel2)
def test_tuple(self): # Simple tuple self.assertEqual(parse('(float32)', self.sym), ct.DataShape(ct.Tuple([ct.DataShape(ct.float32)]))) self.assertEqual(parse('(int16, int32)', self.sym), ct.DataShape(ct.Tuple([ct.DataShape(ct.int16), ct.DataShape(ct.int32)]))) # A trailing comma is ok self.assertEqual(parse('(float32,)', self.sym), ct.DataShape(ct.Tuple([ct.DataShape(ct.float32)]))) self.assertEqual(parse('(int16, int32,)', self.sym), ct.DataShape(ct.Tuple([ct.DataShape(ct.int16), ct.DataShape(ct.int32)])))
def test_best_match_signed_vs_unsigned(self): ores = OverloadResolver('h') ores.extend_overloads([ '(A... * int64) -> A... * int64', '(A... * uint64) -> A... * uint64' ]) d1 = dshape('10 * 3 * int64') idx, match = ores.resolve_overload(coretypes.Tuple([d1])) self.assertEqual(idx, 0) self.assertEqual(match, dshape('(10 * 3 * int64) -> 10 * 3 * int64')[0]) d1 = dshape('4 * 5 * uint64') idx, match = ores.resolve_overload(coretypes.Tuple([d1])) self.assertEqual(idx, 1) self.assertEqual(match, dshape('(4 * 5 * uint64) -> 4 * 5 * uint64')[0])
def test_best_match(self): d1 = dshape('10 * T1 * int32') d2 = dshape('T2 * T2 * float32') match = best_match(f, coretypes.Tuple([d1, d2])) self.assertEqual( str(match.sig), '(X * Y * float32, X * Y * float32) -> X * Y * float32')
def test_best_match_int_float32_vs_float64(self): ores = OverloadResolver('k') ores.extend_overloads([ '(A... * float32) -> A... * float32', '(A... * float64) -> A... * float64' ]) d1 = dshape('3 * int32') idx, match = ores.resolve_overload(coretypes.Tuple([d1])) self.assertEqual(idx, 1) self.assertEqual(match, dshape('(3 * float64) -> 3 * float64')[0])
def test_best_match_float_int_complex(self): ores = OverloadResolver('j') ores.extend_overloads([ '(A... * float64, A... * float64) -> A... * float64', '(A... * complex[float32], A... * complex[float32]) -> A... * complex[float32]' ]) d1, d2 = dshapes('3 * float64', 'int32') idx, match = ores.resolve_overload(coretypes.Tuple([d1, d2])) self.assertEqual(idx, 0) self.assertEqual(match, dshape('(3 * float64, float64) -> 3 * float64')[0])
def test_best_match_int32_float32_ufunc_promotion(self): ores = OverloadResolver('m') ores.extend_overloads([ '(A... * int32, A... * int32) -> A... * int32', '(A... * float32, A... * float32) -> A... * float32', '(A... * float64, A... * float64) -> A... * float64' ]) d1, d2 = dshapes('3 * int32', '3 * float32') idx, match = ores.resolve_overload(coretypes.Tuple([d1, d2])) self.assertEqual(idx, 2) self.assertEqual( match, dshape('(3 * float64, 3 * float64) -> 3 * float64')[0])
def test_best_match_ellipses(self): ores = OverloadResolver('g') ores.extend_overloads([ '(X * Y * float32, X * Y * float32) -> X * int32', '(X * Y * float32, ... * float32) -> X * int32' ]) d1 = dshape('10 * var * int32') d2 = dshape('3 * float32') idx, match = ores.resolve_overload(coretypes.Tuple([d1, d2])) self.assertEqual(idx, 1) self.assertEqual( match, dshape('(10 * var * float32, 3 * float32) -> 10 * int32')[0])
def test_best_match_typevar_dims(self): ores = OverloadResolver('f') ores.extend_overloads([ '(X * Y * float32, X * Y * float32) -> X * Y * float32', '(X * Y * complex[float32], X * Y * complex[float32]) -> X * Y * complex[float32]', '(X * Y * complex[float64], X * Y * complex[float64]) -> X * Y * complex[float64]' ]) d1 = dshape('3 * 10 * complex[float32]') d2 = dshape('3 * 10 * float32') idx, match = ores.resolve_overload(coretypes.Tuple([d1, d2])) self.assertEqual(idx, 1) self.assertEqual( match, dshape( '(3 * 10 * complex[float32], 3 * 10 * complex[float32]) -> 3 * 10 * complex[float32]' )[0])