def process(expr): """Rewrite any retrievals in the given expression. Return a pair of the new expression, and a list of new clauses to be added for any retrievals not already seen. """ nonlocal seen_map new_expr = replacer.process(expr) new_field_repls = replacer.field_repls - seen_field_repls new_map_repls = replacer.map_repls - seen_map_repls new_clauses = [] for repl in new_field_repls: obj, field, value = repl seen_fields.add(field) seen_field_repls.add(repl) new_cl = L.Enumerator(L.tuplify((obj, value), lval=True), L.ln(make_frel(field))) new_clauses.append(new_cl) for repl in new_map_repls: map, key, value = repl seen_map = True seen_map_repls.add(repl) new_cl = L.Enumerator(L.tuplify((map, key, value), lval=True), L.ln(make_maprel())) new_clauses.append(new_cl) return new_expr, new_clauses
def visit_Aggregate(self, node): node = self.generic_visit(node) if not node.op in ['min', 'max']: return node func2 = {'min': 'min2', 'max': 'max2'}[node.op] if not L.is_setunion(node.value): return node sets = L.get_setunion(node.value) if len(sets) == 1: # If there's just one set, don't change anything. return node # Wrap each operand in an aggregate query with the same # options as the original aggregate. (This ensures that # 'impl' is carried over.) Set literals are wrapped in # a call to incoq.runtime's min2()/max2() instead of an # Aggregate query node. terms = [] for s in sets: if isinstance(s, (L.Comp, L.Name)): new_term = L.Aggregate(s, node.op, node.options) else: new_term = L.pe('OP(__ARGS)', subst={'OP': L.ln(func2)}) new_term = new_term._replace(args=s.elts) terms.append(new_term) # The new top-level aggregate is min2()/max2(). new_node = L.pe('OP(__ARGS)', subst={'OP': L.ln(func2)}) new_node = new_node._replace(args=tuple(terms)) return new_node
def test_enumclause_setmatch(self): # Make sure we can convert clauses over setmatches. cl = EnumClause.from_AST( L.Enumerator(L.tuplify(['y'], lval=True), L.SetMatch(L.ln('R'), 'bu', L.ln('x'))), DummyFactory) exp_cl = EnumClause(('x', 'y'), 'R') self.assertEqual(cl, exp_cl)
def visit_SetUpdate(self, node): rel = L.get_name(node.target) if rel not in self.at_rels: return # New code gets inserted after the update. # This is true even if the update was a removal. # It shouldn't matter where we do the U-set update, # so long as the invariants are properly maintained # at the time. prefix = self.manager.namegen.next_prefix() vars = [prefix + v for v in self.projvars] if node.op == 'add': funcname = L.N.demfunc(self.demname) else: funcname = L.N.undemfunc(self.demname) call_func = L.Call(L.ln(funcname), tuple(L.ln(v) for v in vars), (), None, None) postcode = L.pc(''' for S_PROJVARS in DELTA.elements(): CALL_FUNC DELTA.clear() ''', subst={'S_PROJVARS': L.tuplify(vars, lval=True), 'DELTA': self.delta_name, 'CALL_FUNC': call_func}) return self.with_outer_maint(node, funcname, L.ts(node), (), postcode)
def flatten_set_clause(cl, input_rels): """Turn a membership clause that is not over a comprehension, special relation, or input relation, into a clause over the M-set. Return a pair of the (possibly unchanged) clause and a bool indicating whether or not the change was done. This also works on condition clauses that express membership constraints. The rewritten clause is still a condition clause. """ def should_trans(rhs): return (not isinstance(rhs, L.Comp) and not (isinstance(rhs, L.Name) and (is_specialrel(rhs.id) or rhs.id in input_rels))) # Enumerator case. if isinstance(cl, L.Enumerator) and should_trans(cl.iter): item = cl.target cont = cl.iter cont = L.ContextSetter.run(cont, L.Store) new_cl = L.Enumerator(L.tuplify((cont, item), lval=True), L.ln(make_mrel())) return new_cl, True # Condition case. if isinstance(cl, L.expr) and L.is_cmp(cl): item, op, cont = L.get_cmp(cl) if isinstance(op, L.In) and should_trans(cont): new_cl = L.cmp(L.tuplify((cont, item)), L.In(), L.ln(make_mrel())) return new_cl, True return cl, False
def get_res_code(self): """Return code (expression) to lookup the result.""" params = self.inccomp.comp.params if len(params) > 0: resexp = self.inccomp.spec.resexp assert isinstance(resexp, L.Tuple) resexp_arity = len(resexp.elts) n_rescomponents = resexp_arity - len(params) maskstr = 'b' * len(params) + 'u' * n_rescomponents masknode = Mask(maskstr).make_node() paramsnode = L.tuplify(params) code = L.pe(''' setmatch(RES, MASK, PARAMS) ''', subst={ 'RES': L.ln(self.inccomp.name), 'MASK': masknode, 'PARAMS': paramsnode }) else: code = L.ln(self.inccomp.name) return code
def make_vareq_cond(eqs): """Given a list of pairs of variables, return a conjunction of equalities, one for each pair. """ eqcond = L.BoolOp(L.And(), tuple(L.cmpeq(L.ln(v1), L.ln(v2)) for v1, v2 in eqs)) return eqcond
def __init__(self, aggrop, rel, relmask, params, oper_demname, oper_demparams): assert self.aggrop in ['count', 'sum', 'min', 'max'] # AST node representation. node = L.ln(rel) if len(params) > 0: node = L.SetMatch(node, relmask.make_node().s, L.tuplify(params)) if oper_demname is not None: node = L.DemQuery(oper_demname, [L.ln(p) for p in oper_demparams], node) node = L.Aggregate(node, aggrop, None) self.node = node
def visit_Comp(self, node): node = self.generic_visit(node) if node != self.comp: return node self.need_func = True call = L.pe('QFUN(__ARGS)', subst={'QFUN': L.ln(L.N.queryfunc(self.name))}) call = call._replace(args=tuple(L.ln(p) for p in self.comp.params)) return call
def make_addu_maint(self, prefix): """Generate code for after an addition to U.""" incaggr = self.incaggr assert incaggr.has_demand spec = incaggr.spec mv_var = prefix + 'val' elemvar = prefix + 'elem' # If we're using half-demand, there's no demand to propagate # to the operand. All we need to do is add an entry with count # 0 if one is not already there. if incaggr.half_demand: return L.pc(''' S_MV = A.smdeflookup(MASK, KEY, None) if MV is None: A.smassignkey(MASK, KEY, ZERO, PREFIX) ''', subst={'A': L.ln(incaggr.name), 'S_MV': L.sn(mv_var), 'MV': L.ln(mv_var), 'MASK': incaggr.aggrmask.make_node(), 'KEY': L.tuplify(incaggr.params), 'ZERO': self.make_zero_mapval_expr(), 'PREFIX': L.Str(prefix)}) update_code = self.make_update_state_code( L.sn(mv_var), L.ln(mv_var), 'add', L.ln(elemvar), prefix) # Make operand demand function call, if operand uses demand. if spec.has_oper_demand: demfunc = L.N.demfunc(spec.oper_demname) call_demfunc = L.Call(L.ln(demfunc), tuple(L.ln(v) for v in spec.oper_demparams), (), None, None) propagate_code = (L.Expr(call_demfunc),) else: propagate_code = () code = L.pc(''' S_MV = ZERO for S_ELEM in setmatch(R, RELMASK, PARAMS): UPDATE_MAPVAL A.smassignkey(AGGRMASK, KEY, MV, PREFIX) PROPAGATE_DEMAND ''', subst={'S_MV': L.sn(mv_var), 'ZERO': self.make_zero_mapval_expr(), 'S_ELEM': L.sn(elemvar), 'R': spec.rel, 'RELMASK': spec.relmask.make_node(), 'PARAMS': L.tuplify(spec.params), '<c>UPDATE_MAPVAL': update_code, 'A': L.ln(incaggr.name), 'AGGRMASK': incaggr.aggrmask.make_node(), 'KEY': L.tuplify(incaggr.params), 'MV': L.ln(mv_var), 'PREFIX': L.Str(prefix), '<c>PROPAGATE_DEMAND': propagate_code}) return code
def make_removeu_maint(self, prefix): """Generate code for before a removal from U.""" incaggr = self.incaggr assert incaggr.has_demand spec = incaggr.spec mv_var = prefix + 'val' # If we're using half-demand, there's no demand to propagate # to the operand. All we need to do is determine whether to # do the removal by checking whether the count is 0. if incaggr.half_demand: return L.pc(''' S_MV = A.smlookup(MASK, KEY) if COUNT == 0: A.smdelkey(MASK, KEY, PREFIX) ''', subst={ 'S_MV': L.sn(mv_var), 'COUNT': self.mapval_proj_count(L.ln(mv_var)), 'A': incaggr.name, 'MASK': incaggr.aggrmask.make_node(), 'KEY': L.tuplify(incaggr.params), 'PREFIX': L.Str(prefix) }) # Generate operand undemand function call, if operand # uses demand. if spec.has_oper_demand: undemfunc = L.N.undemfunc(spec.oper_demname) call_undemfunc = L.Call( L.ln(undemfunc), tuple(L.ln(v) for v in spec.oper_demparams), (), None, None) propagate_code = (L.Expr(call_undemfunc), ) else: propagate_code = () code = L.pc(''' PROPAGATE_DEMAND A.smdelkey(MASK, KEY, PREFIX) ''', subst={ 'A': incaggr.name, 'MASK': incaggr.aggrmask.make_node(), 'KEY': L.tuplify(incaggr.params), 'PREFIX': L.Str(prefix), '<c>PROPAGATE_DEMAND': propagate_code }) return code
def visit_Name(self, node): if node.id in self.field_exps: obj, field = self.field_exps[node.id] new_node = L.Attribute(L.ln(obj), field, node.ctx) new_node = self.generic_visit(new_node) return new_node elif node.id in self.map_exps: map, key = self.map_exps[node.id] new_node = L.Subscript(L.ln(map), L.Index(L.ln(key)), node.ctx) new_node = self.generic_visit(new_node) return new_node else: return node
def make_update_state_code(self, state_snode, state_lnode, op, val_node, prefix): add_template = L.trim(''' S_TREE, _ = STATE TREE[VAL] = None S_STATE = (TREE, TREE.MINMAX()) ''') remove_template = L.trim(''' S_TREE, _ = STATE del TREE[VAL] S_STATE = (TREE, TREE.MINMAX()) ''') template = {'add': add_template, 'remove': remove_template}[op] treevar = prefix + 'tree' minmax = {'min': '__min__', 'max': '__max__'}[self.kind] code = L.pc(template, subst={'S_TREE': L.sn(treevar), 'TREE': L.ln(treevar), '@MINMAX': minmax, 'STATE': state_lnode, 'S_STATE': state_snode, 'VAL': val_node}) return code
def test_enumclause_basic(self): cl = EnumClause(('x', 'y', 'x', '_'), 'R') # From expression. cl2 = EnumClause.from_expr(L.pe('(x, y, x, _) in R')) self.assertEqual(cl2, cl) # AST round-trip. clast = cl.to_AST() exp_clast = \ L.Enumerator(L.tuplify(['x', 'y', 'x', '_'], lval=True), L.ln('R')) self.assertEqual(clast, exp_clast) cl2 = EnumClause.from_AST(exp_clast, DummyFactory) self.assertEqual(cl2, cl) # Attributes. self.assertFalse(cl.isdelta) self.assertEqual(cl.enumlhs, ('x', 'y', 'x', '_')) self.assertEqual(cl.enumvars, ('x', 'y')) self.assertEqual(cl.pat_mask, (True, True, True, True)) self.assertEqual(cl.enumrel, 'R') self.assertTrue(cl.has_wildcards) self.assertEqual(cl.vars, ('x', 'y')) self.assertEqual(cl.eqvars, None) self.assertTrue(cl.robust) self.assertEqual(cl.demname, None) self.assertEqual(cl.demparams, ())
def to_AST(self): mask = Mask.from_keylen(len(self.lhs) - 1) keyvars = self.lhs[:-1] var = self.lhs[-1] sm = L.SMLookup(L.ln(self.rel), mask.make_node().s, L.tuplify(keyvars), None) return L.Enumerator(L.sn(var), L.Set((sm, )))
def make_bindmatch(rel, mask, bvars, uvars, body): if mask.is_allbound and not mask.has_equalities: template = L.trim(''' if BVARS in REL: BODY ''') elif mask.is_allunbound and not mask.has_wildcards: template = L.trim(''' for UVARS in REL: BODY ''') else: template = L.trim(''' for UVARS in setmatch(REL, MASK, BVARS): BODY ''') code = L.pc(template, subst={'REL': L.ln(rel), 'MASK': mask.make_node(), 'BVARS': L.tuplify(bvars), 'UVARS': L.tuplify(uvars, lval=True), '<c>BODY': body}) return code
def visit_DemQuery(self, node): self.demwrapped_nodes.add(node.value) node = self.generic_visit(node) if not isinstance(node.value, L.SMLookup): return node sm = node.value assert sm.default is None v = self.repls.get(node, None) if v is not None: # Reuse existing entry. var = v else: # Create new entry. self.repls[node] = var = next(self.namer) # Create accompanying clause. Has form # var in DEMQUERY(..., {smlookup}) # The clause constructor logic will later rewrite that, # or else fail if there's a syntax problem. cl_target = L.sn(var) cl_iter = node._replace(value=L.Set((sm,))) new_cl = L.Enumerator(cl_target, cl_iter) self.new_clauses.append(new_cl) return L.ln(var)
def make_update_state_code(self, state_snode, state_lnode, op, val_node, prefix): add_template = L.trim(''' S_TREE, _ = STATE TREE[VAL] = None S_STATE = (TREE, TREE.MINMAX()) ''') remove_template = L.trim(''' S_TREE, _ = STATE del TREE[VAL] S_STATE = (TREE, TREE.MINMAX()) ''') template = {'add': add_template, 'remove': remove_template}[op] treevar = prefix + 'tree' minmax = {'min': '__min__', 'max': '__max__'}[self.kind] code = L.pc(template, subst={ 'S_TREE': L.sn(treevar), 'TREE': L.ln(treevar), '@MINMAX': minmax, 'STATE': state_lnode, 'S_STATE': state_snode, 'VAL': val_node }) return code
def visit_SetUpdate(self, node): target_ok = LegalUpdateValidator.run(node.target) elem_ok = LegalUpdateValidator.run(node.elem) if target_ok and elem_ok: return node code = () if not target_ok: targetvar = next(self.namegen) code += (L.Assign((L.sn(targetvar), ), node.target), ) node = node._replace(target=L.ln(targetvar)) if not elem_ok: elemvar = next(self.namegen) code += (L.Assign((L.sn(elemvar), ), node.elem), ) node = node._replace(elem=L.ln(elemvar)) return code + (node, )
def to_AST(self): code = self.cl.to_AST() assert isinstance(code, L.Enumerator) code = code._replace( iter=L.DemQuery(self.demname, tuple( L.ln(p) for p in self.demparams), code.iter)) return code
def make_retrieval_code(self): """Make code for retrieving the value of the aggregate result, including demanding it. """ incaggr = self.incaggr params_l = L.List(tuple(L.ln(p) for p in incaggr.params), L.Load()) if incaggr.has_demand: code = L.pe(''' DEMQUERY(NAME, PARAMS_L, RES.smlookup(AGGRMASK, PARAMS_T)) ''', subst={ 'NAME': incaggr.name, 'PARAMS_L': params_l, 'PARAMS_T': L.tuplify(incaggr.params), 'RES': incaggr.name, 'AGGRMASK': incaggr.aggrmask.make_node() }) else: code = L.pe(''' RES.smdeflookup(AGGRMASK, PARAMS_T, ZERO) ''', subst={ 'RES': incaggr.name, 'AGGRMASK': incaggr.aggrmask.make_node(), 'PARAMS_T': L.tuplify(incaggr.params), 'ZERO': self.make_zero_mapval_expr(), }) code = self.make_proj_mapval_code(code) return code
def to_AST(self): mask = Mask.from_keylen(len(self.lhs) - 1) keyvars = self.lhs[:-1] var = self.lhs[-1] sm = L.SMLookup(L.ln(self.rel), mask.make_node().s, L.tuplify(keyvars), None) return L.Enumerator(L.sn(var), L.Set((sm,)))
def make_retrieval_code(self): """Make code for retrieving the value of the aggregate result, including demanding it. """ incaggr = self.incaggr params_l = L.List(tuple(L.ln(p) for p in incaggr.params), L.Load()) if incaggr.has_demand: code = L.pe(''' DEMQUERY(NAME, PARAMS_L, RES.smlookup(AGGRMASK, PARAMS_T)) ''', subst={'NAME': incaggr.name, 'PARAMS_L': params_l, 'PARAMS_T': L.tuplify(incaggr.params), 'RES': incaggr.name, 'AGGRMASK': incaggr.aggrmask.make_node()}) else: code = L.pe(''' RES.smdeflookup(AGGRMASK, PARAMS_T, ZERO) ''', subst={'RES': incaggr.name, 'AGGRMASK': incaggr.aggrmask.make_node(), 'PARAMS_T': L.tuplify(incaggr.params), 'ZERO': self.make_zero_mapval_expr(),}) code = self.make_proj_mapval_code(code) return code
def structures_to_comps(ds, factory): """Convert tags and filters to comprehensions that define them. Return pairs of comps and their names, in dependency order. Ignore usets. """ tags_by_name = {t.name: t for t in ds.tags} result = [] for s in ds.structs: if s.kind is KIND_TAG: cl = EnumClause(s.lhs, s.rel) spec = CompSpec(Join([cl], factory, None), L.ln(s.var), []) elif s.kind is KIND_FILTER: cls = [] for tname in s.preds: t = tags_by_name[tname] cls.append(EnumClause([t.var], t.name)) # Be sure to replace wildcards with fresh vars. lhs = inst_wildcards(s.lhs) cls.append(EnumClause(lhs, s.rel)) spec = CompSpec(Join(cls, factory, None), L.tuplify(lhs), []) elif s.kind is KIND_USET: continue else: assert() result.append((s.name, spec.to_comp({}))) return result
def visit_SetUpdate(self, node): target_ok = LegalUpdateValidator.run(node.target) elem_ok = LegalUpdateValidator.run(node.elem) if target_ok and elem_ok: return node code = () if not target_ok: targetvar = next(self.namegen) code += (L.Assign((L.sn(targetvar),), node.target),) node = node._replace(target=L.ln(targetvar)) if not elem_ok: elemvar = next(self.namegen) code += (L.Assign((L.sn(elemvar),), node.elem),) node = node._replace(elem=L.ln(elemvar)) return code + (node,)
def visit_DelKey(self, node): target_ok = LegalUpdateValidator.run(node.target) key_ok = LegalUpdateValidator.run(node.key) if target_ok and key_ok: return node code = () if not target_ok: targetvar = next(self.namegen) code += (L.Assign((L.sn(targetvar),), node.target),) node = node._replace(target=L.ln(targetvar)) if not key_ok: keyvar = next(self.namegen) code += (L.Assign((L.sn(keyvar),), node.key),) node = node._replace(key=L.ln(keyvar)) return code + (node,)
def visit_DemQuery(self, node): # Translate into a call to the demand function. Cost is # that plus the result retrieval cost. callnode = L.Call(L.ln(L.N.queryfunc(node.demname)), node.args, (), None, None) callcost = self.visit(callnode) retrievecost = self.visit(node.value) return SumCost([callcost, retrievecost])
def visit_DelKey(self, node): target_ok = LegalUpdateValidator.run(node.target) key_ok = LegalUpdateValidator.run(node.key) if target_ok and key_ok: return node code = () if not target_ok: targetvar = next(self.namegen) code += (L.Assign((L.sn(targetvar), ), node.target), ) node = node._replace(target=L.ln(targetvar)) if not key_ok: keyvar = next(self.namegen) code += (L.Assign((L.sn(keyvar), ), node.key), ) node = node._replace(key=L.ln(keyvar)) return code + (node, )
def helper(self, node, no_update=False): fresh = next(self.namegen) if no_update: template = L.trim(''' S_VAR = Set() ''') else: template = L.trim(''' S_VAR = Set() L_VAR.update(EXPR) ''') new_code = L.pc(template, subst={'L_VAR': L.ln(fresh), 'S_VAR': L.sn(fresh), 'EXPR': node}) self.pre_stmts.extend(new_code) return L.ln(fresh)
def mainttest_helper(self, maskstr): spec = AuxmapSpec('R', Mask(maskstr)) # Make the prefix '_' so it's easier to read/type. self.manager.namegen.next_prefix = lambda: '_' code = make_auxmap_maint_code(self.manager, spec, L.ln('e'), 'add') return code
def visit_Delete(self, node): if not self.rewrite_fields: return node if not L.is_delattr(node): return node cont, field = L.get_delattr(node) return L.pc(''' CONT.nsdelfield(FIELD) ''', subst={'CONT': cont, 'FIELD': L.ln(field)})
def make_update_mapval_code(self, mv_snode, mv_lnode, op, val_node, prefix): """Produce code to make a new mapval, given an update to the corresponding operand. The mapval is read from mv_lnode and written to mv_snode. """ # If we don't track counts, the mapvals are the same as # the states. if not self.incaggr.tracks_counts: return self.make_update_state_code(mv_snode, mv_lnode, op, val_node, prefix) statevar = prefix + 'state' state_lnode = L.ln(statevar) state_snode = L.sn(statevar) countvar = prefix + 'count' updatestate_code = self.make_update_state_code(state_snode, state_lnode, op, val_node, prefix) if op == 'add': template = 'COUNTVAR + 1' elif op == 'remove': template = 'COUNTVAR - 1' else: assert () new_count_node = L.pe(template, subst={'COUNTVAR': L.ln(countvar)}) return L.pc(''' S_STATE, S_COUNTVAR = MV UPDATE_STATE S_MV = STATE, NEW_COUNT ''', subst={ 'S_STATE': state_snode, 'S_COUNTVAR': L.sn(countvar), 'MV': mv_lnode, '<c>UPDATE_STATE': updatestate_code, 'STATE': state_lnode, 'NEW_COUNT': new_count_node, 'S_MV': mv_snode })
def make_removeu_maint(self, prefix): """Generate code for before a removal from U.""" incaggr = self.incaggr assert incaggr.has_demand spec = incaggr.spec mv_var = prefix + 'val' # If we're using half-demand, there's no demand to propagate # to the operand. All we need to do is determine whether to # do the removal by checking whether the count is 0. if incaggr.half_demand: return L.pc(''' S_MV = A.smlookup(MASK, KEY) if COUNT == 0: A.smdelkey(MASK, KEY, PREFIX) ''', subst={'S_MV': L.sn(mv_var), 'COUNT': self.mapval_proj_count(L.ln(mv_var)), 'A': incaggr.name, 'MASK': incaggr.aggrmask.make_node(), 'KEY': L.tuplify(incaggr.params), 'PREFIX': L.Str(prefix)}) # Generate operand undemand function call, if operand # uses demand. if spec.has_oper_demand: undemfunc = L.N.undemfunc(spec.oper_demname) call_undemfunc = L.Call(L.ln(undemfunc), tuple(L.ln(v) for v in spec.oper_demparams), (), None, None) propagate_code = (L.Expr(call_undemfunc),) else: propagate_code = () code = L.pc(''' PROPAGATE_DEMAND A.smdelkey(MASK, KEY, PREFIX) ''', subst={'A': incaggr.name, 'MASK': incaggr.aggrmask.make_node(), 'KEY': L.tuplify(incaggr.params), 'PREFIX': L.Str(prefix), '<c>PROPAGATE_DEMAND': propagate_code}) return code
def get_code(self, bindenv, body): # Just stick a DemQuery node in before the regular code. # TODO: This is a little ugly in that it results in # littering the code with "None"s. Maybe make a special # case in the translation of DemQuery to avoid this. code = self.cl.get_code(bindenv, body) new_node = L.Expr(value=L.DemQuery( self.demname, tuple(L.ln(p) for p in self.demparams), None)) code = (new_node, ) + code return code
def make_auxmap_maint_code(manager, spec, elem, addremove): """Construct auxmap maintenance code for a set update.""" assert addremove in ['add', 'remove'] prefix = manager.namegen.next_prefix() mask = spec.mask # Create fresh variables for the tuple components. vars = [prefix + str(i) for i in range(1, len(mask) + 1)] bvars, uvars, eqs = mask.split_vars(vars) vars_node = L.tuplify(vars, lval=True) map_node = L.ln(spec.map_name) bvars_node = L.tuplify(bvars) uvars_node = L.tuplify(uvars) # If there are equalities, include a conditional check for the # constraints being satisfied. If there are wildcards, make the # image set manipulation operations reference-counted. # # Avoid these in cases where we don't have equalities/wildcards, # to help reduce constant-factor bloat in code size and running # time. if mask.has_equalities: template = ''' VARS = ELEM if EQCOND: MAP.IMGOP(BVARS, UVARS) ''' eqcond = make_vareq_cond(eqs) else: template = ''' VARS = ELEM MAP.IMGOP(BVARS, UVARS) ''' eqcond = None if mask.has_wildcards: imgop = {'add': 'rcimgadd', 'remove': 'rcimgremove'}[addremove] else: imgop = {'add': 'imgadd', 'remove': 'imgremove'}[addremove] code = L.pc(template, subst={ '@IMGOP': imgop, 'VARS': vars_node, 'ELEM': elem, 'MAP': map_node, 'BVARS': bvars_node, 'UVARS': uvars_node, 'EQCOND': eqcond}) return code
def test(self): cl = DemClause(EnumClause(['x', 'y'], 'R'), 'f', ['x']) # AST round-trip. clast = cl.to_AST() exp_clast = \ L.Enumerator(L.tuplify(['x', 'y'], lval=True), L.DemQuery('f', (L.ln('x'),), L.ln('R'))) self.assertEqual(clast, exp_clast) cl2 = DemClause.from_AST(exp_clast, DemClauseFactory) self.assertEqual(cl2, cl) # Attributes. self.assertEqual(cl.pat_mask, (True, True)) self.assertEqual(cl.enumvars_tagsin, ('x', )) self.assertEqual(cl.enumvars_tagsout, ('y', )) # Rewriting. cl2 = cl.rewrite_subst({'x': 'z'}, DemClauseFactory) exp_cl = DemClause(EnumClause(['z', 'y'], 'R'), 'f', ['z']) self.assertEqual(cl2, exp_cl) # Fancy rewriting, uses LookupClause. cl2 = DemClause(LookupClause(['x', 'y'], 'R'), 'f', ['x']) cl2 = cl2.rewrite_subst({'x': 'z'}, DemClauseFactory) exp_cl = DemClause(LookupClause(['z', 'y'], 'R'), 'f', ['z']) self.assertEqual(cl2, exp_cl) # Rating. rate = cl.rate(['x']) self.assertEqual(rate, Rate.NORMAL) rate = cl.rate([]) self.assertEqual(rate, Rate.UNRUNNABLE) # Code generation. code = cl.get_code(['x'], L.pc('pass')) exp_code = L.pc(''' DEMQUERY(f, [x], None) for y in setmatch(R, 'bu', x): pass ''') self.assertEqual(code, exp_code)
def test(self): cl = DemClause(EnumClause(['x', 'y'], 'R'), 'f', ['x']) # AST round-trip. clast = cl.to_AST() exp_clast = \ L.Enumerator(L.tuplify(['x', 'y'], lval=True), L.DemQuery('f', (L.ln('x'),), L.ln('R'))) self.assertEqual(clast, exp_clast) cl2 = DemClause.from_AST(exp_clast, DemClauseFactory) self.assertEqual(cl2, cl) # Attributes. self.assertEqual(cl.pat_mask, (True, True)) self.assertEqual(cl.enumvars_tagsin, ('x',)) self.assertEqual(cl.enumvars_tagsout, ('y',)) # Rewriting. cl2 = cl.rewrite_subst({'x': 'z'}, DemClauseFactory) exp_cl = DemClause(EnumClause(['z', 'y'], 'R'), 'f', ['z']) self.assertEqual(cl2, exp_cl) # Fancy rewriting, uses LookupClause. cl2 = DemClause(LookupClause(['x', 'y'], 'R'), 'f', ['x']) cl2 = cl2.rewrite_subst({'x': 'z'}, DemClauseFactory) exp_cl = DemClause(LookupClause(['z', 'y'], 'R'), 'f', ['z']) self.assertEqual(cl2, exp_cl) # Rating. rate = cl.rate(['x']) self.assertEqual(rate, Rate.NORMAL) rate = cl.rate([]) self.assertEqual(rate, Rate.UNRUNNABLE) # Code generation. code = cl.get_code(['x'], L.pc('pass')) exp_code = L.pc(''' DEMQUERY(f, [x], None) for y in setmatch(R, 'bu', x): pass ''') self.assertEqual(code, exp_code)
def visit_Tuple(self, node): # No need to recurse, that's taken care of by the caller # of this visitor. tupvar = self.tupvar_namer.next() arity = len(node.elts) trel = make_trel(arity) elts = (L.sn(tupvar), ) + node.elts new_cl = L.Enumerator(L.tuplify(elts, lval=True), L.ln(trel)) self.new_clauses.append(new_cl) self.trels.add(trel) return L.sn(tupvar)
def visit_SetUpdate(self, node): if not (isinstance(node.target, L.Name) and node.target.id == self.rel): return fresh = next(self.namegen) tvar = '_t' + fresh ftvar = '_ft' + fresh code = make_flattup_code(self.tuptype, node.elem, L.sn(ftvar), tvar) update = node._replace(elem=L.ln(ftvar)) return code + (update, )
def visit_ClassDef(self, node): node = self.generic_visit(node) assert all(self.valid_baseclass(b) for b in node.bases), \ 'Illegal base class' objbase = L.ln('Set') if objbase not in node.bases: new_bases = node.bases + (objbase, ) node = node._replace(bases=new_bases) return node
def visit_ClassDef(self, node): node = self.generic_visit(node) assert all(self.valid_baseclass(b) for b in node.bases), \ 'Illegal base class' objbase = L.ln('Set') if objbase not in node.bases: new_bases = node.bases + (objbase,) node = node._replace(bases=new_bases) return node
def visit_Assign(self, node): if not self.rewrite_fields: return node if not L.is_attrassign(node): return node cont, field, value = L.get_attrassign(node) return L.pc(''' CONT.nsassignfield(FIELD, VALUE) ''', subst={'CONT': cont, 'FIELD': L.ln(field), 'VALUE': value})
def visit_DelKey(self, node): node = self.generic_visit(node) if not self.use_mapset: return node code = L.pc(''' MAPSET.remove((TARGET, KEY, TARGET[KEY])) ''', subst={'MAPSET': L.ln(make_maprel()), 'TARGET': node.target, 'KEY': node.key}) return code
def visit_SetUpdate(self, node): if not (isinstance(node.target, L.Name) and node.target.id == self.rel): return fresh = next(self.namegen) tvar = '_t' + fresh ftvar = '_ft' + fresh code = make_flattup_code(self.tuptype, node.elem, L.sn(ftvar), tvar) update = node._replace(elem=L.ln(ftvar)) return code + (update,)
def visit_Tuple(self, node): # No need to recurse, that's taken care of by the caller # of this visitor. tupvar = self.tupvar_namer.next() arity = len(node.elts) trel = make_trel(arity) elts = (L.sn(tupvar),) + node.elts new_cl = L.Enumerator(L.tuplify(elts, lval=True), L.ln(trel)) self.new_clauses.append(new_cl) self.trels.add(trel) return L.sn(tupvar)