def test_singletonclause(self): cl = SingletonClause(('x', 'y'), L.pe('e')) # From expression. cl2 = SingletonClause.from_expr(L.pe('(x, y) == e')) self.assertEqual(cl, cl2) # AST round-trip. clast = cl.to_AST() exp_clast = L.Enumerator(L.tuplify(['x', 'y'], lval=True), L.pe('{e}')) self.assertEqual(clast, exp_clast) cl2 = SingletonClause.from_AST(exp_clast, DummyFactory) self.assertEqual(cl2, cl) # Attributes. self.assertEqual(cl.enumvars, ('x', 'y')) # Code generation. code = cl.get_code([], L.pc('pass')) exp_code = L.pc(''' x, y = e pass ''') self.assertEqual(code, exp_code)
def test_deltaclause(self): cl = DeltaClause(('x', 'y'), 'R', L.pe('e'), 1) # AST round-trip. clast = cl.to_AST() exp_clast = L.Enumerator(L.tuplify(['x', 'y'], lval=True), L.pe('deltamatch(R, "bb", e, 1)')) self.assertEqual(clast, exp_clast) cl2 = DeltaClause.from_AST(exp_clast, DummyFactory) self.assertEqual(cl2, cl) # Attributes. self.assertEqual(cl.rel, 'R') # Code generation, no fancy mask. code = cl.get_code([], L.pc('pass')) exp_code = L.pc(''' x, y = e pass ''') self.assertEqual(code, exp_code) # Code generation, fancy mask. cl2 = DeltaClause(('x', 'x', '_'), 'R', L.pe('e'), 1) code = cl2.get_code([], L.pc('pass')) exp_code = L.pc(''' for x in setmatch(deltamatch(R, 'b1w', e, 1), 'u1w', ()): pass ''') self.assertEqual(code, exp_code)
def test_condclause(self): cl = CondClause(L.pe('f(a) or g(b)')) # AST round-trip. clast = cl.to_AST() exp_clast = L.pe('f(a) or g(b)') self.assertEqual(clast, exp_clast) cl2 = CondClause.from_AST(exp_clast, DummyFactory) self.assertEqual(cl2, cl) # fits_string(). self.assertTrue(cl.fits_string(['a', 'b'], 'f(a) or g(b)')) # Attributes. self.assertEqual(cl.enumvars, ()) self.assertEqual(cl.vars, ('a', 'b')) cl2 = CondClause(L.pe('a == b')) self.assertEqual(cl2.eqvars, ('a', 'b')) # Rating. self.assertEqual(cl.rate(['a', 'b']), Rate.CONSTANT) self.assertEqual(cl.rate(['a']), Rate.UNRUNNABLE) # Code generation. code = cl.get_code(['a', 'b'], L.pc('pass')) exp_code = L.pc(''' if f(a) or g(b): pass ''') self.assertEqual(code, exp_code)
def visit_Module(self, node): spec = self.spec node = self.generic_visit(node) code = L.pc(''' result.nsadd(RESEXP) ''', subst={'RESEXP': spec.resexp}) code = spec.join.get_code(spec.params, code, augmented=self.augmented) code = L.pc(''' SPEC_STR result = set() COMPUTE return result ''', subst={'SPEC_STR': L.Str(s=str(self.spec)), '<c>COMPUTE': code}) code = L.plainfuncdef(L.N.queryfunc(self.name), spec.params, code) if self.need_func: node = node._replace(body=code + node.body) return node
def visit_Module(self, node): spec = self.spec node = self.generic_visit(node) code = L.pc(''' result.nsadd(RESEXP) ''', subst={'RESEXP': spec.resexp}) code = spec.join.get_code(spec.params, code, augmented=self.augmented) code = L.pc(''' SPEC_STR result = set() COMPUTE return result ''', subst={ 'SPEC_STR': L.Str(s=str(self.spec)), '<c>COMPUTE': code }) code = L.plainfuncdef(L.N.queryfunc(self.name), spec.params, code) if self.need_func: node = node._replace(body=code + node.body) return node
def visit_SetUpdate(self, node): node = self.generic_visit(node) # No action if # - this is not an update to a variable # - this is not the variable you are looking for (jedi hand wave) if not node.is_varupdate(): return node var, op, elem = node.get_varupdate() if var != self.spec.rel: return node precode = postcode = () if op == 'add': postcode = L.pc('ADDFUNC(ELEM)', subst={'ADDFUNC': self.addfunc_name, 'ELEM': elem}) elif op == 'remove': precode = L.pc('REMOVEFUNC(ELEM)', subst={'REMOVEFUNC': self.removefunc_name, 'ELEM': elem}) else: assert() code = L.Maintenance(self.spec.map_name, L.ts(node), precode, (node,), postcode) return code
def make_addu_maint(self, prefix): """Generate code for after an addition to U.""" incaggr = self.incaggr assert incaggr.has_demand spec = incaggr.spec mv_var = prefix + 'val' elemvar = prefix + 'elem' # If we're using half-demand, there's no demand to propagate # to the operand. All we need to do is add an entry with count # 0 if one is not already there. if incaggr.half_demand: return L.pc(''' S_MV = A.smdeflookup(MASK, KEY, None) if MV is None: A.smassignkey(MASK, KEY, ZERO, PREFIX) ''', subst={'A': L.ln(incaggr.name), 'S_MV': L.sn(mv_var), 'MV': L.ln(mv_var), 'MASK': incaggr.aggrmask.make_node(), 'KEY': L.tuplify(incaggr.params), 'ZERO': self.make_zero_mapval_expr(), 'PREFIX': L.Str(prefix)}) update_code = self.make_update_state_code( L.sn(mv_var), L.ln(mv_var), 'add', L.ln(elemvar), prefix) # Make operand demand function call, if operand uses demand. if spec.has_oper_demand: demfunc = L.N.demfunc(spec.oper_demname) call_demfunc = L.Call(L.ln(demfunc), tuple(L.ln(v) for v in spec.oper_demparams), (), None, None) propagate_code = (L.Expr(call_demfunc),) else: propagate_code = () code = L.pc(''' S_MV = ZERO for S_ELEM in setmatch(R, RELMASK, PARAMS): UPDATE_MAPVAL A.smassignkey(AGGRMASK, KEY, MV, PREFIX) PROPAGATE_DEMAND ''', subst={'S_MV': L.sn(mv_var), 'ZERO': self.make_zero_mapval_expr(), 'S_ELEM': L.sn(elemvar), 'R': spec.rel, 'RELMASK': spec.relmask.make_node(), 'PARAMS': L.tuplify(spec.params), '<c>UPDATE_MAPVAL': update_code, 'A': L.ln(incaggr.name), 'AGGRMASK': incaggr.aggrmask.make_node(), 'KEY': L.tuplify(incaggr.params), 'MV': L.ln(mv_var), 'PREFIX': L.Str(prefix), '<c>PROPAGATE_DEMAND': propagate_code}) return code
def test_bindmatch_other(self): code = make_bindmatch('R', Mask('bu'), ['x'], ['y'], L.pc('pass')) exp_code = L.pc(''' for y in setmatch(R, 'bu', x): pass ''') self.assertEqual(code, exp_code)
def test_bindmatch_b1(self): code = make_bindmatch('R', Mask('b1'), ['x'], [], L.pc('pass')) exp_code = L.pc(''' for _ in setmatch(R, 'b1', x): pass ''') self.assertEqual(code, exp_code)
def test_bindmatch_uu(self): code = make_bindmatch('R', Mask('uu'), [], ['x', 'y'], L.pc('pass')) exp_code = L.pc(''' for (x, y) in R: pass ''') self.assertEqual(code, exp_code)
def test_clauseflattener(self): code = L.pc(''' COMP({x for x in S for (x, (y, z)) in R}, [], {}) ''') code = ClauseFlattener.run(code, 'R', self.tuptype) exp_code = L.pc(''' COMP({x for x in S for (x, y, z) in R}, [], {}) ''') self.assertEqual(code, exp_code)
def test_bindmatch_bb(self): code = make_bindmatch('R', Mask('bb'), ['x', 'y'], [], L.pc('pass')) exp_code = L.pc(''' if ((x, y) in R): pass ''') self.assertEqual(code, exp_code)
def test_tuplematch_other(self): code = make_tuplematch(L.pe('v'), Mask('bu'), ['x'], ['y'], L.pc('pass')) exp_code = L.pc(''' for y in setmatch({v}, 'bu', x): pass ''') self.assertEqual(code, exp_code)
def test_tuplematch_bb(self): code = make_tuplematch(L.pe('v'), Mask('bb'), ['x', 'y'], [], L.pc('pass')) exp_code = L.pc(''' if ((x, y) == v): pass ''') self.assertEqual(code, exp_code)
def test_tuplematch_uu(self): code = make_tuplematch(L.pe('v'), Mask('uu'), [], ['x', 'y'], L.pc('pass')) exp_code = L.pc(''' (x, y) = v pass ''') self.assertEqual(code, exp_code)
def test_code(self): join = self.make_join('for (a, b) in R for (b, c) in S') code = join.get_code(['c'], L.pc('pass'), augmented=False) exp_code = L.pc(''' for b in setmatch(S, 'ub', c): for a in setmatch(R, 'ub', b): pass ''') self.assertEqual(code, exp_code)
def test_reltypegetter(self): tree = L.pc(''' print(COMP({x for x in S for (x, (y, z)) in R}, [], {})) ''') tuptype = ReltypeGetter.run(tree, 'R') exp_tuptype = ('<T>', 'x', ('<T>', 'y', 'z')) self.assertEqual(tuptype, exp_tuptype) tree = L.pc(''' print(COMP({x for x in R for (x, (y, z)) in R}, [], {})) ''') with self.assertRaises(AssertionError): ReltypeGetter.run(tree, 'R')
def visit_Module(self, node): resinit = L.pe('RCSet()') code = L.pc(''' RES = RESINIT ''', subst={'RES': self.inccomp.name, 'RESINIT': resinit}) for rel in self.inccomp.spec.join.rels: prefix1 = self.manager.namegen.next_prefix() prefix2 = self.manager.namegen.next_prefix() add_code, add_comps = make_comp_maint_code( self.inccomp.spec, self.inccomp.name, rel, 'add', L.pe('_e'), prefix1, maint_impl=self.inccomp.maint_impl, rc=self.inccomp.rc, selfjoin=self.inccomp.selfjoin) remove_code, remove_comps = make_comp_maint_code( self.inccomp.spec, self.inccomp.name, rel, 'remove', L.pe('_e'), prefix2, maint_impl=self.inccomp.maint_impl, rc=self.inccomp.rc, selfjoin=self.inccomp.selfjoin) self.maint_comps.extend(add_comps) self.maint_comps.extend(remove_comps) code += L.pc(''' def ADDFUNC(_e): ADDCODE def REMOVEFUNC(_e): REMOVECODE ''', subst={'<def>ADDFUNC': self.addfuncs[rel], '<c>ADDCODE': add_code, '<def>REMOVEFUNC': self.removefuncs[rel], '<c>REMOVECODE': remove_code}) vt = self.manager.vartypes for e in self.inccomp.spec.join.enumvars: if e in vt: vt[prefix1 + e] = vt[e] vt[prefix2 + e] = vt[e] node = node._replace(body=code + node.body) node = self.generic_visit(node) return node
def make_removeu_maint(self, prefix): """Generate code for before a removal from U.""" incaggr = self.incaggr assert incaggr.has_demand spec = incaggr.spec mv_var = prefix + 'val' # If we're using half-demand, there's no demand to propagate # to the operand. All we need to do is determine whether to # do the removal by checking whether the count is 0. if incaggr.half_demand: return L.pc(''' S_MV = A.smlookup(MASK, KEY) if COUNT == 0: A.smdelkey(MASK, KEY, PREFIX) ''', subst={ 'S_MV': L.sn(mv_var), 'COUNT': self.mapval_proj_count(L.ln(mv_var)), 'A': incaggr.name, 'MASK': incaggr.aggrmask.make_node(), 'KEY': L.tuplify(incaggr.params), 'PREFIX': L.Str(prefix) }) # Generate operand undemand function call, if operand # uses demand. if spec.has_oper_demand: undemfunc = L.N.undemfunc(spec.oper_demname) call_undemfunc = L.Call( L.ln(undemfunc), tuple(L.ln(v) for v in spec.oper_demparams), (), None, None) propagate_code = (L.Expr(call_undemfunc), ) else: propagate_code = () code = L.pc(''' PROPAGATE_DEMAND A.smdelkey(MASK, KEY, PREFIX) ''', subst={ 'A': incaggr.name, 'MASK': incaggr.aggrmask.make_node(), 'KEY': L.tuplify(incaggr.params), 'PREFIX': L.Str(prefix), '<c>PROPAGATE_DEMAND': propagate_code }) return code
def test_costlabel(self): tree = L.pc(''' def f(x): for y in setmatch(R, 'bu', x): pass ''') costmap = func_costs(tree) tree = CostLabelAdder.run(tree, costmap) exp_tree = L.pc(''' def f(x): Comment('Cost: O(R_out[x])') for y in setmatch(R, 'bu', x): pass ''') self.assertEqual(tree, exp_tree)
def test_iterate_code(self): # Test single. code = for_rel_code(['a', 'b'], L.pe('R'), L.pc('pass')) exp_code = L.pc(''' for a, b in R: pass ''') self.assertEqual(code, exp_code) code = for_rels_union_code(['a', 'b'], [L.pe('R')], L.pc('pass'), '_') self.assertEqual(code, exp_code) # Test union. code = for_rels_union_code(['a', 'b'], [L.pe('R'), L.pe('S')], L.pc('pass'), 'D') exp_code = L.pc(''' D = set() for a, b in R: D.nsadd((a, b)) for a, b in S: D.nsadd((a, b)) for a, b in D: pass del D ''') self.assertEqual(code, exp_code) # Test verify disjoint union. code = for_rels_union_code(['a', 'b'], [L.pe('R'), L.pe('S')], L.pc('pass'), 'D', verify_disjoint=True) exp_code = L.pc(''' D = set() for a, b in R: assert (a, b) not in D D.add((a, b)) for a, b in S: assert (a, b) not in D D.add((a, b)) for a, b in D: pass del D ''') self.assertEqual(code, exp_code) # Test disjoint union. code = for_rels_union_disjoint_code( ['a', 'b'], [L.pe('R'), L.pe('S')], L.pc('pass')) exp_code = L.pc(''' for a, b in R: pass for a, b in S: pass ''') self.assertEqual(code, exp_code)
def test_flattup_code(self): in_node = L.ln('x') out_node = L.sn('y') code = make_flattup_code(self.tuptype, in_node, out_node, 't') exp_code = L.pc(''' y = (x[0], x[1][0], x[1][1]) ''') self.assertEqual(code, exp_code) in_node = L.pe('x.a') code = make_flattup_code(self.tuptype, in_node, out_node, 't') exp_code = L.pc(''' t = x.a y = (t[0], t[1][0], t[1][1]) ''') self.assertEqual(code, exp_code)
def helper(self, node, var, op, elem): assert op in ['add', 'remove'] # Maintenance goes after addition updates and before removals, # except when we're using augmented code, which relies on the # value of the set *without* the updated element. after_add = self.inccomp.selfjoin != 'aug' is_add = op == 'add' if self.inccomp.change_tracker: # For change trackers, all removals turn into additions, # but are still run in the same spot they would have been. funcdict = self.addfuncs else: funcdict = self.addfuncs if is_add else self.removefuncs func = funcdict[var] code = L.pc('FUNC(ELEM)', subst={'FUNC': func, 'ELEM': elem}) if after_add ^ is_add: precode = code postcode = () else: precode = () postcode = code # Respect outsideinvs. This ensures that demand invariant # maintenance is inserted before/after the query maintenance. return self.with_outer_maint(node, self.inccomp.name, L.ts(node), precode, postcode)
def visit_Module(self, node): incaggr = self.incaggr self.manager.add_invariant(incaggr.name, incaggr) add_prefix = self.manager.namegen.next_prefix() remove_prefix = self.manager.namegen.next_prefix() addcode = self.cg.make_oper_maint(add_prefix, 'add', L.pe('_e')) removecode = self.cg.make_oper_maint(remove_prefix, 'remove', L.pe('_e')) code = L.pc(''' RES = Set() def ADDFUNC(_e): ADDCODE def REMOVEFUNC(_e): REMOVECODE ''', subst={'RES': incaggr.name, '<def>ADDFUNC': self.addfunc, '<c>ADDCODE': addcode, '<def>REMOVEFUNC': self.removefunc, '<c>REMOVECODE': removecode}) node = node._replace(body=code + node.body) node = self.generic_visit(node) return node
def make_update_state_code(self, state_snode, state_lnode, op, val_node, prefix): add_template = L.trim(''' S_TREE, _ = STATE TREE[VAL] = None S_STATE = (TREE, TREE.MINMAX()) ''') remove_template = L.trim(''' S_TREE, _ = STATE del TREE[VAL] S_STATE = (TREE, TREE.MINMAX()) ''') template = {'add': add_template, 'remove': remove_template}[op] treevar = prefix + 'tree' minmax = {'min': '__min__', 'max': '__max__'}[self.kind] code = L.pc(template, subst={'S_TREE': L.sn(treevar), 'TREE': L.ln(treevar), '@MINMAX': minmax, 'STATE': state_lnode, 'S_STATE': state_snode, 'VAL': val_node}) return code
def test_checkbad(self): tree = L.pc(''' for t in setmatch(_TUP2, 'ubw', x): pass ''') with self.assertRaises(AssertionError): check_bad_setmatches(tree)
def visit_Module(self, node): incaggr = self.incaggr self.manager.add_invariant(incaggr.name, incaggr) add_prefix = self.manager.namegen.next_prefix() remove_prefix = self.manager.namegen.next_prefix() addcode = self.cg.make_oper_maint(add_prefix, 'add', L.pe('_e')) removecode = self.cg.make_oper_maint(remove_prefix, 'remove', L.pe('_e')) code = L.pc(''' RES = Set() def ADDFUNC(_e): ADDCODE def REMOVEFUNC(_e): REMOVECODE ''', subst={ 'RES': incaggr.name, '<def>ADDFUNC': self.addfunc, '<c>ADDCODE': addcode, '<def>REMOVEFUNC': self.removefunc, '<c>REMOVECODE': removecode }) node = node._replace(body=code + node.body) node = self.generic_visit(node) return node
def get_code(self, bindenv, body): assert set(self.vars).issubset(bindenv) code = L.pc(''' if COND: BODY ''', subst={'COND': self.cond, '<c>BODY': body}) return code
def make_bindmatch(rel, mask, bvars, uvars, body): if mask.is_allbound and not mask.has_equalities: template = L.trim(''' if BVARS in REL: BODY ''') elif mask.is_allunbound and not mask.has_wildcards: template = L.trim(''' for UVARS in REL: BODY ''') else: template = L.trim(''' for UVARS in setmatch(REL, MASK, BVARS): BODY ''') code = L.pc(template, subst={'REL': L.ln(rel), 'MASK': mask.make_node(), 'BVARS': L.tuplify(bvars), 'UVARS': L.tuplify(uvars, lval=True), '<c>BODY': body}) return code
def make_tuplematch(val, mask, bvars, uvars, body): if mask.is_allbound: template = L.trim(''' if BVARS == VAL: BODY ''') elif mask.is_allunbound and not mask.has_wildcards: template = L.trim(''' UVARS = VAL BODY ''') else: template = L.trim(''' for UVARS in setmatch({VAL}, MASK, BVARS): BODY ''') code = L.pc(template, subst={'VAL': val, 'MASK': mask.make_node(), 'BVARS': L.tuplify(bvars), 'UVARS': L.tuplify(uvars, lval=True), '<c>BODY': body}) return code
def make_update_state_code(self, state_snode, state_lnode, op, val_node, prefix): add_template = L.trim(''' S_TREE, _ = STATE TREE[VAL] = None S_STATE = (TREE, TREE.MINMAX()) ''') remove_template = L.trim(''' S_TREE, _ = STATE del TREE[VAL] S_STATE = (TREE, TREE.MINMAX()) ''') template = {'add': add_template, 'remove': remove_template}[op] treevar = prefix + 'tree' minmax = {'min': '__min__', 'max': '__max__'}[self.kind] code = L.pc(template, subst={ 'S_TREE': L.sn(treevar), 'TREE': L.ln(treevar), '@MINMAX': minmax, 'STATE': state_lnode, 'S_STATE': state_snode, 'VAL': val_node }) return code
def visit_SetUpdate(self, node): rel = L.get_name(node.target) if rel not in self.at_rels: return # New code gets inserted after the update. # This is true even if the update was a removal. # It shouldn't matter where we do the U-set update, # so long as the invariants are properly maintained # at the time. prefix = self.manager.namegen.next_prefix() vars = [prefix + v for v in self.projvars] if node.op == 'add': funcname = L.N.demfunc(self.demname) else: funcname = L.N.undemfunc(self.demname) call_func = L.Call(L.ln(funcname), tuple(L.ln(v) for v in vars), (), None, None) postcode = L.pc(''' for S_PROJVARS in DELTA.elements(): CALL_FUNC DELTA.clear() ''', subst={'S_PROJVARS': L.tuplify(vars, lval=True), 'DELTA': self.delta_name, 'CALL_FUNC': call_func}) return self.with_outer_maint(node, funcname, L.ts(node), (), postcode)
def visit_DelKey(self, node): if not self.rewrite_maps: return node return L.pc(''' TARGET.nsdelkey(KEY) ''', subst={'TARGET': node.target, 'KEY': node.key})
def test_removeu(self): code = self.cg.make_removeu_maint('_') exp_code = L.pc(''' (_1, _2) = (p1, p2) _elem = A.smlookup('bbu', (p1, p2)) A.remove((_1, _2, _elem)) ''') self.assertEqual(code, exp_code)
def test_removeu(self): code = self.cg.make_removeu_maint('_') exp_code = L.pc(''' _ = () _elem = A.smlookup('u', ()) A.remove(_elem) ''') self.assertEqual(code, exp_code)
def visit_SetUpdate(self, node): spec = self.incaggr.spec node = self.generic_visit(node) if not node.is_varupdate(): return node var, op, elem = node.get_varupdate() if var == spec.rel: precode = postcode = () if op == 'add': postcode = L.pc('ADDFUNC(ELEM)', subst={ 'ADDFUNC': self.addfunc, 'ELEM': elem }) elif op == 'remove': precode = L.pc('REMOVEFUNC(ELEM)', subst={ 'REMOVEFUNC': self.removefunc, 'ELEM': elem }) else: assert () code = L.Maintenance(self.incaggr.name, L.ts(node), precode, (node, ), postcode) elif var == L.N.uset(self.incaggr.name): prefix = self.manager.namegen.next_prefix() precode = postcode = () if op == 'add': postcode = self.cg.make_addu_maint(prefix) elif op == 'remove': precode = self.cg.make_removeu_maint(prefix) else: assert () code = L.Maintenance(self.incaggr.name, L.ts(node), precode, (node, ), postcode) else: code = node return code
def visit_SetUpdate(self, node): if not self.rewrite_sets: return node nsop = {'add': 'nsadd', 'remove': 'nsremove'}[node.op] template = 'TARGET.{}(ELEM)'.format(nsop) return L.pc(template, subst={'TARGET': node.target, 'ELEM': node.elem})
def visit_AssignKey(self, node): if not self.rewrite_maps: return node return L.pc(''' TARGET.nsassignkey(KEY, VALUE) ''', subst={'TARGET': node.target, 'KEY': node.key, 'VALUE': node.value})
def test_enumclause_code(self): cl = EnumClause(('x', 'y'), 'R') # fits_string(). self.assertTrue(cl.fits_string(['x'], 'R_out')) # Rating. self.assertEqual(cl.rate(['x']), Rate.NORMAL) self.assertEqual(cl.rate(['x', 'y']), Rate.CONSTANT_MEMBERSHIP) self.assertEqual(cl.rate([]), Rate.NOTPREFERRED) # Code generation. code = cl.get_code(['x'], L.pc('pass')) exp_code = L.pc(''' for y in setmatch(R, 'bu', x): pass ''') self.assertEqual(code, exp_code)
def test_addu(self): code = self.cg.make_addu_maint('_') exp_code = L.pc(''' _val = A.smdeflookup('bbu', (p1, p2), None) if (_val is None): (_1, _2) = (p1, p2) A.add((_1, _2, (0, 0))) ''') self.assertEqual(code, exp_code)