def test_match_equation_dim(self): # Broadcasting a single dimension eqns = _match_equation(dshape('1 * int32'), dshape('10 * int32')) self.assertEqual(eqns, [(T.Fixed(1), T.Fixed(10)), (T.int32, T.int32)]) # Matching a dim type variable eqns = _match_equation(dshape('3 * int32'), dshape('M * int32')) self.assertEqual(eqns, [(T.Fixed(3), T.TypeVar('M')), (T.int32, T.int32)])
def test_fixed_dims(self): self.assertEqual(parse('3 * bool', self.sym), ct.DataShape(ct.Fixed(3), ct.bool_)) self.assertEqual(parse('7 * 3 * bool', self.sym), ct.DataShape(ct.Fixed(7), ct.Fixed(3), ct.bool_)) self.assertEqual(parse('5 * 3 * 12 * bool', self.sym), ct.DataShape(ct.Fixed(5), ct.Fixed(3), ct.Fixed(12), ct.bool_)) self.assertEqual(parse('2 * 3 * 4 * 5 * bool', self.sym), ct.DataShape(ct.Fixed(2), ct.Fixed(3), ct.Fixed(4), ct.Fixed(5), ct.bool_))
def test_dshape_matches_concrete(self): # Exact match, same signature and zero cost at = dshape('(3 * int32, 2 * var * float64)') sig = dshape('(3 * int32, 2 * var * float64) -> 4 * int16') self.assertEqual(match_argtypes_to_signature(at, sig), (sig[0], 0)) # Requires broadcasting at = dshape('(1 * int32, 2 * 4 * float64)') sig = dshape('(3 * int32, 2 * var * float64) -> 4 * int16') self.assertEqual(match_argtypes_to_signature(at, sig), (sig[0], max(dim_coercion_cost(T.Fixed(1), T.Fixed(3)), dim_coercion_cost(T.Fixed(4), T.Var()))))
def test_var_dims(self): self.assertEqual(parse('var * bool', self.sym), ct.DataShape(ct.Var(), ct.bool_)) self.assertEqual(parse('var * var * bool', self.sym), ct.DataShape(ct.Var(), ct.Var(), ct.bool_)) self.assertEqual(parse('M * 5 * var * bool', self.sym), ct.DataShape(ct.TypeVar('M'), ct.Fixed(5), ct.Var(), ct.bool_))
def test_ellipses(self): self.assertEqual(parse('... * bool', self.sym), ct.DataShape(ct.Ellipsis(), ct.bool_)) self.assertEqual(parse('M * ... * bool', self.sym), ct.DataShape(ct.TypeVar('M'), ct.Ellipsis(), ct.bool_)) self.assertEqual(parse('M * ... * 3 * bool', self.sym), ct.DataShape(ct.TypeVar('M'), ct.Ellipsis(), ct.Fixed(3), ct.bool_))
def test_typevar_dims(self): self.assertEqual(parse('M * bool', self.sym), ct.DataShape(ct.TypeVar('M'), ct.bool_)) self.assertEqual(parse('A * B * bool', self.sym), ct.DataShape(ct.TypeVar('A'), ct.TypeVar('B'), ct.bool_)) self.assertEqual(parse('A... * X * 3 * bool', self.sym), ct.DataShape(ct.Ellipsis(ct.TypeVar('A')), ct.TypeVar('X'), ct.Fixed(3), ct.bool_))
def test_match_equation_ellipsis(self): # Matching an ellipsis eqns = _match_equation(dshape('int32'), dshape('... * int32')) self.assertEqual(eqns, [([], T.Ellipsis()), (T.int32, T.int32)]) eqns = _match_equation(dshape('3 * int32'), dshape('... * int32')) self.assertEqual(eqns, [([T.Fixed(3)], T.Ellipsis()), (T.int32, T.int32)]) eqns = _match_equation(dshape('3 * var * int32'), dshape('... * int32')) self.assertEqual(eqns, [([T.Fixed(3), T.Var()], T.Ellipsis()), (T.int32, T.int32)]) # Matching an ellipsis type variable eqns = _match_equation(dshape('int32'), dshape('A... * int32')) self.assertEqual(eqns, [([], T.Ellipsis(T.TypeVar('A'))), (T.int32, T.int32)]) eqns = _match_equation(dshape('3 * int32'), dshape('A... * int32')) self.assertEqual(eqns, [([T.Fixed(3)], T.Ellipsis(T.TypeVar('A'))), (T.int32, T.int32)]) eqns = _match_equation(dshape('3 * var * int32'), dshape('A... * int32')) self.assertEqual(eqns, [([T.Fixed(3), T.Var()], T.Ellipsis(T.TypeVar('A'))), (T.int32, T.int32)]) # Matching an ellipsis with a dim type variable on the left eqns = _match_equation(dshape('3 * var * int32'), dshape('A * B... * int32')) self.assertEqual(eqns, [(T.Fixed(3), T.TypeVar('A')), ([T.Var()], T.Ellipsis(T.TypeVar('B'))), (T.int32, T.int32)]) # Matching an ellipsis with a dim type variable on the right eqns = _match_equation(dshape('3 * var * int32'), dshape('A... * B * int32')) self.assertEqual(eqns, [([T.Fixed(3)], T.Ellipsis(T.TypeVar('A'))), (T.Var(), T.TypeVar('B')), (T.int32, T.int32)]) # Matching an ellipsis with a dim type variable on both sides eqns = _match_equation(dshape('3 * var * int32'), dshape('A * B... * C * int32')) self.assertEqual(eqns, [(T.Fixed(3), T.TypeVar('A')), ([], T.Ellipsis(T.TypeVar('B'))), (T.Var(), T.TypeVar('C')), (T.int32, T.int32)]) eqns = _match_equation(dshape('3 * var * 4 * M * int32'), dshape('A * B... * C * int32')) self.assertEqual(eqns, [(T.Fixed(3), T.TypeVar('A')), ([T.Var(), T.Fixed(4)], T.Ellipsis(T.TypeVar('B'))), (T.TypeVar('M'), T.TypeVar('C')), (T.int32, T.int32)])
def test_struct(self): # Simple struct self.assertEqual( parse('{x: int16, y: int32}', self.sym), ct.DataShape( ct.Record([('x', ct.DataShape(ct.int16)), ('y', ct.DataShape(ct.int32))]))) # A trailing comma is ok self.assertEqual( parse('{x: int16, y: int32,}', self.sym), ct.DataShape( ct.Record([('x', ct.DataShape(ct.int16)), ('y', ct.DataShape(ct.int32))]))) # Field names starting with _ and caps self.assertEqual( parse('{_x: int16, Zed: int32,}', self.sym), ct.DataShape( ct.Record([('_x', ct.DataShape(ct.int16)), ('Zed', ct.DataShape(ct.int32))]))) # A slightly bigger example ds_str = """3 * var * { id : int32, name : string, description : { language : string, text : string }, entries : var * { date : date, text : string } }""" int32 = ct.DataShape(ct.int32) string = ct.DataShape(ct.string) date = ct.DataShape(ct.date_) ds = (ct.Fixed(3), ct.Var(), ct.Record([ ('id', int32), ('name', string), ('description', ct.DataShape( ct.Record([('language', string), ('text', string)]))), ('entries', ct.DataShape(ct.Var(), ct.Record([('date', date), ('text', string)]))) ])) self.assertEqual(parse(ds_str, self.sym), ct.DataShape(*ds))
def test_option(self): self.assertEqual(parse('option[int32]', self.sym), ct.DataShape(ct.Option(ct.int32))) self.assertEqual(parse('?int32', self.sym), ct.DataShape(ct.Option(ct.int32))) self.assertEqual( parse('2 * 3 * option[int32]', self.sym), ct.DataShape(ct.Fixed(2), ct.Fixed(3), ct.Option(ct.int32))) self.assertEqual( parse('2 * 3 * ?int32', self.sym), ct.DataShape(ct.Fixed(2), ct.Fixed(3), ct.Option(ct.int32))) self.assertEqual( parse('2 * option[3 * int32]', self.sym), ct.DataShape(ct.Fixed(2), ct.Option(ct.DataShape(ct.Fixed(3), ct.int32)))) self.assertEqual( parse('2 * ?3 * int32', self.sym), ct.DataShape(ct.Fixed(2), ct.Option(ct.DataShape(ct.Fixed(3), ct.int32))))