def test_arityfinder(self): comp1 = L.pe('COMP({x for x in S}, [], {})') comp2 = L.pe('COMP({y for y in C1}, [], {})', subst={'C1': comp1}) tree = L.p(''' print(C2) ''', subst={'C2': comp2}) arity = SubqueryArityFinder.run(tree, comp1) self.assertEqual(arity, 1) comp3 = L.pe('COMP({(x, x) for x in S}, [], {})') comp4 = L.pe('COMP({z for (z, z) in C3}, [], {})', subst={'C3': comp3}) tree = L.p(''' print(C4, C4) ''', subst={'C4': comp4}) arity = SubqueryArityFinder.run(tree, comp3) self.assertEqual(arity, 2) comp5 = L.pe('COMP({z for (z, z) in C1}, [], {})', subst={'C1': comp1}) tree = L.p(''' print(C5) ''', subst={'C5': comp5}) arity = SubqueryArityFinder.run(tree, comp1) self.assertEqual(arity, False) tree = L.p(''' print(C2, C1) ''', subst={'C2': comp2, 'C1': comp1}) arity = SubqueryArityFinder.run(tree, comp1) self.assertEqual(arity, False)
def test_spec(self): # Aggregate of a relation. node = L.pe('count(R)') spec = AggrSpec.from_node(node) self.assertEqual(spec.aggrop, 'count') self.assertEqual(spec.rel, 'R') self.assertEqual(spec.relmask, Mask('u')) self.assertEqual(spec.params, ()) self.assertEqual(spec.oper_demname, None) self.assertEqual(spec.oper_demparams, None) constrs = spec.get_domain_constraints('A') exp_constrs = [] self.assertEqual(constrs, exp_constrs) # Aggregate of a setmatch, with demand. node = L.pe('count(DEMQUERY(foo, [c1], ' 'setmatch(R, "bub", (c1, c2))))') spec = AggrSpec.from_node(node) self.assertEqual(spec.aggrop, 'count') self.assertEqual(spec.rel, 'R') self.assertEqual(spec.relmask, Mask('bub')) self.assertEqual(spec.params, ('c1', 'c2')) self.assertEqual(spec.oper_demname, 'foo') self.assertEqual(spec.oper_demparams, ('c1', )) constrs = spec.get_domain_constraints('A') exp_constrs = [('A.1', 'R.1'), ('A.2', 'R.3')] self.assertEqual(constrs, exp_constrs)
def test_retrieval_replacer(self): field_namer = lambda lhs, rhs: 'f_' + lhs + '_' + rhs map_namer = lambda lhs, rhs: 'm_' + lhs + '_' + rhs tree = L.pe('a.b[c.d].e + a[b[c]]') replacer = RetrievalReplacer(field_namer, map_namer) tree = replacer.process(tree) field_repls = replacer.field_repls map_repls = replacer.map_repls exp_tree = L.pe('f_m_f_a_b_f_c_d_e + m_a_m_b_c') exp_field_repls = [ ('a', 'b', 'f_a_b'), ('c', 'd', 'f_c_d'), ('m_f_a_b_f_c_d', 'e', 'f_m_f_a_b_f_c_d_e'), ] exp_map_repls = [ ('f_a_b', 'f_c_d', 'm_f_a_b_f_c_d'), ('b', 'c', 'm_b_c'), ('a', 'm_b_c', 'm_a_m_b_c'), ] self.assertEqual(tree, exp_tree) self.assertSequenceEqual(field_repls, exp_field_repls) self.assertSequenceEqual(map_repls, exp_map_repls)
def test_ucon_params(self): class DummyClause(EnumClause, ABCStruct): lhs = Field() rel = Field() con_mask = (False, True) # Basic. join = Join([ DummyClause(['x', 'y'], 'R'), DummyClause(['y', 'z'], 'R'), ], CF, None) spec = CompSpec(join, L.pe('(x, z)'), ['x', 'y', 'z']) uncons = spec.get_uncon_params() exp_uncons = ['x'] self.assertSequenceEqual(uncons, exp_uncons) # Cycle. join = Join([ DummyClause(['x', 'x'], 'R'), ], CF, None) spec = CompSpec(join, L.pe('x'), ['x']) uncons = spec.get_uncon_params() exp_uncons = ['x'] self.assertSequenceEqual(uncons, exp_uncons) # Cycle with two distinct minimal sets of uncons. join = Join([ DummyClause(['x', 'y'], 'R'), DummyClause(['y', 'x'], 'R'), ], CF, None) spec = CompSpec(join, L.pe('(x, y)'), ['x', 'y']) uncons = spec.get_uncon_params() exp_uncons = ['x'] self.assertSequenceEqual(uncons, exp_uncons)
def test_deltaclause(self): cl = DeltaClause(('x', 'y'), 'R', L.pe('e'), 1) # AST round-trip. clast = cl.to_AST() exp_clast = L.Enumerator(L.tuplify(['x', 'y'], lval=True), L.pe('deltamatch(R, "bb", e, 1)')) self.assertEqual(clast, exp_clast) cl2 = DeltaClause.from_AST(exp_clast, DummyFactory) self.assertEqual(cl2, cl) # Attributes. self.assertEqual(cl.rel, 'R') # Code generation, no fancy mask. code = cl.get_code([], L.pc('pass')) exp_code = L.pc(''' x, y = e pass ''') self.assertEqual(code, exp_code) # Code generation, fancy mask. cl2 = DeltaClause(('x', 'x', '_'), 'R', L.pe('e'), 1) code = cl2.get_code([], L.pc('pass')) exp_code = L.pc(''' for x in setmatch(deltamatch(R, 'b1w', e, 1), 'u1w', ()): pass ''') self.assertEqual(code, exp_code)
def make_retrieval_code(self): """Make code for retrieving the value of the aggregate result, including demanding it. """ incaggr = self.incaggr params_l = L.List(tuple(L.ln(p) for p in incaggr.params), L.Load()) if incaggr.has_demand: code = L.pe(''' DEMQUERY(NAME, PARAMS_L, RES.smlookup(AGGRMASK, PARAMS_T)) ''', subst={ 'NAME': incaggr.name, 'PARAMS_L': params_l, 'PARAMS_T': L.tuplify(incaggr.params), 'RES': incaggr.name, 'AGGRMASK': incaggr.aggrmask.make_node() }) else: code = L.pe(''' RES.smdeflookup(AGGRMASK, PARAMS_T, ZERO) ''', subst={ 'RES': incaggr.name, 'AGGRMASK': incaggr.aggrmask.make_node(), 'PARAMS_T': L.tuplify(incaggr.params), 'ZERO': self.make_zero_mapval_expr(), }) code = self.make_proj_mapval_code(code) return code
def test_condclause(self): cl = CondClause(L.pe('f(a) or g(b)')) # AST round-trip. clast = cl.to_AST() exp_clast = L.pe('f(a) or g(b)') self.assertEqual(clast, exp_clast) cl2 = CondClause.from_AST(exp_clast, DummyFactory) self.assertEqual(cl2, cl) # fits_string(). self.assertTrue(cl.fits_string(['a', 'b'], 'f(a) or g(b)')) # Attributes. self.assertEqual(cl.enumvars, ()) self.assertEqual(cl.vars, ('a', 'b')) cl2 = CondClause(L.pe('a == b')) self.assertEqual(cl2.eqvars, ('a', 'b')) # Rating. self.assertEqual(cl.rate(['a', 'b']), Rate.CONSTANT) self.assertEqual(cl.rate(['a']), Rate.UNRUNNABLE) # Code generation. code = cl.get_code(['a', 'b'], L.pc('pass')) exp_code = L.pc(''' if f(a) or g(b): pass ''') self.assertEqual(code, exp_code)
def visit_Aggregate(self, node): node = self.generic_visit(node) if not node.op in ['min', 'max']: return node func2 = {'min': 'min2', 'max': 'max2'}[node.op] if not L.is_setunion(node.value): return node sets = L.get_setunion(node.value) if len(sets) == 1: # If there's just one set, don't change anything. return node # Wrap each operand in an aggregate query with the same # options as the original aggregate. (This ensures that # 'impl' is carried over.) Set literals are wrapped in # a call to incoq.runtime's min2()/max2() instead of an # Aggregate query node. terms = [] for s in sets: if isinstance(s, (L.Comp, L.Name)): new_term = L.Aggregate(s, node.op, node.options) else: new_term = L.pe('OP(__ARGS)', subst={'OP': L.ln(func2)}) new_term = new_term._replace(args=s.elts) terms.append(new_term) # The new top-level aggregate is min2()/max2(). new_node = L.pe('OP(__ARGS)', subst={'OP': L.ln(func2)}) new_node = new_node._replace(args=tuple(terms)) return new_node
def test_unflatten_subclause(self): # Make sure we don't do anything foolish when presented with # a subtractive enum. comp = L.pe('COMP({z for (x, y) in _M for (y, z) in _M - {e}}, [])') comp = unflatten_comp(comp) exp_comp = L.pe('COMP({z for y in x for (y, z) in _M - {e}}, [])') self.assertEqual(comp, exp_comp)
def test_filter_comps(self): join = self.make_join( 'for (a, b) in R for (b, c) in S for (c, d) in _M') comp1 = L.pe( 'COMP({(a, b, c, d) for (a, b) in deltamatch(S, "bb", e, 1) for (b, c) in S for (c, d) in _M}, ' '[], {})') comp2 = L.pe( 'COMP({(a, b, c, d) for (a, b) in R for (b, c) in S for (c, d) in deltamatch(_M, "bb", e, 1)}, ' '[], {})') tree = L.p(''' print(COMP1) print(COMP2) ''', subst={'COMP1': comp1, 'COMP2': comp2}) ds = make_structures(join.clauses, 'Q', singletag=False, subdem_tags=True) tree, ds = filter_comps(tree, CF, ds, [comp1, comp2], True, augmented=False, subdem_tags=True) struct_names = [s.name for s in ds.tags + ds.filters + ds.usets] exp_tree = L.p(''' print(COMP({(a, b, c, d) for (a, b) in deltamatch(S, 'bb', e, 1) for (b, c) in Q_dS for (c, d) in _M}, [], {})) print(COMP({(a, b, c, d) for (a, b) in R for (b, c) in Q_dS for (c, d) in deltamatch(Q_d_M, 'bb', e, 1) for (c, d) in Q_d_M}, [], {})) ''') exp_struct_names = ['Q_Tb1', 'Q_dS', 'Q_Tc', 'Q_d_M'] self.assertEqual(tree, exp_tree) self.assertCountEqual(struct_names, exp_struct_names)
def make_retrieval_code(self): """Make code for retrieving the value of the aggregate result, including demanding it. """ incaggr = self.incaggr params_l = L.List(tuple(L.ln(p) for p in incaggr.params), L.Load()) if incaggr.has_demand: code = L.pe(''' DEMQUERY(NAME, PARAMS_L, RES.smlookup(AGGRMASK, PARAMS_T)) ''', subst={'NAME': incaggr.name, 'PARAMS_L': params_l, 'PARAMS_T': L.tuplify(incaggr.params), 'RES': incaggr.name, 'AGGRMASK': incaggr.aggrmask.make_node()}) else: code = L.pe(''' RES.smdeflookup(AGGRMASK, PARAMS_T, ZERO) ''', subst={'RES': incaggr.name, 'AGGRMASK': incaggr.aggrmask.make_node(), 'PARAMS_T': L.tuplify(incaggr.params), 'ZERO': self.make_zero_mapval_expr(),}) code = self.make_proj_mapval_code(code) return code
def visit_Module(self, node): incaggr = self.incaggr self.manager.add_invariant(incaggr.name, incaggr) add_prefix = self.manager.namegen.next_prefix() remove_prefix = self.manager.namegen.next_prefix() addcode = self.cg.make_oper_maint(add_prefix, 'add', L.pe('_e')) removecode = self.cg.make_oper_maint(remove_prefix, 'remove', L.pe('_e')) code = L.pc(''' RES = Set() def ADDFUNC(_e): ADDCODE def REMOVEFUNC(_e): REMOVECODE ''', subst={'RES': incaggr.name, '<def>ADDFUNC': self.addfunc, '<c>ADDCODE': addcode, '<def>REMOVEFUNC': self.removefunc, '<c>REMOVECODE': removecode}) node = node._replace(body=code + node.body) node = self.generic_visit(node) return node
def test_reinterpreter_comp(self): comp1 = L.pe('COMP({(x, y, (x, z)) for (x, y) in S ' 'for (y, z) in T}, [], {})') comp2 = L.pe('COMP({(x, x) for (x, y) in U}, [], {})') spec1 = CompSpec.from_comp(comp1, self.manager.factory) spec2 = CompSpec.from_comp(comp2, self.manager.factory) # Dummy wrapper for what would be IncComp. Dummy1 = SimpleNamespace() Dummy1.spec = spec1 Dummy2 = SimpleNamespace() Dummy2.spec = spec2 invs = {'Q': Dummy1, 'S': Dummy2} # Boilerplate domain information regarding the comprehension. constrs = [] constrs.extend(spec1.get_domain_constraints('Q')) constrs.extend(spec2.get_domain_constraints('S')) domain_subst = unify(constrs) domain_subst = add_domain_names(domain_subst, {}) trans = CostReinterpreter(invs, domain_subst, {}, {}) cost = NameCost('Q') cost = trans.process(cost) cost = normalize(cost) exp_cost_str = '(Q_x*Q_z)' self.assertEqual(str(cost), exp_cost_str)
def test_spec(self): # Aggregate of a relation. node = L.pe('count(R)') spec = AggrSpec.from_node(node) self.assertEqual(spec.aggrop, 'count') self.assertEqual(spec.rel, 'R') self.assertEqual(spec.relmask, Mask('u')) self.assertEqual(spec.params, ()) self.assertEqual(spec.oper_demname, None) self.assertEqual(spec.oper_demparams, None) constrs = spec.get_domain_constraints('A') exp_constrs = [] self.assertEqual(constrs, exp_constrs) # Aggregate of a setmatch, with demand. node = L.pe('count(DEMQUERY(foo, [c1], ' 'setmatch(R, "bub", (c1, c2))))') spec = AggrSpec.from_node(node) self.assertEqual(spec.aggrop, 'count') self.assertEqual(spec.rel, 'R') self.assertEqual(spec.relmask, Mask('bub')) self.assertEqual(spec.params, ('c1', 'c2')) self.assertEqual(spec.oper_demname, 'foo') self.assertEqual(spec.oper_demparams, ('c1',)) constrs = spec.get_domain_constraints('A') exp_constrs = [('A.1', 'R.1'), ('A.2', 'R.3')] self.assertEqual(constrs, exp_constrs)
def visit_Module(self, node): incaggr = self.incaggr self.manager.add_invariant(incaggr.name, incaggr) add_prefix = self.manager.namegen.next_prefix() remove_prefix = self.manager.namegen.next_prefix() addcode = self.cg.make_oper_maint(add_prefix, 'add', L.pe('_e')) removecode = self.cg.make_oper_maint(remove_prefix, 'remove', L.pe('_e')) code = L.pc(''' RES = Set() def ADDFUNC(_e): ADDCODE def REMOVEFUNC(_e): REMOVECODE ''', subst={ 'RES': incaggr.name, '<def>ADDFUNC': self.addfunc, '<c>ADDCODE': addcode, '<def>REMOVEFUNC': self.removefunc, '<c>REMOVECODE': removecode }) node = node._replace(body=code + node.body) node = self.generic_visit(node) return node
def test_singletonclause(self): cl = SingletonClause(('x', 'y'), L.pe('e')) # From expression. cl2 = SingletonClause.from_expr(L.pe('(x, y) == e')) self.assertEqual(cl, cl2) # AST round-trip. clast = cl.to_AST() exp_clast = L.Enumerator(L.tuplify(['x', 'y'], lval=True), L.pe('{e}')) self.assertEqual(clast, exp_clast) cl2 = SingletonClause.from_AST(exp_clast, DummyFactory) self.assertEqual(cl2, cl) # Attributes. self.assertEqual(cl.enumvars, ('x', 'y')) # Code generation. code = cl.get_code([], L.pc('pass')) exp_code = L.pc(''' x, y = e pass ''') self.assertEqual(code, exp_code)
def test_basic(self): cl1 = EnumClause(('a', 'b'), 'R') cl2 = EnumClause(('b', 'c'), 'S') cl3 = CondClause(L.pe('a != c')) join = Join([cl1, cl2, cl3], CF, None) # AST round-trip. comp = join.to_comp({}) exp_comp = L.Comp(L.pe('(a, b, c)'), (cl1.to_AST(), cl2.to_AST(), cl3.to_AST()), (), {}) self.assertEqual(comp, exp_comp) join2 = Join.from_comp(exp_comp, CF) self.assertEqual(join, join2) # Attributes. self.assertEqual(join.enumvars, ('a', 'b', 'c')) self.assertEqual(join.vars, ('a', 'b', 'c')) self.assertEqual(join.rels, ('R', 'S')) self.assertTrue(join.robust) self.assertEqual(join.has_wildcards, False) self.assertIs(join.delta, None) # Rewriting/prefixing. cl1a = EnumClause(('z', 'b'), 'R') cl3a = CondClause(L.pe('z != c')) join2 = join.rewrite_subst({'a': 'z'}) self.assertEqual(join2, Join([cl1a, cl2, cl3a], CF, None)) cl1b = EnumClause(('_a', '_b'), 'R') cl2b = EnumClause(('_b', '_c'), 'S') cl3b = CondClause(L.pe('_a != _c')) join3 = join.prefix_enumvars('_') self.assertEqual(join3, Join([cl1b, cl2b, cl3b], CF, None))
def test_patternize_depatternize(self): orig_comp = L.pe('COMP({z for (x_2, y) in R if x == x_2 for (y_2, z) in S if y == y_2}, [x], {})') exp_comp = L.pe('COMP({z for (x, y) in R for (y, z) in S}, [x], {})') comp = patternize_comp(orig_comp, self.manager.factory) self.assertEqual(comp, exp_comp) comp = depatternize_comp(comp, self.manager.factory) self.assertEqual(comp, orig_comp)
def test_unflatten_subclause(self): # Make sure we don't do anything foolish when presented with # a subtractive enum. comp = L.pe( 'COMP({z for (x, y) in _M for (y, z) in _M - {e}}, [])') comp = unflatten_comp(comp) exp_comp = L.pe( 'COMP({z for y in x for (y, z) in _M - {e}}, [])') self.assertEqual(comp, exp_comp)
def test_retrieval_expander(self): tree = L.pe('f_f_m_a_b_c_d + foo') field_exps = {'f_f_m_a_b_c_d': ('f_m_a_b_c', 'd'), 'f_m_a_b_c': ('m_a_b', 'c')} map_exps = {'m_a_b': ('a', 'b')} tree = RetrievalExpander.run(tree, field_exps, map_exps) exp_tree = L.pe('a[b].c.d + foo') self.assertEqual(tree, exp_tree)
def test_patternize_depatternize(self): orig_comp = L.pe( 'COMP({z for (x_2, y) in R if x == x_2 for (y_2, z) in S if y == y_2}, [x], {})' ) exp_comp = L.pe('COMP({z for (x, y) in R for (y, z) in S}, [x], {})') comp = patternize_comp(orig_comp, self.manager.factory) self.assertEqual(comp, exp_comp) comp = depatternize_comp(comp, self.manager.factory) self.assertEqual(comp, orig_comp)
def test_unflatten_retrievals(self): comp = L.pe(''' COMP({x_a for (S, S_b) in _F_b for (S_b, c, m_S_b_k_c) in _MAP for x in m_S_b_k_c for (x, x_a) in _F_a if x_a > 5}, [S]) ''') comp = unflatten_retrievals(comp) exp_comp = L.pe('COMP({x.a for x in S.b[c] if x.a > 5}, [S])') self.assertEqual(comp, exp_comp)
def test_objclausefactory(self): cl = MClause('S', 'x') clast = L.Enumerator(L.tuplify(['S', 'x'], lval=True), L.pe('_M')) cl2 = ObjClauseFactory.from_AST(clast) self.assertEqual(cl2, cl) cl = FClause_NoTC('o', 'v', 'f') clast = L.Enumerator(L.tuplify(['o', 'v'], lval=True), L.pe('_F_f')) cl2 = ObjClauseFactory_NoTC.from_AST(clast) self.assertEqual(cl2, cl) self.assertIsInstance(cl2, FClause_NoTC)
def test_retrieval_expander(self): tree = L.pe('f_f_m_a_b_c_d + foo') field_exps = { 'f_f_m_a_b_c_d': ('f_m_a_b_c', 'd'), 'f_m_a_b_c': ('m_a_b', 'c') } map_exps = {'m_a_b': ('a', 'b')} tree = RetrievalExpander.run(tree, field_exps, map_exps) exp_tree = L.pe('a[b].c.d + foo') self.assertEqual(tree, exp_tree)
def test_basic(self): cl1 = EnumClause.from_expr(L.pe('(x, y) in R')) cl2 = EnumClause.from_expr(L.pe('(y, z) in S')) spec = CompSpec(Join([cl1, cl2], CF, None), L.pe('(x, z)'), ['x']) # AST round-trip. comp = spec.to_comp({}) exp_comp = L.pe('COMP({(x, z) for (x, y) in R for (y, z) in S}, ' '[x], {})') self.assertEqual(comp, exp_comp) spec2 = CompSpec.from_comp(exp_comp, CF) self.assertEqual(spec, spec2)
def test_maintjoins(self): join = self.make_join('for (a, b) in R for (b, c) in R') # Disjoint, subtractive. mjoins = join.get_maint_joins(L.pe('e'), 'R', 'add', '', disjoint_strat='sub') exp_mjoin1 = self.make_join( ''' for (a, b) in deltamatch(R, "bb", e, 1) for (b, c) in R - {e}''', DeltaInfo('R', L.pe('e'), ('a', 'b'), 'add')) exp_mjoin2 = self.make_join( ''' for (a, b) in R for (b, c) in deltamatch(R, "bb", e, 1)''', DeltaInfo('R', L.pe('e'), ('b', 'c'), 'add')) self.assertSequenceEqual(mjoins, [exp_mjoin1, exp_mjoin2]) # Disjoint, augmented. mjoins = join.get_maint_joins(L.pe('e'), 'R', 'add', '', disjoint_strat='aug') exp_mjoin1 = self.make_join( ''' for (a, b) in deltamatch(R, "bb", e, 0) for (b, c) in R + {e}''', DeltaInfo('R', L.pe('e'), ('a', 'b'), 'add')) exp_mjoin2 = self.make_join( ''' for (a, b) in R for (b, c) in deltamatch(R, "bb", e, 0)''', DeltaInfo('R', L.pe('e'), ('b', 'c'), 'add')) self.assertSequenceEqual(mjoins, [exp_mjoin1, exp_mjoin2]) # Not disjoint. With prefix. mjoins = join.get_maint_joins(L.pe('e'), 'R', 'add', '_', disjoint_strat='das') exp_mjoin1 = self.make_join( ''' for (_a, _b) in deltamatch(R, "bb", e, 1) for (_b, _c) in R''', DeltaInfo('R', L.pe('e'), ('_a', '_b'), 'add')) exp_mjoin2 = self.make_join( ''' for (_a, _b) in R for (_b, _c) in deltamatch(R, "bb", e, 1)''', DeltaInfo('R', L.pe('e'), ('_b', '_c'), 'add')) self.assertSequenceEqual(mjoins, [exp_mjoin1, exp_mjoin2])
def test_replacer(self): look = L.pe('R.smlookup("bu", x)') dem1 = L.pe('DEMQUERY(foo, [y], R.smlookup("bu", y))') dem2 = L.pe('DEMQUERY(bar, [z], R.smlookup("bu", z))') tree = L.pe('x + LOOK + DEM1 + DEM1 + DEM2', subst={ 'LOOK': look, 'DEM1': dem1, 'DEM2': dem2 }) namer = L.NameGenerator() replacer = LookupReplacer(namer) tree, clauses = replacer.process(tree) repls = replacer.repls exp_tree = L.pe('x + v1 + v2 + v2 + v3') exp_clauses = [ L.Enumerator(L.sn('v1'), L.pe('{R.smlookup("bu", x)}')), L.Enumerator(L.sn('v2'), L.pe('DEMQUERY(foo, [y], {R.smlookup("bu", y)})')), L.Enumerator(L.sn('v3'), L.pe('DEMQUERY(bar, [z], {R.smlookup("bu", z)})')), ] exp_repls = { look: 'v1', dem1: 'v2', dem2: 'v3', } self.assertEqual(tree, exp_tree) self.assertEqual(clauses, exp_clauses) self.assertEqual(repls, exp_repls)
def test_flatten_smlookups_dem(self): comp = L.pe('COMP({x for x in S ' 'if DEMQUERY(foo, [u], Aggr1.smlookup("u", ())) > 5}, ' '[], {})') comp = flatten_smlookups(comp) comp = flatten_smlookups(comp) exp_comp = L.pe('COMP({x for x in S ' 'for _av1 in DEMQUERY(foo, [u], ' '{Aggr1.smlookup("u", ())}) ' 'if (_av1 > 5)}, ' '[], {})') self.assertEqual(comp, exp_comp)
def test_replacer(self): look = L.pe('R.smlookup("bu", x)') dem1 = L.pe('DEMQUERY(foo, [y], R.smlookup("bu", y))') dem2 = L.pe('DEMQUERY(bar, [z], R.smlookup("bu", z))') tree = L.pe('x + LOOK + DEM1 + DEM1 + DEM2', subst={'LOOK': look, 'DEM1': dem1, 'DEM2': dem2}) namer = L.NameGenerator() replacer = LookupReplacer(namer) tree, clauses = replacer.process(tree) repls = replacer.repls exp_tree = L.pe('x + v1 + v2 + v2 + v3') exp_clauses = [ L.Enumerator(L.sn('v1'), L.pe('{R.smlookup("bu", x)}')), L.Enumerator(L.sn('v2'), L.pe('DEMQUERY(foo, [y], {R.smlookup("bu", y)})')), L.Enumerator(L.sn('v3'), L.pe('DEMQUERY(bar, [z], {R.smlookup("bu", z)})')), ] exp_repls = { look: 'v1', dem1: 'v2', dem2: 'v3', } self.assertEqual(tree, exp_tree) self.assertEqual(clauses, exp_clauses) self.assertEqual(repls, exp_repls)
def visit_Module(self, node): resinit = L.pe('RCSet()') code = L.pc(''' RES = RESINIT ''', subst={'RES': self.inccomp.name, 'RESINIT': resinit}) for rel in self.inccomp.spec.join.rels: prefix1 = self.manager.namegen.next_prefix() prefix2 = self.manager.namegen.next_prefix() add_code, add_comps = make_comp_maint_code( self.inccomp.spec, self.inccomp.name, rel, 'add', L.pe('_e'), prefix1, maint_impl=self.inccomp.maint_impl, rc=self.inccomp.rc, selfjoin=self.inccomp.selfjoin) remove_code, remove_comps = make_comp_maint_code( self.inccomp.spec, self.inccomp.name, rel, 'remove', L.pe('_e'), prefix2, maint_impl=self.inccomp.maint_impl, rc=self.inccomp.rc, selfjoin=self.inccomp.selfjoin) self.maint_comps.extend(add_comps) self.maint_comps.extend(remove_comps) code += L.pc(''' def ADDFUNC(_e): ADDCODE def REMOVEFUNC(_e): REMOVECODE ''', subst={'<def>ADDFUNC': self.addfuncs[rel], '<c>ADDCODE': add_code, '<def>REMOVEFUNC': self.removefuncs[rel], '<c>REMOVECODE': remove_code}) vt = self.manager.vartypes for e in self.inccomp.spec.join.enumvars: if e in vt: vt[prefix1 + e] = vt[e] vt[prefix2 + e] = vt[e] node = node._replace(body=code + node.body) node = self.generic_visit(node) return node
def test_flatten_smlookups_nodem(self): comp = L.pe('COMP({x for x in S ' 'if Aggr1.smlookup("u", ()) > 5}, ' '[], {})') comp = flatten_smlookups(comp) # Ensure idempotence. We don't want to mess up an enumerator # in a maintenance comprehension. comp = flatten_smlookups(comp) exp_comp = L.pe('COMP({x for x in S ' 'for _av1 in {Aggr1.smlookup("u", ())} ' 'if (_av1 > 5)}, ' '[], {})') self.assertEqual(comp, exp_comp)
def from_options(cls, options): """Construct from comprehension options dict. If delta info isn't provided, return None instead of an instance. """ if options is None or '_deltarel' not in options: return None rel = options['_deltarel'] elem = options['_deltaelem'] elem = L.pe(elem) lhs = options['_deltalhs'] lhs = L.get_vartuple(L.pe(lhs)) op = options['_deltaop'] return cls(rel, elem, lhs, op)
def visit_Assign(self, node): allowed_inits = [ L.pe('Set()'), L.pe('incoq.runtime.Set()'), L.pe('set()'), ] # If this is a relation initializer, mark the relation name # and don't recurse. if (self.toplevel and L.is_varassign(node)): name, value = L.get_varassign(node) if value in allowed_inits: self.inited.add(name) return self.generic_visit(node)
def test_transform(self): comp = L.pe('COMP({z for (x, y) in R for (y, z) in S}, [x], ' '{"impl": "inc"})') tree = L.p(''' R.add(1) print(COMP) ''', subst={'COMP': comp}) tree = transform_all_queries(tree, self.manager) tree = L.elim_deadfuncs(tree, lambda n: n.startswith('_maint_')) exp_tree = L.p(''' Comp1 = RCSet() def _maint_Comp1_R_add(_e): Comment("Iterate {(v1_x, v1_y, v1_z) : (v1_x, v1_y) in deltamatch(R, 'bb', _e, 1), (v1_y, v1_z) in S}") (v1_x, v1_y) = _e for v1_z in setmatch(S, 'bu', v1_y): if ((v1_x, v1_z) not in Comp1): Comp1.add((v1_x, v1_z)) else: Comp1.incref((v1_x, v1_z)) with MAINT(Comp1, 'after', 'R.add(1)'): R.add(1) _maint_Comp1_R_add(1) print(setmatch(Comp1, 'bu', x)) ''') self.assertEqual(tree, exp_tree)
def test_enumclause_basic(self): cl = EnumClause(('x', 'y', 'x', '_'), 'R') # From expression. cl2 = EnumClause.from_expr(L.pe('(x, y, x, _) in R')) self.assertEqual(cl2, cl) # AST round-trip. clast = cl.to_AST() exp_clast = \ L.Enumerator(L.tuplify(['x', 'y', 'x', '_'], lval=True), L.ln('R')) self.assertEqual(clast, exp_clast) cl2 = EnumClause.from_AST(exp_clast, DummyFactory) self.assertEqual(cl2, cl) # Attributes. self.assertFalse(cl.isdelta) self.assertEqual(cl.enumlhs, ('x', 'y', 'x', '_')) self.assertEqual(cl.enumvars, ('x', 'y')) self.assertEqual(cl.pat_mask, (True, True, True, True)) self.assertEqual(cl.enumrel, 'R') self.assertTrue(cl.has_wildcards) self.assertEqual(cl.vars, ('x', 'y')) self.assertEqual(cl.eqvars, None) self.assertTrue(cl.robust) self.assertEqual(cl.demname, None) self.assertEqual(cl.demparams, ())
def test_flatten_retrievals(self): comp = L.pe('COMP({x.a for x in S.b[c] if x.a > 5}, [S])') comp, seen_fields, seen_map = flatten_retrievals(comp) exp_comp = L.pe(''' COMP({x_a for (S, S_b) in _F_b for (S_b, c, m_S_b_k_c) in _MAP for x in m_S_b_k_c for (x, x_a) in _F_a if x_a > 5}, [S]) ''') exp_seen_fields = ['b', 'a'] exp_seen_map = True self.assertEqual(comp, exp_comp) self.assertEqual(seen_fields, exp_seen_fields) self.assertEqual(seen_map, exp_seen_map)
def test_flatten_unflatten_sets(self): comp = L.pe('COMP({x for (o, o_s) in _F_s for (o, o_t) in _F_t ' 'for x in o_s if x in o_t if x in T}, [S, T])') flatcomp, use_mset = flatten_sets(comp, ['T']) exp_flatcomp = L.pe( 'COMP({x for (o, o_s) in _F_s for (o, o_t) in _F_t ' 'for (o_s, x) in _M if (o_t, x) in _M if x in T}, ' '[S, T])') self.assertEqual(flatcomp, exp_flatcomp) self.assertTrue(use_mset) unflatcomp = unflatten_sets(flatcomp) self.assertEqual(unflatcomp, comp)
def make_join(self, source): """Construct a Join from a comprehension's source code (ignoring the result expression). """ node = L.pe(source) join = Join.from_comp(node, self.manager.factory) return join
def to_AST(self): code = self.cl.to_AST() assert isinstance(code, L.Enumerator) code = code._replace(iter=L.pe('ITER + {EXTRA}', subst={'ITER': code.iter, 'EXTRA': self.extra})) return code
def test_basic(self): query1 = L.pe('COMP({x for x in S}, [S], {"e": "f2"})') query2 = L.pe('COMP({y for y in T}, [T], {})') # Test import/export. nopts = {'a': 'b2'} qopts = {query1: {'e': 'f2'}} o = DummyManager() o.import_opts(nopts, qopts) # Test retrievals. self.assertEqual(o.get_opt('a'), 'b2') self.assertEqual(o.get_opt('c'), 'd') self.assertEqual(o.get_queryopt(query1, 'e'), 'f2') self.assertEqual(o.get_queryopt(query1, 'g'), 'h') self.assertEqual(o.get_queryopt(query2, 'e'), 'f')
def test_resexp_vars(self): resexp = L.pe('(a + b, (c, d), (a, c, e, f))') bounds, unbounds = split_resexp_vars(resexp, Mask('bbu')) exp_bounds = {'c', 'd'} exp_unbounds = {'a', 'e', 'f'} self.assertEqual(bounds, exp_bounds) self.assertEqual(unbounds, exp_unbounds)
def test_flatten_smlookups_dem(self): comp = L.pe( 'COMP({x for x in S ' 'if DEMQUERY(foo, [u], Aggr1.smlookup("u", ())) > 5}, ' '[], {})') comp = flatten_smlookups(comp) comp = flatten_smlookups(comp) exp_comp = L.pe( 'COMP({x for x in S ' 'for _av1 in DEMQUERY(foo, [u], ' '{Aggr1.smlookup("u", ())}) ' 'if (_av1 > 5)}, ' '[], {})') self.assertEqual(comp, exp_comp)
def test_flatten_unflatten_sets(self): comp = L.pe( 'COMP({x for (o, o_s) in _F_s for (o, o_t) in _F_t ' 'for x in o_s if x in o_t if x in T}, [S, T])') flatcomp, use_mset = flatten_sets(comp, ['T']) exp_flatcomp = L.pe( 'COMP({x for (o, o_s) in _F_s for (o, o_t) in _F_t ' 'for (o_s, x) in _M if (o_t, x) in _M if x in T}, ' '[S, T])') self.assertEqual(flatcomp, exp_flatcomp) self.assertTrue(use_mset) unflatcomp = unflatten_sets(flatcomp) self.assertEqual(unflatcomp, comp)