Ejemplo n.º 1
0
 def test_match_with_resolver(self):
     # Test matching with a resolver function
     # This is a contrived resolver which combines the A... and
     # B typevars in a way that cannot be done with simple pattern
     # matching. While not a useful example in and of itself, it
     # exhibits the needed behavior in reduction function signature
     # matching.
     def resolver(tvar, tvdict):
         if tvar == T.Ellipsis(T.TypeVar('R')):
             a = tvdict[T.Ellipsis(T.TypeVar('A'))]
             b = tvdict[T.TypeVar('B')]
             result = [b]
             for x in a:
                 result.extend([x, b])
             return result
         elif tvar == T.TypeVar('T'):
             return T.int16
     at = dshape('(5 * int32, 4 * float64)')
     sig = dshape('(B * int32, A... * float64) -> R... * T')
     self.assertEqual(match_argtypes_to_signature(at, sig, resolver),
                      (dshape('(5 * int32, 4 * float64) -> 5 * 4 * 5 * int16')[0],
                       0.25))
     at = dshape('(5 * var * 2 * int32, 4 * float64)')
     sig = dshape('(A... * int32, B * float64) -> R... * 2 * T')
     self.assertEqual(match_argtypes_to_signature(at, sig, resolver),
                      (dshape('(5 * var * 2 * int32, 4 * float64) ->' +
                              ' 4 * 5 * 4 * var * 4 * 2 * 4 * 2 * int16')[0],
                       0.25))
 def test_match_with_resolver(self):
     # Test matching with a resolver function
     # This is a contrived resolver which combines the A... and
     # B typevars in a way that cannot be done with simple pattern
     # matching. While not a useful example in and of itself, it
     # exhibits the needed behavior in reduction function signature
     # matching.
     def resolver(tvar, tvdict):
         if tvar == T.Ellipsis(T.TypeVar('R')):
             a = tvdict[T.Ellipsis(T.TypeVar('A'))]
             b = tvdict[T.TypeVar('B')]
             result = [b]
             for x in a:
                 result.extend([x, b])
             return result
         elif tvar == T.TypeVar('T'):
             return T.int16
     at = dshape('(5 * int32, 4 * float64)')
     sig = dshape('(B * int32, A... * float64) -> R... * T')
     self.assertEqual(match_argtypes_to_signature(at, sig, resolver),
                      (dshape('(5 * int32, 4 * float64) -> 5 * 4 * 5 * int16')[0],
                       0.25))
     at = dshape('(5 * var * 2 * int32, 4 * float64)')
     sig = dshape('(A... * int32, B * float64) -> R... * 2 * T')
     self.assertEqual(match_argtypes_to_signature(at, sig, resolver),
                      (dshape('(5 * var * 2 * int32, 4 * float64) ->' +
                              ' 4 * 5 * 4 * var * 4 * 2 * 4 * 2 * int16')[0],
                       0.25))
 def test_dshape_matches_concrete(self):
     # Exact match, same signature and zero cost
     at = dshape('(3 * int32, 2 * var * float64)')
     sig = dshape('(3 * int32, 2 * var * float64) -> 4 * int16')
     self.assertEqual(match_argtypes_to_signature(at, sig),
                      (sig[0], 0))
     # Requires broadcasting
     at = dshape('(1 * int32, 2 * 4 * float64)')
     sig = dshape('(3 * int32, 2 * var * float64) -> 4 * int16')
     self.assertEqual(match_argtypes_to_signature(at, sig),
                      (sig[0], max(dim_coercion_cost(T.Fixed(1), T.Fixed(3)),
                                   dim_coercion_cost(T.Fixed(4), T.Var()))))
Ejemplo n.º 4
0
 def test_dshape_matches_concrete(self):
     # Exact match, same signature and zero cost
     at = dshape('(3 * int32, 2 * var * float64)')
     sig = dshape('(3 * int32, 2 * var * float64) -> 4 * int16')
     self.assertEqual(match_argtypes_to_signature(at, sig),
                      (sig[0], 0))
     # Requires broadcasting
     at = dshape('(1 * int32, 2 * 4 * float64)')
     sig = dshape('(3 * int32, 2 * var * float64) -> 4 * int16')
     self.assertEqual(match_argtypes_to_signature(at, sig),
                      (sig[0], max(dim_coercion_cost(T.Fixed(1), T.Fixed(3)),
                                   dim_coercion_cost(T.Fixed(4), T.Var()))))
 def test_broadcast_vs_not(self):
     # Single dimension type variables must match up exactly
     at = dshape('(int32, float64)')
     sig_scalar = dshape('(float64, float64) -> int16')
     sig_bcast = dshape('(A... * float64, A... * float64) -> A... * int16')
     match_scalar = match_argtypes_to_signature(at, sig_scalar)
     match_bcast = match_argtypes_to_signature(at, sig_bcast)
     self.assertEqual(match_scalar[0],
                      dshape('(float64, float64) -> int16')[0])
     self.assertEqual(match_bcast[0],
                      dshape('(float64, float64) -> int16')[0])
     # Should be cheaper to match without the broadcasting
     self.assertTrue(match_scalar[1] < match_bcast[1])
Ejemplo n.º 6
0
 def test_broadcast_vs_not(self):
     # Single dimension type variables must match up exactly
     at = dshape('(int32, float64)')
     sig_scalar = dshape('(float64, float64) -> int16')
     sig_bcast = dshape('(A... * float64, A... * float64) -> A... * int16')
     match_scalar = match_argtypes_to_signature(at, sig_scalar)
     match_bcast = match_argtypes_to_signature(at, sig_bcast)
     self.assertEqual(match_scalar[0],
                      dshape('(float64, float64) -> int16')[0])
     self.assertEqual(match_bcast[0],
                      dshape('(float64, float64) -> int16')[0])
     # Should be cheaper to match without the broadcasting
     self.assertTrue(match_scalar[1] < match_bcast[1])
 def test_dtype_matches_typevar(self):
     # Exact match, and zero cost
     at = dshape('(int32, float64)')
     sig = dshape('(int32, T) -> T')
     self.assertEqual(match_argtypes_to_signature(at, sig),
                      (dshape('(int32, float64) -> float64')[0], 0.125))
     # Type promotion between the inputs
     at = dshape('(int32, float64)')
     sig = dshape('(T, T) -> T')
     self.assertEqual(match_argtypes_to_signature(at, sig),
                      (dshape('(float64, float64) -> float64')[0], 0.125))
     # Type promotion between the inputs
     at = dshape('(int32, bool, float64)')
     sig = dshape('(T, S, T) -> S')
     self.assertEqual(match_argtypes_to_signature(at, sig),
                      (dshape('(float64, bool, float64) -> bool')[0], 0.125))
 def test_dtype_matches_concrete(self):
     # Exact match, same signature and zero cost
     at = dshape('(int32, float64)')
     sig = dshape('(int32, float64) -> int16')
     self.assertEqual(match_argtypes_to_signature(at, sig),
                      (sig[0], 0))
     # Requires a coercion, cost is that of the coercion
     at = dshape('(int32, int32)')
     sig = dshape('(int32, float64) -> int16')
     self.assertEqual(match_argtypes_to_signature(at, sig),
                      (sig[0], dtype_coercion_cost(T.int32, T.float64)))
     # Requires two coercions, cost is maximum of the two
     at = dshape('(int16, int32)')
     sig = dshape('(int32, float64) -> int16')
     self.assertEqual(match_argtypes_to_signature(at, sig),
                      (sig[0], dtype_coercion_cost(T.int32, T.float64)))
Ejemplo n.º 9
0
 def test_dtype_matches_typevar(self):
     # Exact match, and zero cost
     at = dshape('(int32, float64)')
     sig = dshape('(int32, T) -> T')
     self.assertEqual(match_argtypes_to_signature(at, sig),
                      (dshape('(int32, float64) -> float64')[0], 0.125))
     # Type promotion between the inputs
     at = dshape('(int32, float64)')
     sig = dshape('(T, T) -> T')
     self.assertEqual(match_argtypes_to_signature(at, sig),
                      (dshape('(float64, float64) -> float64')[0], 0.125))
     # Type promotion between the inputs
     at = dshape('(int32, bool, float64)')
     sig = dshape('(T, S, T) -> S')
     self.assertEqual(match_argtypes_to_signature(at, sig),
                      (dshape('(float64, bool, float64) -> bool')[0], 0.125))
Ejemplo n.º 10
0
 def test_dtype_matches_concrete(self):
     # Exact match, same signature and zero cost
     at = dshape('(int32, float64)')
     sig = dshape('(int32, float64) -> int16')
     self.assertEqual(match_argtypes_to_signature(at, sig),
                      (sig[0], 0))
     # Requires a coercion, cost is that of the coercion
     at = dshape('(int32, int32)')
     sig = dshape('(int32, float64) -> int16')
     self.assertEqual(match_argtypes_to_signature(at, sig),
                      (sig[0], dtype_coercion_cost(T.int32, T.float64)))
     # Requires two coercions, cost is maximum of the two
     at = dshape('(int16, int32)')
     sig = dshape('(int32, float64) -> int16')
     self.assertEqual(match_argtypes_to_signature(at, sig),
                      (sig[0], dtype_coercion_cost(T.int32, T.float64)))
 def test_dshape_matches_typevar(self):
     # Arrays with matching size
     at = dshape('(5 * int32, 5 * float64)')
     sig = dshape('(N * int32, N * float64) -> N * int16')
     self.assertEqual(match_argtypes_to_signature(at, sig),
                      (dshape('(5 * int32, 5 * float64) -> 5 * int16')[0],
                       0.125))
     # Matrix multiplication
     at = dshape('(3 * 5 * float64, 5 * 6 * float32)')
     sig = dshape('(M * N * A, N * R * A) -> M * R * A')
     self.assertEqual(match_argtypes_to_signature(at, sig),
                      (dshape('(3 * 5 * float64, 5 * 6 * float64) ->' +
                              ' 3 * 6 * float64')[0], 0.375))
     # Broadcasted matrix multiplication
     at = dshape('(20 * 3 * 5 * float64, 3 * 1 * 5 * 6 * float32)')
     sig = dshape('(Dims... * M * N * A, Dims... * N * R * A) ->' +
                  ' Dims... * M * R * A')
     self.assertEqual(match_argtypes_to_signature(at, sig),
                      (dshape('(20 * 3 * 5 * float64,' +
                              ' 3 * 1 * 5 * 6 * float64) ->' +
                              ' 3 * 20 * 3 * 6 * float64')[0], 0.625))
Ejemplo n.º 12
0
 def test_dshape_matches_typevar(self):
     # Arrays with matching size
     at = dshape('(5 * int32, 5 * float64)')
     sig = dshape('(N * int32, N * float64) -> N * int16')
     self.assertEqual(match_argtypes_to_signature(at, sig),
                      (dshape('(5 * int32, 5 * float64) -> 5 * int16')[0],
                       0.125))
     # Matrix multiplication
     at = dshape('(3 * 5 * float64, 5 * 6 * float32)')
     sig = dshape('(M * N * A, N * R * A) -> M * R * A')
     self.assertEqual(match_argtypes_to_signature(at, sig),
                      (dshape('(3 * 5 * float64, 5 * 6 * float64) ->' +
                              ' 3 * 6 * float64')[0], 0.375))
     # Broadcasted matrix multiplication
     at = dshape('(20 * 3 * 5 * float64, 3 * 1 * 5 * 6 * float32)')
     sig = dshape('(Dims... * M * N * A, Dims... * N * R * A) ->' +
                  ' Dims... * M * R * A')
     self.assertEqual(match_argtypes_to_signature(at, sig),
                      (dshape('(20 * 3 * 5 * float64,' +
                              ' 3 * 1 * 5 * 6 * float64) ->' +
                              ' 3 * 20 * 3 * 6 * float64')[0], 0.625))
Ejemplo n.º 13
0
 def test_tv_matches_struct(self):
     at = dshape('(3 * {x: int, y: string}, 3 * bool)')
     sig = dshape('(M * T, M * bool) -> var * T')
     match = match_argtypes_to_signature(at, sig)
     self.assertEqual(match[0],
                      dshape('(3 * {x: int, y: string}, 3 * bool) -> var * {x: int, y: string}')[0])