def flatten_set_clause(cl, input_rels): """Turn a membership clause that is not over a comprehension, special relation, or input relation, into a clause over the M-set. Return a pair of the (possibly unchanged) clause and a bool indicating whether or not the change was done. This also works on condition clauses that express membership constraints. The rewritten clause is still a condition clause. """ def should_trans(rhs): return (not isinstance(rhs, L.Comp) and not (isinstance(rhs, L.Name) and (is_specialrel(rhs.id) or rhs.id in input_rels))) # Enumerator case. if isinstance(cl, L.Enumerator) and should_trans(cl.iter): item = cl.target cont = cl.iter cont = L.ContextSetter.run(cont, L.Store) new_cl = L.Enumerator(L.tuplify((cont, item), lval=True), L.ln(make_mrel())) return new_cl, True # Condition case. if isinstance(cl, L.expr) and L.is_cmp(cl): item, op, cont = L.get_cmp(cl) if isinstance(op, L.In) and should_trans(cont): new_cl = L.cmp(L.tuplify((cont, item)), L.In(), L.ln(make_mrel())) return new_cl, True return cl, False
def make_retrieval_code(self): """Make code for retrieving the value of the aggregate result, including demanding it. """ incaggr = self.incaggr params_l = L.List(tuple(L.ln(p) for p in incaggr.params), L.Load()) if incaggr.has_demand: code = L.pe(''' DEMQUERY(NAME, PARAMS_L, RES.smlookup(AGGRMASK, PARAMS_T)) ''', subst={'NAME': incaggr.name, 'PARAMS_L': params_l, 'PARAMS_T': L.tuplify(incaggr.params), 'RES': incaggr.name, 'AGGRMASK': incaggr.aggrmask.make_node()}) else: code = L.pe(''' RES.smdeflookup(AGGRMASK, PARAMS_T, ZERO) ''', subst={'RES': incaggr.name, 'AGGRMASK': incaggr.aggrmask.make_node(), 'PARAMS_T': L.tuplify(incaggr.params), 'ZERO': self.make_zero_mapval_expr(),}) code = self.make_proj_mapval_code(code) return code
def make_bindmatch(rel, mask, bvars, uvars, body): if mask.is_allbound and not mask.has_equalities: template = L.trim(''' if BVARS in REL: BODY ''') elif mask.is_allunbound and not mask.has_wildcards: template = L.trim(''' for UVARS in REL: BODY ''') else: template = L.trim(''' for UVARS in setmatch(REL, MASK, BVARS): BODY ''') code = L.pc(template, subst={'REL': L.ln(rel), 'MASK': mask.make_node(), 'BVARS': L.tuplify(bvars), 'UVARS': L.tuplify(uvars, lval=True), '<c>BODY': body}) return code
def make_tuplematch(val, mask, bvars, uvars, body): if mask.is_allbound: template = L.trim(''' if BVARS == VAL: BODY ''') elif mask.is_allunbound and not mask.has_wildcards: template = L.trim(''' UVARS = VAL BODY ''') else: template = L.trim(''' for UVARS in setmatch({VAL}, MASK, BVARS): BODY ''') code = L.pc(template, subst={'VAL': val, 'MASK': mask.make_node(), 'BVARS': L.tuplify(bvars), 'UVARS': L.tuplify(uvars, lval=True), '<c>BODY': body}) return code
def make_retrieval_code(self): """Make code for retrieving the value of the aggregate result, including demanding it. """ incaggr = self.incaggr params_l = L.List(tuple(L.ln(p) for p in incaggr.params), L.Load()) if incaggr.has_demand: code = L.pe(''' DEMQUERY(NAME, PARAMS_L, RES.smlookup(AGGRMASK, PARAMS_T)) ''', subst={ 'NAME': incaggr.name, 'PARAMS_L': params_l, 'PARAMS_T': L.tuplify(incaggr.params), 'RES': incaggr.name, 'AGGRMASK': incaggr.aggrmask.make_node() }) else: code = L.pe(''' RES.smdeflookup(AGGRMASK, PARAMS_T, ZERO) ''', subst={ 'RES': incaggr.name, 'AGGRMASK': incaggr.aggrmask.make_node(), 'PARAMS_T': L.tuplify(incaggr.params), 'ZERO': self.make_zero_mapval_expr(), }) code = self.make_proj_mapval_code(code) return code
def process(expr): """Rewrite any retrievals in the given expression. Return a pair of the new expression, and a list of new clauses to be added for any retrievals not already seen. """ nonlocal seen_map new_expr = replacer.process(expr) new_field_repls = replacer.field_repls - seen_field_repls new_map_repls = replacer.map_repls - seen_map_repls new_clauses = [] for repl in new_field_repls: obj, field, value = repl seen_fields.add(field) seen_field_repls.add(repl) new_cl = L.Enumerator(L.tuplify((obj, value), lval=True), L.ln(make_frel(field))) new_clauses.append(new_cl) for repl in new_map_repls: map, key, value = repl seen_map = True seen_map_repls.add(repl) new_cl = L.Enumerator(L.tuplify((map, key, value), lval=True), L.ln(make_maprel())) new_clauses.append(new_cl) return new_expr, new_clauses
def make_addu_maint(self, prefix): """Generate code for after an addition to U.""" incaggr = self.incaggr assert incaggr.has_demand spec = incaggr.spec mv_var = prefix + 'val' elemvar = prefix + 'elem' # If we're using half-demand, there's no demand to propagate # to the operand. All we need to do is add an entry with count # 0 if one is not already there. if incaggr.half_demand: return L.pc(''' S_MV = A.smdeflookup(MASK, KEY, None) if MV is None: A.smassignkey(MASK, KEY, ZERO, PREFIX) ''', subst={'A': L.ln(incaggr.name), 'S_MV': L.sn(mv_var), 'MV': L.ln(mv_var), 'MASK': incaggr.aggrmask.make_node(), 'KEY': L.tuplify(incaggr.params), 'ZERO': self.make_zero_mapval_expr(), 'PREFIX': L.Str(prefix)}) update_code = self.make_update_state_code( L.sn(mv_var), L.ln(mv_var), 'add', L.ln(elemvar), prefix) # Make operand demand function call, if operand uses demand. if spec.has_oper_demand: demfunc = L.N.demfunc(spec.oper_demname) call_demfunc = L.Call(L.ln(demfunc), tuple(L.ln(v) for v in spec.oper_demparams), (), None, None) propagate_code = (L.Expr(call_demfunc),) else: propagate_code = () code = L.pc(''' S_MV = ZERO for S_ELEM in setmatch(R, RELMASK, PARAMS): UPDATE_MAPVAL A.smassignkey(AGGRMASK, KEY, MV, PREFIX) PROPAGATE_DEMAND ''', subst={'S_MV': L.sn(mv_var), 'ZERO': self.make_zero_mapval_expr(), 'S_ELEM': L.sn(elemvar), 'R': spec.rel, 'RELMASK': spec.relmask.make_node(), 'PARAMS': L.tuplify(spec.params), '<c>UPDATE_MAPVAL': update_code, 'A': L.ln(incaggr.name), 'AGGRMASK': incaggr.aggrmask.make_node(), 'KEY': L.tuplify(incaggr.params), 'MV': L.ln(mv_var), 'PREFIX': L.Str(prefix), '<c>PROPAGATE_DEMAND': propagate_code}) return code
def make_auxmap_maint_code(manager, spec, elem, addremove): """Construct auxmap maintenance code for a set update.""" assert addremove in ['add', 'remove'] prefix = manager.namegen.next_prefix() mask = spec.mask # Create fresh variables for the tuple components. vars = [prefix + str(i) for i in range(1, len(mask) + 1)] bvars, uvars, eqs = mask.split_vars(vars) vars_node = L.tuplify(vars, lval=True) map_node = L.ln(spec.map_name) bvars_node = L.tuplify(bvars) uvars_node = L.tuplify(uvars) # If there are equalities, include a conditional check for the # constraints being satisfied. If there are wildcards, make the # image set manipulation operations reference-counted. # # Avoid these in cases where we don't have equalities/wildcards, # to help reduce constant-factor bloat in code size and running # time. if mask.has_equalities: template = ''' VARS = ELEM if EQCOND: MAP.IMGOP(BVARS, UVARS) ''' eqcond = make_vareq_cond(eqs) else: template = ''' VARS = ELEM MAP.IMGOP(BVARS, UVARS) ''' eqcond = None if mask.has_wildcards: imgop = {'add': 'rcimgadd', 'remove': 'rcimgremove'}[addremove] else: imgop = {'add': 'imgadd', 'remove': 'imgremove'}[addremove] code = L.pc(template, subst={ '@IMGOP': imgop, 'VARS': vars_node, 'ELEM': elem, 'MAP': map_node, 'BVARS': bvars_node, 'UVARS': uvars_node, 'EQCOND': eqcond}) return code
def test_objclausefactory(self): cl = MClause('S', 'x') clast = L.Enumerator(L.tuplify(['S', 'x'], lval=True), L.pe('_M')) cl2 = ObjClauseFactory.from_AST(clast) self.assertEqual(cl2, cl) cl = FClause_NoTC('o', 'v', 'f') clast = L.Enumerator(L.tuplify(['o', 'v'], lval=True), L.pe('_F_f')) cl2 = ObjClauseFactory_NoTC.from_AST(clast) self.assertEqual(cl2, cl) self.assertIsInstance(cl2, FClause_NoTC)
def make_removeu_maint(self, prefix): """Generate code for before a removal from U.""" incaggr = self.incaggr assert incaggr.has_demand spec = incaggr.spec mv_var = prefix + 'val' # If we're using half-demand, there's no demand to propagate # to the operand. All we need to do is determine whether to # do the removal by checking whether the count is 0. if incaggr.half_demand: return L.pc(''' S_MV = A.smlookup(MASK, KEY) if COUNT == 0: A.smdelkey(MASK, KEY, PREFIX) ''', subst={ 'S_MV': L.sn(mv_var), 'COUNT': self.mapval_proj_count(L.ln(mv_var)), 'A': incaggr.name, 'MASK': incaggr.aggrmask.make_node(), 'KEY': L.tuplify(incaggr.params), 'PREFIX': L.Str(prefix) }) # Generate operand undemand function call, if operand # uses demand. if spec.has_oper_demand: undemfunc = L.N.undemfunc(spec.oper_demname) call_undemfunc = L.Call( L.ln(undemfunc), tuple(L.ln(v) for v in spec.oper_demparams), (), None, None) propagate_code = (L.Expr(call_undemfunc), ) else: propagate_code = () code = L.pc(''' PROPAGATE_DEMAND A.smdelkey(MASK, KEY, PREFIX) ''', subst={ 'A': incaggr.name, 'MASK': incaggr.aggrmask.make_node(), 'KEY': L.tuplify(incaggr.params), 'PREFIX': L.Str(prefix), '<c>PROPAGATE_DEMAND': propagate_code }) return code
def visit_SetUpdate(self, node): rel = L.get_name(node.target) if rel not in self.at_rels: return # New code gets inserted after the update. # This is true even if the update was a removal. # It shouldn't matter where we do the U-set update, # so long as the invariants are properly maintained # at the time. prefix = self.manager.namegen.next_prefix() vars = [prefix + v for v in self.projvars] if node.op == 'add': funcname = L.N.demfunc(self.demname) else: funcname = L.N.undemfunc(self.demname) call_func = L.Call(L.ln(funcname), tuple(L.ln(v) for v in vars), (), None, None) postcode = L.pc(''' for S_PROJVARS in DELTA.elements(): CALL_FUNC DELTA.clear() ''', subst={'S_PROJVARS': L.tuplify(vars, lval=True), 'DELTA': self.delta_name, 'CALL_FUNC': call_func}) return self.with_outer_maint(node, funcname, L.ts(node), (), postcode)
def test_enumclause_setmatch(self): # Make sure we can convert clauses over setmatches. cl = EnumClause.from_AST( L.Enumerator(L.tuplify(['y'], lval=True), L.SetMatch(L.ln('R'), 'bu', L.ln('x'))), DummyFactory) exp_cl = EnumClause(('x', 'y'), 'R') self.assertEqual(cl, exp_cl)
def test_deltaclause(self): cl = DeltaClause(('x', 'y'), 'R', L.pe('e'), 1) # AST round-trip. clast = cl.to_AST() exp_clast = L.Enumerator(L.tuplify(['x', 'y'], lval=True), L.pe('deltamatch(R, "bb", e, 1)')) self.assertEqual(clast, exp_clast) cl2 = DeltaClause.from_AST(exp_clast, DummyFactory) self.assertEqual(cl2, cl) # Attributes. self.assertEqual(cl.rel, 'R') # Code generation, no fancy mask. code = cl.get_code([], L.pc('pass')) exp_code = L.pc(''' x, y = e pass ''') self.assertEqual(code, exp_code) # Code generation, fancy mask. cl2 = DeltaClause(('x', 'x', '_'), 'R', L.pe('e'), 1) code = cl2.get_code([], L.pc('pass')) exp_code = L.pc(''' for x in setmatch(deltamatch(R, 'b1w', e, 1), 'u1w', ()): pass ''') self.assertEqual(code, exp_code)
def to_comp(self, options): """Create a corresponding Comp node.""" clauses = tuple(cl.to_AST() for cl in self.clauses) if self.delta is not None: options = self.delta.updateopts(options) return L.Comp(L.tuplify(self.enumvars), clauses, (), options)
def test_enumclause_basic(self): cl = EnumClause(('x', 'y', 'x', '_'), 'R') # From expression. cl2 = EnumClause.from_expr(L.pe('(x, y, x, _) in R')) self.assertEqual(cl2, cl) # AST round-trip. clast = cl.to_AST() exp_clast = \ L.Enumerator(L.tuplify(['x', 'y', 'x', '_'], lval=True), L.ln('R')) self.assertEqual(clast, exp_clast) cl2 = EnumClause.from_AST(exp_clast, DummyFactory) self.assertEqual(cl2, cl) # Attributes. self.assertFalse(cl.isdelta) self.assertEqual(cl.enumlhs, ('x', 'y', 'x', '_')) self.assertEqual(cl.enumvars, ('x', 'y')) self.assertEqual(cl.pat_mask, (True, True, True, True)) self.assertEqual(cl.enumrel, 'R') self.assertTrue(cl.has_wildcards) self.assertEqual(cl.vars, ('x', 'y')) self.assertEqual(cl.eqvars, None) self.assertTrue(cl.robust) self.assertEqual(cl.demname, None) self.assertEqual(cl.demparams, ())
def uset_to_comp(ds, uset, factory, first_clause): """Convert a uset to a comprehension.""" subdem_tags = uset.preds is not None if subdem_tags: tags_by_name = {t.name: t for t in ds.tags} clauses = [] for tname in uset.preds: t = tags_by_name[tname] clauses.append(EnumClause([t.var], t.name)) else: clauses = uset.pred_clauses # As an odd special case, if there are no preds, # use an emptiness test on the first enumerator, # which should be a U-set. if len(clauses) == 0: assert first_clause.kind is Clause.KIND_ENUM assert uset.i != 0, 'Can\'t make demand invariant for inner ' \ 'query; it is the first clause' cl = EnumClause(tuple('_' for _ in first_clause.enumlhs), first_clause.enumrel) clauses.append(cl) spec = CompSpec(Join(clauses, factory, None), L.tuplify(uset.vars), []) return spec.to_comp({})
def to_AST(self): mask = Mask.from_keylen(len(self.lhs) - 1) keyvars = self.lhs[:-1] var = self.lhs[-1] sm = L.SMLookup(L.ln(self.rel), mask.make_node().s, L.tuplify(keyvars), None) return L.Enumerator(L.sn(var), L.Set((sm, )))
def get_subquery_demnames(spec): """For each subquery in a comp, construct an invariant definition for its U-set, formed as the conjunction of the enumerators to the left of the subquery's occurrence. Each parameter of the subquery must be an enumvar in one of these clauses. Return a list of pairs of a demand name of a subquery and its invariant (comp spec). """ clauses = spec.join.clauses clauses = [cl for cl in clauses if cl.kind is cl.KIND_ENUM] result = [] for i, cl in enumerate(clauses): if cl.has_demand: # Grab clauses to the left of this one. # If they too are demand clauses, unwrap them to get # the underlying clause. demclauses = clauses[:i] for i, demcl in enumerate(demclauses): if demcl.has_demand: demclauses[i] = demcl.cl # Make sure the demand parameters are all bound in clauses # to the left of here. boundvars = set(v for demcl in demclauses for v in demcl.enumvars) unboundparams = set(cl.demparams) - boundvars assert len(unboundparams) == 0, \ 'Subquery parameter(s) {} not bound in clause to left ' \ 'of occurrence'.format(unboundparams) # Construct the invariant. new_join = Join(demclauses, spec.join.factory, None) new_spec = CompSpec(new_join, L.tuplify(cl.demparams), ()) result.append((cl.demname, new_spec)) return result
def get_res_code(self): """Return code (expression) to lookup the result.""" params = self.inccomp.comp.params if len(params) > 0: resexp = self.inccomp.spec.resexp assert isinstance(resexp, L.Tuple) resexp_arity = len(resexp.elts) n_rescomponents = resexp_arity - len(params) maskstr = 'b' * len(params) + 'u' * n_rescomponents masknode = Mask(maskstr).make_node() paramsnode = L.tuplify(params) code = L.pe(''' setmatch(RES, MASK, PARAMS) ''', subst={ 'RES': L.ln(self.inccomp.name), 'MASK': masknode, 'PARAMS': paramsnode }) else: code = L.ln(self.inccomp.name) return code
def structures_to_comps(ds, factory): """Convert tags and filters to comprehensions that define them. Return pairs of comps and their names, in dependency order. Ignore usets. """ tags_by_name = {t.name: t for t in ds.tags} result = [] for s in ds.structs: if s.kind is KIND_TAG: cl = EnumClause(s.lhs, s.rel) spec = CompSpec(Join([cl], factory, None), L.ln(s.var), []) elif s.kind is KIND_FILTER: cls = [] for tname in s.preds: t = tags_by_name[tname] cls.append(EnumClause([t.var], t.name)) # Be sure to replace wildcards with fresh vars. lhs = inst_wildcards(s.lhs) cls.append(EnumClause(lhs, s.rel)) spec = CompSpec(Join(cls, factory, None), L.tuplify(lhs), []) elif s.kind is KIND_USET: continue else: assert() result.append((s.name, spec.to_comp({}))) return result
def to_AST(self): mask = Mask.from_keylen(len(self.lhs) - 1) keyvars = self.lhs[:-1] var = self.lhs[-1] sm = L.SMLookup(L.ln(self.rel), mask.make_node().s, L.tuplify(keyvars), None) return L.Enumerator(L.sn(var), L.Set((sm,)))
def test_singletonclause(self): cl = SingletonClause(('x', 'y'), L.pe('e')) # From expression. cl2 = SingletonClause.from_expr(L.pe('(x, y) == e')) self.assertEqual(cl, cl2) # AST round-trip. clast = cl.to_AST() exp_clast = L.Enumerator(L.tuplify(['x', 'y'], lval=True), L.pe('{e}')) self.assertEqual(clast, exp_clast) cl2 = SingletonClause.from_AST(exp_clast, DummyFactory) self.assertEqual(cl2, cl) # Attributes. self.assertEqual(cl.enumvars, ('x', 'y')) # Code generation. code = cl.get_code([], L.pc('pass')) exp_code = L.pc(''' x, y = e pass ''') self.assertEqual(code, exp_code)
def updateopts(self, options): """Return a modified options dict with the delta keys set.""" options = dict(options) options['_deltarel'] = self.rel options['_deltaelem'] = L.ts(self.elem) options['_deltalhs'] = L.ts(L.tuplify(self.lhs, lval=True)) options['_deltaop'] = self.op return options
def visit_Enumerator(self, node): node = self.generic_visit(node) if not (isinstance(node.iter, L.Name) and node.iter.id == self.rel): return vars = get_clause_vars(node, self.tuptype) new_lhs = L.tuplify(vars, lval=True) return node._replace(target=new_lhs)
def get_code(self, bindenv, body): guard_code = L.pc(''' if LHS != EXCL: BODY ''', subst={'LHS': L.tuplify(self.cl.enumlhs), 'EXCL': self.excl, '<c>BODY': body}) return self.cl.get_code(bindenv, guard_code)
def make_removeu_maint(self, prefix): """Generate code for before a removal from U.""" incaggr = self.incaggr assert incaggr.has_demand spec = incaggr.spec mv_var = prefix + 'val' # If we're using half-demand, there's no demand to propagate # to the operand. All we need to do is determine whether to # do the removal by checking whether the count is 0. if incaggr.half_demand: return L.pc(''' S_MV = A.smlookup(MASK, KEY) if COUNT == 0: A.smdelkey(MASK, KEY, PREFIX) ''', subst={'S_MV': L.sn(mv_var), 'COUNT': self.mapval_proj_count(L.ln(mv_var)), 'A': incaggr.name, 'MASK': incaggr.aggrmask.make_node(), 'KEY': L.tuplify(incaggr.params), 'PREFIX': L.Str(prefix)}) # Generate operand undemand function call, if operand # uses demand. if spec.has_oper_demand: undemfunc = L.N.undemfunc(spec.oper_demname) call_undemfunc = L.Call(L.ln(undemfunc), tuple(L.ln(v) for v in spec.oper_demparams), (), None, None) propagate_code = (L.Expr(call_undemfunc),) else: propagate_code = () code = L.pc(''' PROPAGATE_DEMAND A.smdelkey(MASK, KEY, PREFIX) ''', subst={'A': incaggr.name, 'MASK': incaggr.aggrmask.make_node(), 'KEY': L.tuplify(incaggr.params), 'PREFIX': L.Str(prefix), '<c>PROPAGATE_DEMAND': propagate_code}) return code
def get_code(self, bindenv, body): deltamask = Mask.from_vars(self.lhs, self.lhs) mask = Mask.from_vars(self.lhs, bindenv) bvars, uvars, _eqs = mask.split_vars(self.lhs) if mask.has_wildcards: # Can this be streamlined into something more readable, # like expressing the deltamatch as an If-guard? val = L.DeltaMatch(L.ln(self.rel), deltamask.make_node().s, self.val, self.limit) return L.pc(''' for UVARS in setmatch(VAL, MASK, BVARS): BODY ''', subst={'VAL': val, 'MASK': mask.make_node(), 'BVARS': L.tuplify(bvars), 'UVARS': L.tuplify(uvars, lval=True), '<c>BODY': body}) else: return make_tuplematch(self.val, mask, bvars, uvars, body)
def for_rels_union_code(vars, iters, body, tempname, *, verify_disjoint=False): """Generate code to run body once for each element in the union of the evaluations of iters. A temporary set is used to eliminate duplicates from the union. """ assert len(iters) > 0 if len(iters) == 1: return for_rel_code(vars, iters[0], body) code = L.pc(''' TEMPSET = set() ''', subst={'TEMPSET': tempname}) for iter in iters: if verify_disjoint: template = L.trim(''' for S_VARS in ITER: assert VARS not in TEMPSET TEMPSET.add(VARS) ''') else: template = L.trim(''' for S_VARS in ITER: TEMPSET.nsadd(VARS) ''') code += L.pc(template, subst={'S_VARS': L.tuplify(vars, lval=True), 'ITER': iter, 'TEMPSET': tempname, 'VARS': L.tuplify(vars)}) code += L.pc(''' for VARS in TEMPSET: BODY del D_TEMPSET ''', subst={'VARS': L.tuplify(vars, lval=True), 'TEMPSET': tempname, '<c>BODY': body, 'D_TEMPSET': L.dn(tempname)}) return code
def visit_Tuple(self, node): # No need to recurse, that's taken care of by the caller # of this visitor. tupvar = self.tupvar_namer.next() arity = len(node.elts) trel = make_trel(arity) elts = (L.sn(tupvar), ) + node.elts new_cl = L.Enumerator(L.tuplify(elts, lval=True), L.ln(trel)) self.new_clauses.append(new_cl) self.trels.add(trel) return L.sn(tupvar)
def test_objclausefactory(self): cl = TClause('t', ['x', 'y']) clast = L.Enumerator(L.tuplify(['t', 'x', 'y'], lval=True), L.pe('_TUP2')) cl2 = TupClauseFactory.from_AST(clast) self.assertEqual(cl2, cl) cl = TClause_NoTC('t', ['x', 'y']) cl2 = TupClauseFactory_NoTC.from_AST(clast) self.assertEqual(cl2, cl) self.assertIsInstance(cl2, TClause_NoTC)
def for_rel_code(vars, iter, body): """Generate code to run body once for each element in the valuation of iter. vars are bound to the components of each element. iter should evaluate to a relation or arity len(vars). """ return L.pc(''' for VARS in ITER: BODY ''', subst={'VARS': L.tuplify(vars, lval=True), 'ITER': iter, '<c>BODY': body})
def visit_Tuple(self, node): # No need to recurse, that's taken care of by the caller # of this visitor. tupvar = self.tupvar_namer.next() arity = len(node.elts) trel = make_trel(arity) elts = (L.sn(tupvar),) + node.elts new_cl = L.Enumerator(L.tuplify(elts, lval=True), L.ln(trel)) self.new_clauses.append(new_cl) self.trels.add(trel) return L.sn(tupvar)
def get_code(self, bindenv, body): guard_code = L.pc(''' if LHS != EXCL: BODY ''', subst={ 'LHS': L.tuplify(self.cl.enumlhs), 'EXCL': self.excl, '<c>BODY': body }) return self.cl.get_code(bindenv, guard_code)
def test_mapclause(self): cl = MapClause('m', 'k', 'v') # Construct from expression. cl2 = MapClause.from_expr(L.pe('(m, k, v) in _MAP')) self.assertEqual(cl2, cl) # AST round-trip. clast = cl.to_AST() exp_clast = L.Enumerator(L.tuplify(('m', 'k', 'v'), lval=True), L.ln('_MAP')) self.assertEqual(clast, exp_clast) cl2 = MapClause.from_AST(exp_clast, ObjClauseFactory) self.assertEqual(cl2, cl) # Attributes. self.assertEqual(cl.enumlhs, ('m', 'k', 'v')) self.assertEqual(cl.pat_mask, (False, True, True)) self.assertEqual(cl.enumvars_tagsin, ('m', )) self.assertEqual(cl.enumvars_tagsout, ('k', 'v')) # Rate. rate = cl.rate([]) self.assertEqual(rate, Rate.UNRUNNABLE) # Code. code = cl.get_code(['m'], L.pc('pass')) exp_code = L.pc(''' if isinstance(m, Map): for k, v in m.items(): pass ''') self.assertEqual(code, exp_code) code = cl.get_code(['m', 'k'], L.pc('pass')) exp_code = L.pc(''' if isinstance(m, Map): if k in m: v = m[k] pass ''') self.assertEqual(code, exp_code) # Code, no type-checks. cl = MapClause_NoTC('m', 'k', 'v') code = cl.get_code(['m'], L.pc('pass')) exp_code = L.pc(''' for k, v in m.items(): pass ''') self.assertEqual(code, exp_code)
def get_code(self, bindenv, body): deltamask = Mask.from_vars(self.lhs, self.lhs) mask = Mask.from_vars(self.lhs, bindenv) bvars, uvars, _eqs = mask.split_vars(self.lhs) if mask.has_wildcards: # Can this be streamlined into something more readable, # like expressing the deltamatch as an If-guard? val = L.DeltaMatch(L.ln(self.rel), deltamask.make_node().s, self.val, self.limit) return L.pc(''' for UVARS in setmatch(VAL, MASK, BVARS): BODY ''', subst={ 'VAL': val, 'MASK': mask.make_node(), 'BVARS': L.tuplify(bvars), 'UVARS': L.tuplify(uvars, lval=True), '<c>BODY': body }) else: return make_tuplematch(self.val, mask, bvars, uvars, body)
def __init__(self, aggrop, rel, relmask, params, oper_demname, oper_demparams): assert self.aggrop in ['count', 'sum', 'min', 'max'] # AST node representation. node = L.ln(rel) if len(params) > 0: node = L.SetMatch(node, relmask.make_node().s, L.tuplify(params)) if oper_demname is not None: node = L.DemQuery(oper_demname, [L.ln(p) for p in oper_demparams], node) node = L.Aggregate(node, aggrop, None) self.node = node
def test_clausefactory(self): # Construct from AST. clast = L.Enumerator(L.tuplify(['x', 'y'], lval=True), L.pe('R - {e}')) cl = ClauseFactory.from_AST(clast) exp_cl = SubClause(EnumClause(('x', 'y'), 'R'), L.pe('e')) self.assertEqual(cl, exp_cl) # rewrite_subst(). cl = SubClause(EnumClause(('x', 'y'), 'R'), L.pe('e')) cl = ClauseFactory.rewrite_subst(cl, {'x': 'z'}) exp_cl = SubClause(EnumClause(('z', 'y'), 'R'), L.pe('e')) self.assertEqual(cl, exp_cl) # bind(). cl = EnumClause(('x', 'y'), 'R') cl = ClauseFactory.bind(cl, L.pe('e'), augmented=False) exp_cl = DeltaClause(['x', 'y'], 'R', L.pe('e'), 1) self.assertEqual(cl, exp_cl) # subtract(). cl = EnumClause(('x', 'y'), 'R') cl = ClauseFactory.subtract(cl, L.pe('e')) exp_cl = SubClause(EnumClause(('x', 'y'), 'R'), L.pe('e')) self.assertEqual(cl, exp_cl) # augment(). cl = EnumClause(['x', 'y'], 'R') cl = ClauseFactory.augment(cl, L.pe('e')) exp_cl = AugClause(EnumClause(['x', 'y'], 'R'), L.pe('e')) self.assertEqual(cl, exp_cl) # rewrite_rel(). cl = SubClause(EnumClause(('x', 'y'), 'R'), L.pe('e')) cl = ClauseFactory.rewrite_rel(cl, 'S') exp_cl = SubClause(EnumClause(('x', 'y'), 'S'), L.pe('e')) self.assertEqual(cl, exp_cl) # membercond_to_enum(). cl = CondClause(L.pe('(x, y) in R')) cl = ClauseFactory.membercond_to_enum(cl) exp_cl = EnumClause(('x', 'y'), 'R') self.assertEqual(cl, exp_cl) # enum_to_membercond(). cl = EnumClause(('x', 'y'), 'R') cl = ClauseFactory.enum_to_membercond(cl) exp_cl = CondClause(L.pe('(x, y) in R')) self.assertEqual(cl, exp_cl)
def without_params(self, flat=False): """Produce a CompSpec where the result expression is rewritten as a tuple of the parameters and the old result expression, and where the parameters are turned into locals. If flat is True, the old result expression must be a tuple, and the new one is formed by concatenating a tuple of the parameters with the old result expression. """ if flat: assert isinstance(self.resexp, L.Tuple) elts = self.params + self.resexp.elts else: elts = self.params + (self.resexp,) new_resexp = L.tuplify(elts) return self._replace(resexp=new_resexp, params=())
def make_projkey(self, val): """If this mask has no 'u' components, given a value for a tuple, construct a key expression out of the non-wildcard components. """ components = [] for i, c in enumerate(self.parts): if c == 'b': # val[i] expr = L.Subscript(val, L.Index(L.Num(i)), L.Load()) components.append(expr) elif c == 'w': pass elif c.isdigit(): pass else: assert() return L.tuplify(components)
def make_projkey(self, val): """If this mask has no 'u' components, given a value for a tuple, construct a key expression out of the non-wildcard components. """ components = [] for i, c in enumerate(self.parts): if c == 'b': # val[i] expr = L.Subscript(val, L.Index(L.Num(i)), L.Load()) components.append(expr) elif c == 'w': pass elif c.isdigit(): pass else: assert () return L.tuplify(components)