def test_replacer(self): look = L.pe('R.smlookup("bu", x)') dem1 = L.pe('DEMQUERY(foo, [y], R.smlookup("bu", y))') dem2 = L.pe('DEMQUERY(bar, [z], R.smlookup("bu", z))') tree = L.pe('x + LOOK + DEM1 + DEM1 + DEM2', subst={ 'LOOK': look, 'DEM1': dem1, 'DEM2': dem2 }) namer = L.NameGenerator() replacer = LookupReplacer(namer) tree, clauses = replacer.process(tree) repls = replacer.repls exp_tree = L.pe('x + v1 + v2 + v2 + v3') exp_clauses = [ L.Enumerator(L.sn('v1'), L.pe('{R.smlookup("bu", x)}')), L.Enumerator(L.sn('v2'), L.pe('DEMQUERY(foo, [y], {R.smlookup("bu", y)})')), L.Enumerator(L.sn('v3'), L.pe('DEMQUERY(bar, [z], {R.smlookup("bu", z)})')), ] exp_repls = { look: 'v1', dem1: 'v2', dem2: 'v3', } self.assertEqual(tree, exp_tree) self.assertEqual(clauses, exp_clauses) self.assertEqual(repls, exp_repls)
def process(expr): """Rewrite any retrievals in the given expression. Return a pair of the new expression, and a list of new clauses to be added for any retrievals not already seen. """ nonlocal seen_map new_expr = replacer.process(expr) new_field_repls = replacer.field_repls - seen_field_repls new_map_repls = replacer.map_repls - seen_map_repls new_clauses = [] for repl in new_field_repls: obj, field, value = repl seen_fields.add(field) seen_field_repls.add(repl) new_cl = L.Enumerator(L.tuplify((obj, value), lval=True), L.ln(make_frel(field))) new_clauses.append(new_cl) for repl in new_map_repls: map, key, value = repl seen_map = True seen_map_repls.add(repl) new_cl = L.Enumerator(L.tuplify((map, key, value), lval=True), L.ln(make_maprel())) new_clauses.append(new_cl) return new_expr, new_clauses
def test_objclausefactory(self): cl = MClause('S', 'x') clast = L.Enumerator(L.tuplify(['S', 'x'], lval=True), L.pe('_M')) cl2 = ObjClauseFactory.from_AST(clast) self.assertEqual(cl2, cl) cl = FClause_NoTC('o', 'v', 'f') clast = L.Enumerator(L.tuplify(['o', 'v'], lval=True), L.pe('_F_f')) cl2 = ObjClauseFactory_NoTC.from_AST(clast) self.assertEqual(cl2, cl) self.assertIsInstance(cl2, FClause_NoTC)
def unflatten_set_clause(cl): """Opposite of above. Unflatten clauses over the M-set. Works for both enumerators and conditions. Returns the (possibly unchanged) clause. """ # Enumerator case. if isinstance(cl, L.Enumerator): res = get_menum(cl) if res is None: return cl cont, item = res cont = L.ContextSetter.run(cont, L.Load) new_cl = L.Enumerator(item, cont) return new_cl # Condition case. if isinstance(cl, L.expr) and L.is_cmp(cl): lhs, op, rhs = L.get_cmp(cl) if not (isinstance(op, L.In) and isinstance(lhs, L.Tuple) and len(lhs.elts) == 2 and L.is_name(rhs) and is_mrel(L.get_name(rhs))): return cl cont, item = lhs.elts new_cl = L.cmp(item, L.In(), cont) return new_cl return cl
def visit_DemQuery(self, node): self.demwrapped_nodes.add(node.value) node = self.generic_visit(node) if not isinstance(node.value, L.SMLookup): return node sm = node.value assert sm.default is None v = self.repls.get(node, None) if v is not None: # Reuse existing entry. var = v else: # Create new entry. self.repls[node] = var = next(self.namer) # Create accompanying clause. Has form # var in DEMQUERY(..., {smlookup}) # The clause constructor logic will later rewrite that, # or else fail if there's a syntax problem. cl_target = L.sn(var) cl_iter = node._replace(value=L.Set((sm,))) new_cl = L.Enumerator(cl_target, cl_iter) self.new_clauses.append(new_cl) return L.ln(var)
def test_deltaclause(self): cl = DeltaClause(('x', 'y'), 'R', L.pe('e'), 1) # AST round-trip. clast = cl.to_AST() exp_clast = L.Enumerator(L.tuplify(['x', 'y'], lval=True), L.pe('deltamatch(R, "bb", e, 1)')) self.assertEqual(clast, exp_clast) cl2 = DeltaClause.from_AST(exp_clast, DummyFactory) self.assertEqual(cl2, cl) # Attributes. self.assertEqual(cl.rel, 'R') # Code generation, no fancy mask. code = cl.get_code([], L.pc('pass')) exp_code = L.pc(''' x, y = e pass ''') self.assertEqual(code, exp_code) # Code generation, fancy mask. cl2 = DeltaClause(('x', 'x', '_'), 'R', L.pe('e'), 1) code = cl2.get_code([], L.pc('pass')) exp_code = L.pc(''' for x in setmatch(deltamatch(R, 'b1w', e, 1), 'u1w', ()): pass ''') self.assertEqual(code, exp_code)
def test_enumclause_basic(self): cl = EnumClause(('x', 'y', 'x', '_'), 'R') # From expression. cl2 = EnumClause.from_expr(L.pe('(x, y, x, _) in R')) self.assertEqual(cl2, cl) # AST round-trip. clast = cl.to_AST() exp_clast = \ L.Enumerator(L.tuplify(['x', 'y', 'x', '_'], lval=True), L.ln('R')) self.assertEqual(clast, exp_clast) cl2 = EnumClause.from_AST(exp_clast, DummyFactory) self.assertEqual(cl2, cl) # Attributes. self.assertFalse(cl.isdelta) self.assertEqual(cl.enumlhs, ('x', 'y', 'x', '_')) self.assertEqual(cl.enumvars, ('x', 'y')) self.assertEqual(cl.pat_mask, (True, True, True, True)) self.assertEqual(cl.enumrel, 'R') self.assertTrue(cl.has_wildcards) self.assertEqual(cl.vars, ('x', 'y')) self.assertEqual(cl.eqvars, None) self.assertTrue(cl.robust) self.assertEqual(cl.demname, None) self.assertEqual(cl.demparams, ())
def to_AST(self): mask = Mask.from_keylen(len(self.lhs) - 1) keyvars = self.lhs[:-1] var = self.lhs[-1] sm = L.SMLookup(L.ln(self.rel), mask.make_node().s, L.tuplify(keyvars), None) return L.Enumerator(L.sn(var), L.Set((sm, )))
def test_enumclause_setmatch(self): # Make sure we can convert clauses over setmatches. cl = EnumClause.from_AST( L.Enumerator(L.tuplify(['y'], lval=True), L.SetMatch(L.ln('R'), 'bu', L.ln('x'))), DummyFactory) exp_cl = EnumClause(('x', 'y'), 'R') self.assertEqual(cl, exp_cl)
def flatten_set_clause(cl, input_rels): """Turn a membership clause that is not over a comprehension, special relation, or input relation, into a clause over the M-set. Return a pair of the (possibly unchanged) clause and a bool indicating whether or not the change was done. This also works on condition clauses that express membership constraints. The rewritten clause is still a condition clause. """ def should_trans(rhs): return (not isinstance(rhs, L.Comp) and not (isinstance(rhs, L.Name) and (is_specialrel(rhs.id) or rhs.id in input_rels))) # Enumerator case. if isinstance(cl, L.Enumerator) and should_trans(cl.iter): item = cl.target cont = cl.iter cont = L.ContextSetter.run(cont, L.Store) new_cl = L.Enumerator(L.tuplify((cont, item), lval=True), L.ln(make_mrel())) return new_cl, True # Condition case. if isinstance(cl, L.expr) and L.is_cmp(cl): item, op, cont = L.get_cmp(cl) if isinstance(op, L.In) and should_trans(cont): new_cl = L.cmp(L.tuplify((cont, item)), L.In(), L.ln(make_mrel())) return new_cl, True return cl, False
def test_objclausefactory(self): cl = TClause('t', ['x', 'y']) clast = L.Enumerator(L.tuplify(['t', 'x', 'y'], lval=True), L.pe('_TUP2')) cl2 = TupClauseFactory.from_AST(clast) self.assertEqual(cl2, cl) cl = TClause_NoTC('t', ['x', 'y']) cl2 = TupClauseFactory_NoTC.from_AST(clast) self.assertEqual(cl2, cl) self.assertIsInstance(cl2, TClause_NoTC)
def visit_Tuple(self, node): # No need to recurse, that's taken care of by the caller # of this visitor. tupvar = self.tupvar_namer.next() arity = len(node.elts) trel = make_trel(arity) elts = (L.sn(tupvar), ) + node.elts new_cl = L.Enumerator(L.tuplify(elts, lval=True), L.ln(trel)) self.new_clauses.append(new_cl) self.trels.add(trel) return L.sn(tupvar)
def test_mapclause(self): cl = MapClause('m', 'k', 'v') # Construct from expression. cl2 = MapClause.from_expr(L.pe('(m, k, v) in _MAP')) self.assertEqual(cl2, cl) # AST round-trip. clast = cl.to_AST() exp_clast = L.Enumerator(L.tuplify(('m', 'k', 'v'), lval=True), L.ln('_MAP')) self.assertEqual(clast, exp_clast) cl2 = MapClause.from_AST(exp_clast, ObjClauseFactory) self.assertEqual(cl2, cl) # Attributes. self.assertEqual(cl.enumlhs, ('m', 'k', 'v')) self.assertEqual(cl.pat_mask, (False, True, True)) self.assertEqual(cl.enumvars_tagsin, ('m', )) self.assertEqual(cl.enumvars_tagsout, ('k', 'v')) # Rate. rate = cl.rate([]) self.assertEqual(rate, Rate.UNRUNNABLE) # Code. code = cl.get_code(['m'], L.pc('pass')) exp_code = L.pc(''' if isinstance(m, Map): for k, v in m.items(): pass ''') self.assertEqual(code, exp_code) code = cl.get_code(['m', 'k'], L.pc('pass')) exp_code = L.pc(''' if isinstance(m, Map): if k in m: v = m[k] pass ''') self.assertEqual(code, exp_code) # Code, no type-checks. cl = MapClause_NoTC('m', 'k', 'v') code = cl.get_code(['m'], L.pc('pass')) exp_code = L.pc(''' for k, v in m.items(): pass ''') self.assertEqual(code, exp_code)
def test_clausefactory(self): # Construct from AST. clast = L.Enumerator(L.tuplify(['x', 'y'], lval=True), L.pe('R - {e}')) cl = ClauseFactory.from_AST(clast) exp_cl = SubClause(EnumClause(('x', 'y'), 'R'), L.pe('e')) self.assertEqual(cl, exp_cl) # rewrite_subst(). cl = SubClause(EnumClause(('x', 'y'), 'R'), L.pe('e')) cl = ClauseFactory.rewrite_subst(cl, {'x': 'z'}) exp_cl = SubClause(EnumClause(('z', 'y'), 'R'), L.pe('e')) self.assertEqual(cl, exp_cl) # bind(). cl = EnumClause(('x', 'y'), 'R') cl = ClauseFactory.bind(cl, L.pe('e'), augmented=False) exp_cl = DeltaClause(['x', 'y'], 'R', L.pe('e'), 1) self.assertEqual(cl, exp_cl) # subtract(). cl = EnumClause(('x', 'y'), 'R') cl = ClauseFactory.subtract(cl, L.pe('e')) exp_cl = SubClause(EnumClause(('x', 'y'), 'R'), L.pe('e')) self.assertEqual(cl, exp_cl) # augment(). cl = EnumClause(['x', 'y'], 'R') cl = ClauseFactory.augment(cl, L.pe('e')) exp_cl = AugClause(EnumClause(['x', 'y'], 'R'), L.pe('e')) self.assertEqual(cl, exp_cl) # rewrite_rel(). cl = SubClause(EnumClause(('x', 'y'), 'R'), L.pe('e')) cl = ClauseFactory.rewrite_rel(cl, 'S') exp_cl = SubClause(EnumClause(('x', 'y'), 'S'), L.pe('e')) self.assertEqual(cl, exp_cl) # membercond_to_enum(). cl = CondClause(L.pe('(x, y) in R')) cl = ClauseFactory.membercond_to_enum(cl) exp_cl = EnumClause(('x', 'y'), 'R') self.assertEqual(cl, exp_cl) # enum_to_membercond(). cl = EnumClause(('x', 'y'), 'R') cl = ClauseFactory.enum_to_membercond(cl) exp_cl = CondClause(L.pe('(x, y) in R')) self.assertEqual(cl, exp_cl)
def test_tclause(self): cl = TClause('t', ['x', 'y']) # AST round-trip. clast = cl.to_AST() exp_clast = L.Enumerator(L.tuplify(('t', 'x', 'y'), lval=True), L.ln('_TUP2')) self.assertEqual(clast, exp_clast) cl2 = TClause.from_AST(exp_clast, TupClauseFactory) self.assertEqual(cl2, cl) # Attributes. self.assertEqual(cl.enumlhs, ('t', 'x', 'y')) self.assertEqual(cl.pat_mask, (True, True, True)) self.assertEqual(cl.enumvars_tagsin, ('t', )) self.assertEqual(cl.enumvars_tagsout, ('x', 'y')) self.assertCountEqual(cl.get_domain_constrs('_'), [('_t', ('<T>', '_t.1', '_t.2')), ('_t.1', '_x'), ('_t.2', '_y')]) # Rate. rate = cl.rate([]) self.assertEqual(rate, Rate.UNRUNNABLE) rate = cl.rate(['t']) self.assertEqual(rate, Rate.CONSTANT) # Code. code = cl.get_code(['t'], L.pc('pass')) exp_code = L.pc(''' if isinstance(t, tuple) and len(t) == 2: for x, y in setmatch({(t, t[0], t[1])}, 'buu', t): pass ''') self.assertEqual(code, exp_code) # Code, no type-checks. cl = TClause_NoTC('t', ['x', 'y']) code = cl.get_code(['t'], L.pc('pass')) exp_code = L.pc(''' for x, y in setmatch({(t, t[0], t[1])}, 'buu', t): pass ''') self.assertEqual(code, exp_code)
def test_fclause(self): cl = FClause('o', 'v', 'f') # Construct from expression. cl2 = FClause.from_expr(L.pe('(o, v) in _F_f')) self.assertEqual(cl2, cl) # AST round-trip. clast = cl.to_AST() exp_clast = L.Enumerator(L.tuplify(('o', 'v'), lval=True), L.ln('_F_f')) self.assertEqual(clast, exp_clast) cl2 = FClause.from_AST(exp_clast, ObjClauseFactory) self.assertEqual(cl2, cl) # Attributes. self.assertEqual(cl.enumlhs, ('o', 'v')) self.assertEqual(cl.pat_mask, (False, True)) self.assertEqual(cl.enumvars_tagsin, ('o', )) self.assertEqual(cl.enumvars_tagsout, ('v', )) # Rate. rate = cl.rate([]) self.assertEqual(rate, Rate.UNRUNNABLE) rate = cl.rate(['o']) self.assertEqual(rate, Rate.CONSTANT) # Code. code = cl.get_code(['o'], L.pc('pass')) exp_code = L.pc(''' if hasattr(o, 'f'): v = o.f pass ''') self.assertEqual(code, exp_code) # Code, no type-checks. cl = FClause_NoTC('o', 'v', 'f') code = cl.get_code(['o'], L.pc('pass')) exp_code = L.pc(''' v = o.f pass ''') self.assertEqual(code, exp_code)
def test(self): cl = DemClause(EnumClause(['x', 'y'], 'R'), 'f', ['x']) # AST round-trip. clast = cl.to_AST() exp_clast = \ L.Enumerator(L.tuplify(['x', 'y'], lval=True), L.DemQuery('f', (L.ln('x'),), L.ln('R'))) self.assertEqual(clast, exp_clast) cl2 = DemClause.from_AST(exp_clast, DemClauseFactory) self.assertEqual(cl2, cl) # Attributes. self.assertEqual(cl.pat_mask, (True, True)) self.assertEqual(cl.enumvars_tagsin, ('x', )) self.assertEqual(cl.enumvars_tagsout, ('y', )) # Rewriting. cl2 = cl.rewrite_subst({'x': 'z'}, DemClauseFactory) exp_cl = DemClause(EnumClause(['z', 'y'], 'R'), 'f', ['z']) self.assertEqual(cl2, exp_cl) # Fancy rewriting, uses LookupClause. cl2 = DemClause(LookupClause(['x', 'y'], 'R'), 'f', ['x']) cl2 = cl2.rewrite_subst({'x': 'z'}, DemClauseFactory) exp_cl = DemClause(LookupClause(['z', 'y'], 'R'), 'f', ['z']) self.assertEqual(cl2, exp_cl) # Rating. rate = cl.rate(['x']) self.assertEqual(rate, Rate.NORMAL) rate = cl.rate([]) self.assertEqual(rate, Rate.UNRUNNABLE) # Code generation. code = cl.get_code(['x'], L.pc('pass')) exp_code = L.pc(''' DEMQUERY(f, [x], None) for y in setmatch(R, 'bu', x): pass ''') self.assertEqual(code, exp_code)
def membercond_to_enum(cls, cl): """For a condition clause that expresses a membership, return an equivalent enumerator clause. For other kinds of conditions, return the same clause. For enumerators, raise TypeError. """ if cl.kind is not Clause.KIND_COND: raise TypeError compre_ast = None clast = cl.to_AST() if L.is_cmp(clast): lhs, op, rhs = L.get_cmp(clast) if (L.is_vartuple(lhs) and isinstance(op, L.In)): compre_ast = L.Enumerator( L.tuplify(L.get_vartuple(lhs), lval=True), rhs) if compre_ast is None: return cl else: return cls.from_AST(compre_ast)
def visit_SMLookup(self, node): node = self.generic_visit(node) if node in self.demwrapped_nodes: return node sm = node assert sm.default is None v = self.repls.get(node, None) if v is not None: var = v else: self.repls[node] = var = next(self.namer) cl_target = L.sn(var) cl_iter = L.Set((sm,)) new_cl = L.Enumerator(cl_target, cl_iter) self.new_clauses.append(new_cl) return L.ln(var)
def test_mclause(self): cl = MClause('S', 'x') # Construct from expression. cl2 = MClause.from_expr(L.pe('(S, x) in _M')) self.assertEqual(cl2, cl) # AST round-trip. clast = cl.to_AST() exp_clast = L.Enumerator(L.tuplify(('S', 'x'), lval=True), L.ln('_M')) self.assertEqual(clast, exp_clast) cl2 = MClause.from_AST(exp_clast, ObjClauseFactory) self.assertEqual(cl2, cl) # Attributes. self.assertEqual(cl.enumlhs, ('S', 'x')) self.assertEqual(cl.pat_mask, (False, True)) self.assertEqual(cl.enumvars_tagsin, ('S', )) self.assertEqual(cl.enumvars_tagsout, ('x', )) # Rate. rate = cl.rate([]) self.assertEqual(rate, Rate.UNRUNNABLE) # Code. code = cl.get_code(['S'], L.pc('pass')) exp_code = L.pc(''' if isinstance(S, Set): for x in S: pass ''') self.assertEqual(code, exp_code) # Code, no type-checks. cl = MClause_NoTC('S', 'x') code = cl.get_code(['S'], L.pc('pass')) exp_code = L.pc(''' for x in S: pass ''') self.assertEqual(code, exp_code)
def test_lookupclause(self): cl = LookupClause(('x', 'y', 'z'), 'R') # AST round-trip. clast = cl.to_AST() sm = L.SMLookup(L.ln('R'), 'bbu', L.tuplify(['x', 'y']), None) exp_clast = L.Enumerator(L.sn('z'), L.Set((sm, ))) self.assertEqual(clast, exp_clast) cl2 = LookupClause.from_AST(exp_clast, DummyFactory) self.assertEqual(cl2, cl) # Attributes. self.assertEqual(cl.enumvars, ('x', 'y', 'z')) # Rewriting. cl2 = cl.rewrite_subst({'x': 'xx', 'z': 'zz'}, DummyFactory) self.assertEqual(cl2, LookupClause(('xx', 'y', 'zz'), 'R')) # Rating. self.assertEqual(cl.rate(['x']), Rate.NORMAL) self.assertEqual(cl.rate(['x', 'y']), Rate.CONSTANT)
def visit_Aggregate(self, node): node = self.generic_visit(node) operand = node.value if isinstance(operand, L.Comp): return node if not is_retrievalchain(operand): # Bailout, looks like we won't be able to incrementalize # this later anyway. return node # Replace with {_e for _e in OPERAND}. # This case is for both single vars and retrieval chains. # The comp's options are inherited from the aggregate. params = get_retrieval_params(operand) elem = '_e' clause = L.Enumerator(target=L.sn(elem), iter=operand) node = node._replace(value=L.Comp(resexp=L.ln(elem), clauses=(clause, ), params=params, options=node.options)) return node
def test_subclause(self): cl = SubClause(EnumClause(('x', 'y'), 'R'), L.pe('e')) # AST round-trip. clast = cl.to_AST() exp_clast = L.Enumerator(L.tuplify(['x', 'y'], lval=True), L.pe('R - {e}')) self.assertEqual(clast, exp_clast) cl2 = SubClause.from_AST(exp_clast, DummyFactory) self.assertEqual(cl2, cl) # Attributes. self.assertEqual(cl.enumlhs, ('x', 'y')) self.assertFalse(cl.robust) # Code generation. code = cl.get_code([], L.pc('pass')) exp_code = L.pc(''' for (x, y) in R: if (x, y) != e: pass ''') self.assertEqual(code, exp_code)
def test_singletonclause(self): cl = SingletonClause(('x', 'y'), L.pe('e')) # From expression. cl2 = SingletonClause.from_expr(L.pe('(x, y) == e')) self.assertEqual(cl, cl2) # AST round-trip. clast = cl.to_AST() exp_clast = L.Enumerator(L.tuplify(['x', 'y'], lval=True), L.pe('{e}')) self.assertEqual(clast, exp_clast) cl2 = SingletonClause.from_AST(exp_clast, DummyFactory) self.assertEqual(cl2, cl) # Attributes. self.assertEqual(cl.enumvars, ('x', 'y')) # Code generation. code = cl.get_code([], L.pc('pass')) exp_code = L.pc(''' x, y = e pass ''') self.assertEqual(code, exp_code)
def to_AST(self): return L.Enumerator(L.tuplify(self.lhs, lval=True), L.ln(self.rel))
def to_AST(self): return L.Enumerator(L.tuplify(self.lhs, lval=True), L.Set( (self.val, )))
def to_AST(self): mask = Mask.from_vars(self.lhs, self.lhs) return L.Enumerator( L.tuplify(self.lhs, lval=True), L.DeltaMatch(L.ln(self.rel), mask.make_node().s, self.val, self.limit))
def test_getclausevars(self): lhs = L.Tuple((L.sn('x'), L.Tuple((L.sn('y'), L.sn('z')), L.Store())), L.Store()) vars = get_clause_vars(L.Enumerator(lhs, L.ln('R')), self.tuptype) exp_vars = ['x', 'y', 'z'] self.assertEqual(vars, exp_vars)