Ejemplo n.º 1
0
 def test_match_equation_dim(self):
     # Broadcasting a single dimension
     eqns = _match_equation(dshape('1 * int32'), dshape('10 * int32'))
     self.assertEqual(eqns, [(T.Fixed(1), T.Fixed(10)),
                             (T.int32, T.int32)])
     # Matching a dim type variable
     eqns = _match_equation(dshape('3 * int32'), dshape('M * int32'))
     self.assertEqual(eqns, [(T.Fixed(3), T.TypeVar('M')),
                             (T.int32, T.int32)])
Ejemplo n.º 2
0
 def test_fixed_dims(self):
     self.assertEqual(parse('3 * bool', self.sym),
                      ct.DataShape(ct.Fixed(3), ct.bool_))
     self.assertEqual(parse('7 * 3 * bool', self.sym),
                      ct.DataShape(ct.Fixed(7), ct.Fixed(3), ct.bool_))
     self.assertEqual(parse('5 * 3 * 12 * bool', self.sym),
                      ct.DataShape(ct.Fixed(5), ct.Fixed(3),
                                   ct.Fixed(12), ct.bool_))
     self.assertEqual(parse('2 * 3 * 4 * 5 * bool', self.sym),
                      ct.DataShape(ct.Fixed(2), ct.Fixed(3),
                                   ct.Fixed(4), ct.Fixed(5), ct.bool_))
Ejemplo n.º 3
0
 def test_dshape_matches_concrete(self):
     # Exact match, same signature and zero cost
     at = dshape('(3 * int32, 2 * var * float64)')
     sig = dshape('(3 * int32, 2 * var * float64) -> 4 * int16')
     self.assertEqual(match_argtypes_to_signature(at, sig),
                      (sig[0], 0))
     # Requires broadcasting
     at = dshape('(1 * int32, 2 * 4 * float64)')
     sig = dshape('(3 * int32, 2 * var * float64) -> 4 * int16')
     self.assertEqual(match_argtypes_to_signature(at, sig),
                      (sig[0], max(dim_coercion_cost(T.Fixed(1), T.Fixed(3)),
                                   dim_coercion_cost(T.Fixed(4), T.Var()))))
Ejemplo n.º 4
0
 def test_var_dims(self):
     self.assertEqual(parse('var * bool', self.sym),
                      ct.DataShape(ct.Var(), ct.bool_))
     self.assertEqual(parse('var * var * bool', self.sym),
                      ct.DataShape(ct.Var(), ct.Var(), ct.bool_))
     self.assertEqual(parse('M * 5 * var * bool', self.sym),
                      ct.DataShape(ct.TypeVar('M'), ct.Fixed(5), ct.Var(), ct.bool_))
Ejemplo n.º 5
0
 def test_ellipses(self):
     self.assertEqual(parse('... * bool', self.sym),
                      ct.DataShape(ct.Ellipsis(), ct.bool_))
     self.assertEqual(parse('M * ... * bool', self.sym),
                      ct.DataShape(ct.TypeVar('M'), ct.Ellipsis(), ct.bool_))
     self.assertEqual(parse('M * ... * 3 * bool', self.sym),
                      ct.DataShape(ct.TypeVar('M'), ct.Ellipsis(),
                                   ct.Fixed(3), ct.bool_))
Ejemplo n.º 6
0
 def test_typevar_dims(self):
     self.assertEqual(parse('M * bool', self.sym),
                      ct.DataShape(ct.TypeVar('M'), ct.bool_))
     self.assertEqual(parse('A * B * bool', self.sym),
                      ct.DataShape(ct.TypeVar('A'), ct.TypeVar('B'), ct.bool_))
     self.assertEqual(parse('A... * X * 3 * bool', self.sym),
                      ct.DataShape(ct.Ellipsis(ct.TypeVar('A')), ct.TypeVar('X'),
                                   ct.Fixed(3), ct.bool_))
Ejemplo n.º 7
0
 def test_match_equation_ellipsis(self):
     # Matching an ellipsis
     eqns = _match_equation(dshape('int32'), dshape('... * int32'))
     self.assertEqual(eqns, [([], T.Ellipsis()),
                             (T.int32, T.int32)])
     eqns = _match_equation(dshape('3 * int32'), dshape('... * int32'))
     self.assertEqual(eqns, [([T.Fixed(3)], T.Ellipsis()),
                             (T.int32, T.int32)])
     eqns = _match_equation(dshape('3 * var * int32'), dshape('... * int32'))
     self.assertEqual(eqns, [([T.Fixed(3), T.Var()], T.Ellipsis()),
                             (T.int32, T.int32)])
     # Matching an ellipsis type variable
     eqns = _match_equation(dshape('int32'), dshape('A... * int32'))
     self.assertEqual(eqns, [([], T.Ellipsis(T.TypeVar('A'))),
                             (T.int32, T.int32)])
     eqns = _match_equation(dshape('3 * int32'), dshape('A... * int32'))
     self.assertEqual(eqns, [([T.Fixed(3)], T.Ellipsis(T.TypeVar('A'))),
                             (T.int32, T.int32)])
     eqns = _match_equation(dshape('3 * var * int32'), dshape('A... * int32'))
     self.assertEqual(eqns, [([T.Fixed(3), T.Var()], T.Ellipsis(T.TypeVar('A'))),
                             (T.int32, T.int32)])
     # Matching an ellipsis with a dim type variable on the left
     eqns = _match_equation(dshape('3 * var * int32'), dshape('A * B... * int32'))
     self.assertEqual(eqns, [(T.Fixed(3), T.TypeVar('A')),
                             ([T.Var()], T.Ellipsis(T.TypeVar('B'))),
                             (T.int32, T.int32)])
     # Matching an ellipsis with a dim type variable on the right
     eqns = _match_equation(dshape('3 * var * int32'), dshape('A... * B * int32'))
     self.assertEqual(eqns, [([T.Fixed(3)], T.Ellipsis(T.TypeVar('A'))),
                             (T.Var(), T.TypeVar('B')),
                             (T.int32, T.int32)])
     # Matching an ellipsis with a dim type variable on both sides
     eqns = _match_equation(dshape('3 * var * int32'), dshape('A * B... * C * int32'))
     self.assertEqual(eqns, [(T.Fixed(3), T.TypeVar('A')),
                             ([], T.Ellipsis(T.TypeVar('B'))),
                             (T.Var(), T.TypeVar('C')),
                             (T.int32, T.int32)])
     eqns = _match_equation(dshape('3 * var * 4 * M * int32'), dshape('A * B... * C * int32'))
     self.assertEqual(eqns, [(T.Fixed(3), T.TypeVar('A')),
                             ([T.Var(), T.Fixed(4)], T.Ellipsis(T.TypeVar('B'))),
                             (T.TypeVar('M'), T.TypeVar('C')),
                             (T.int32, T.int32)])
Ejemplo n.º 8
0
 def test_struct(self):
     # Simple struct
     self.assertEqual(
         parse('{x: int16, y: int32}', self.sym),
         ct.DataShape(
             ct.Record([('x', ct.DataShape(ct.int16)),
                        ('y', ct.DataShape(ct.int32))])))
     # A trailing comma is ok
     self.assertEqual(
         parse('{x: int16, y: int32,}', self.sym),
         ct.DataShape(
             ct.Record([('x', ct.DataShape(ct.int16)),
                        ('y', ct.DataShape(ct.int32))])))
     # Field names starting with _ and caps
     self.assertEqual(
         parse('{_x: int16, Zed: int32,}', self.sym),
         ct.DataShape(
             ct.Record([('_x', ct.DataShape(ct.int16)),
                        ('Zed', ct.DataShape(ct.int32))])))
     # A slightly bigger example
     ds_str = """3 * var * {
                     id : int32,
                     name : string,
                     description : {
                         language : string,
                         text : string
                     },
                     entries : var * {
                         date : date,
                         text : string
                     }
                 }"""
     int32 = ct.DataShape(ct.int32)
     string = ct.DataShape(ct.string)
     date = ct.DataShape(ct.date_)
     ds = (ct.Fixed(3), ct.Var(),
           ct.Record([
               ('id', int32), ('name', string),
               ('description',
                ct.DataShape(
                    ct.Record([('language', string), ('text', string)]))),
               ('entries',
                ct.DataShape(ct.Var(),
                             ct.Record([('date', date), ('text', string)])))
           ]))
     self.assertEqual(parse(ds_str, self.sym), ct.DataShape(*ds))
Ejemplo n.º 9
0
 def test_option(self):
     self.assertEqual(parse('option[int32]', self.sym),
                      ct.DataShape(ct.Option(ct.int32)))
     self.assertEqual(parse('?int32', self.sym),
                      ct.DataShape(ct.Option(ct.int32)))
     self.assertEqual(
         parse('2 * 3 * option[int32]', self.sym),
         ct.DataShape(ct.Fixed(2), ct.Fixed(3), ct.Option(ct.int32)))
     self.assertEqual(
         parse('2 * 3 * ?int32', self.sym),
         ct.DataShape(ct.Fixed(2), ct.Fixed(3), ct.Option(ct.int32)))
     self.assertEqual(
         parse('2 * option[3 * int32]', self.sym),
         ct.DataShape(ct.Fixed(2),
                      ct.Option(ct.DataShape(ct.Fixed(3), ct.int32))))
     self.assertEqual(
         parse('2 * ?3 * int32', self.sym),
         ct.DataShape(ct.Fixed(2),
                      ct.Option(ct.DataShape(ct.Fixed(3), ct.int32))))