def test_combine_nested_maps(): x = np.array([[1,2], [3,4]]) typed_fn = parakeet.typed_repr(allpairs_dist, [x]) typed_fn = high_level_optimizations.apply(typed_fn) assert len(typed_fn.body) == 1 stmt = typed_fn.body[0] assert stmt.__class__ is Return v = stmt.value assert v.__class__ is OuterMap assert len(v.args) == 2
def test_combine_nested_maps(): x = np.array([[1,2], [3,4]]) typed_fn = parakeet.typed_repr(allpairs_dist, [x]) typed_fn = high_level_optimizations.apply(typed_fn) assert len(typed_fn.body) == 1 stmt = typed_fn.body[0] assert stmt.__class__ is Return, "Expected Return but got %s" % stmt v = stmt.value assert v.__class__ is OuterMap, "Expected OuterMap but got %s" % v assert len(v.args) == 2, "Expected OuterMap to have two args, but got %s" % (v.args,)
def test_combine_nested_maps(): x = np.array([[1, 2], [3, 4]]) typed_fn = parakeet.typed_repr(allpairs_dist, [x]) typed_fn = high_level_optimizations.apply(typed_fn) assert len(typed_fn.body) == 1 stmt = typed_fn.body[0] assert stmt.__class__ is Return, "Expected Return but got %s" % stmt v = stmt.value assert v.__class__ is OuterMap, "Expected OuterMap but got %s" % v assert len( v.args) == 2, "Expected OuterMap to have two args, but got %s" % ( v.args, )
def test_combine_nested_index_maps(): n = 3 typed_fn = parakeet.typed_repr(allpairs_add_imap, [n]) print "TYPED FN", typed_fn typed_fn = high_level_optimizations.apply(typed_fn) print "OPT FN", typed_fn assert len(typed_fn.body) in (1, 2) stmt = typed_fn.body[-1] print "STMT", stmt assert stmt.__class__ is Return, "Expected Return but got %s" % stmt v = stmt.value print "VALUE", v assert v.__class__ is IndexMap, "Expected IndexMap but got %s" % v nested_fn = get_fn(v.fn) assert len(nested_fn.body) == 1 nested_stmt = nested_fn.body[0] assert nested_stmt.__class__ is Return nested_value = nested_stmt.value assert nested_value.__class__ is PrimCall, "Expected PrimCall but got %s" % nested_value
def test_combine_nested_index_maps(): n = 3 typed_fn = parakeet.typed_repr(allpairs_add_imap, [n]) print "TYPED FN", typed_fn typed_fn = high_level_optimizations.apply(typed_fn) print "OPT FN", typed_fn assert len(typed_fn.body) in (1,2) stmt = typed_fn.body[-1] print "STMT", stmt assert stmt.__class__ is Return, "Expected Return but got %s" % stmt v = stmt.value print "VALUE", v assert v.__class__ is IndexMap, "Expected IndexMap but got %s" % v nested_fn = get_fn(v.fn) assert len(nested_fn.body) == 1 nested_stmt = nested_fn.body[0] assert nested_stmt.__class__ is Return nested_value = nested_stmt.value assert nested_value.__class__ is PrimCall, "Expected PrimCall but got %s" % nested_value