def visit_Member(self, node): node = self.generic_visit(node) if (isinstance(node.iter, L.Name) and node.iter.id in self.rels and L.is_tuple_of_names(node.target)): return L.RelMember(L.detuplify(node.target), node.iter.id) return node
def functionally_determines(self, cl, bindenv): mask = L.mask_from_bounds(self.lhs_vars(cl), bindenv) if mask == L.mask('bu'): return True else: return super().functionally_determines(cl, bindenv)
def get_priority(self, cl, bindenv): mask = L.mask_from_bounds(self.lhs_vars(cl), bindenv) if mask == L.mask('bu'): return Priority.Constant else: return super().get_priority(cl, bindenv)
def py_preprocess(tree): """Take in a Python AST tree, partially preprocess it, and return the corresponding IncAST tree along with parsed information. """ # Rewrite QUERY directives to replace their source strings with # the corresponding parsed Python ASTs. Provided that the other # preprocessing steps are functional (i.e., apply equally to # multiple occurrences of the same AST), this ensures that any # subsequent steps that modify occurrences of a query will also # modify its occurrence in the QUERY directive. tree = preprocess_query_directives(tree) # Admit some constructs as syntactic sugar that would otherwise # be excluded from IncAST. tree = preprocess_constructs(tree) # Get rid of import statement and qualifiers for the runtime # library. tree = preprocess_runtime_import(tree) # Get rid of main boilerplate. tree = preprocess_main_call(tree) # Get relation declarations. tree, decls = preprocess_var_decls(tree) # Get symbol info. tree, info = preprocess_directives(tree) # Convert the tree and parsed query info to IncAST. tree = L.import_incast(tree) info.query_info = [(L.import_incast(query), value) for query, value in info.query_info] return tree, decls, info
def get_code(self, cl, bindenv, body): assert_unique(cl.vars) mask = L.mask_from_bounds(cl.vars, bindenv) keyvars, valvar = L.split_by_mask(cl.mask, cl.vars) valvar = valvar[0] # Can also handle all-unbound case by iterating over dict.items(), # but requires fresh var for decomposing key tuple. if L.mask_is_allbound(mask): comparison = L.Parser.pe('_KEY in _MAP and _MAP[_KEY] == _VALUE', subst={'_MAP': cl.map, '_KEY': L.tuplify(keyvars), '_VALUE': valvar}) code = (L.If(comparison, body, ()),) elif mask == cl.mask: code = L.Parser.pc(''' if _KEY in _MAP: _VALUE = _MAP[_KEY] _BODY ''', subst={'_MAP': cl.map, '_KEY': L.tuplify(keyvars), '_VALUE': valvar, '<c>_BODY': body}) else: code = super().get_code(cl, bindenv, body) return code
def make_demand_query(symtab, query, left_clauses): """Create a demand query, update the query's demand_query attribute, and return the new demand query symbol. """ ct = symtab.clausetools demquery_name = N.get_query_demand_query_name(query.name) demquery_tuple = L.tuplify(query.demand_params) demquery_tuple_type = symtab.analyze_expr_type(demquery_tuple) demquery_type = T.Set(demquery_tuple_type) demquery_comp = L.Comp(demquery_tuple, left_clauses) prefix = next(symtab.fresh_names.vars) demquery_comp = ct.comp_rename_lhs_vars(demquery_comp, lambda x: prefix + x) demquery_sym = symtab.define_query(demquery_name, type=demquery_type, node=demquery_comp, impl=query.impl) query.demand_query = demquery_name return demquery_sym
def get_code(self, cl, bindenv, body): vars = self.lhs_vars(cl) assert_unique(vars) mask = L.mask_from_bounds(vars, bindenv) lookup_expr = L.DictLookup(L.Name(cl.map), L.Name(cl.key), None) if L.mask_is_allbound(mask): comparison = L.Compare(L.Name(cl.value), L.Eq(), lookup_expr) code = (L.If(comparison, body, ()),) needs_typecheck = True elif mask == L.mask('bbu'): code = (L.Assign(cl.value, lookup_expr),) code += body needs_typecheck = True elif mask == L.mask('buu'): items_expr = L.Parser.pe('_MAP.items()', subst={'_MAP': cl.map}) code = (L.DecompFor([cl.key, cl.value], items_expr, body),) needs_typecheck = True else: code = super().get_code(cl, bindenv, body) needs_typecheck = False if needs_typecheck and self.use_typecheck: code = (L.If(L.IsMap(L.Name(cl.map)), code, ()),) return code
def visit_DictLookup(self, node): node = self.generic_visit(node) # Only simple map lookups are allowed. assert isinstance(node.value, L.Name) assert L.is_tuple_of_names(node.key) assert node.default is None map = node.value.id keyvars = L.detuplify(node.key) var = self.repls.get(node, None) if var is None: mask = L.mapmask_from_len(len(keyvars)) rel = N.SA_name(map, mask) # Create a fresh variable. self.repls[node] = var = next(self.fresh_names) # Construct a clause to bind it. vars = list(keyvars) + [var] new_clause = L.SetFromMapMember(vars, rel, map, mask) self.new_clauses.append(new_clause) # Construct a corresponding SetFromMap invariant. sfm = SetFromMapInvariant(rel, map, mask) self.sfm_invs.add(sfm) return L.Name(var)
def rewrite_comp(self, symbol, name, comp): if name == query.name: if query.params == (): return L.Name(result_var) else: mask = L.keymask_from_len(len(query.params), orig_arity) return L.ImgLookup(L.Name(result_var), mask, query.params)
def convert_subquery_clause(clause): """Given a clause, if it is a VarsMember clause for an incrementalized subquery, return an equivalent RelMember clause. For any other clause return the clause unchanged. The two forms recognized are: - right-hand side is a Name node - right-hand side is an ImgLookup node on a Name, with a keymask """ if not isinstance(clause, L.VarsMember): return clause if isinstance(clause.iter, L.Name): return L.RelMember(clause.vars, clause.iter.id) elif (isinstance(clause.iter, L.ImgLookup) and isinstance(clause.iter.set, L.Name) and L.is_keymask(clause.iter.mask)): nb, nu = L.break_keymask(clause.iter.mask) assert nb == len(clause.iter.bounds) assert nu == len(clause.vars) return L.RelMember(clause.iter.bounds + clause.vars, clause.iter.set.id) return clause
def visit_Member(self, node): # For clauses that wrap around another clause, like # WithoutMember, reorient the target and iter before recursing. handled = False # <target> in <expr> - {<elem>} if (isinstance(node.iter, L.BinOp) and isinstance(node.iter.op, L.Sub) and isinstance(node.iter.right, L.Set) and len(node.iter.right.elts) == 1): inner_clause = L.Member(node.target, node.iter.left) node = L.WithoutMember(inner_clause, node.iter.right.elts[0]) handled = True node = self.generic_visit(node) if handled: return node # <vars> in {<elem>} if (L.is_tuple_of_names(node.target) and isinstance(node.iter, L.Set) and len(node.iter.elts) == 1): return L.SingMember(L.detuplify(node.target), node.iter.elts[0]) return node
def make_comp_maint_func(clausetools, fresh_var_prefix, fresh_join_names, comp, result_var, rel, op, *, counted): """Make maintenance function for a comprehension.""" assert isinstance(op, (L.SetAdd, L.SetRemove)) op_name = L.set_update_name(op) func_name = N.get_maint_func_name(result_var, rel, op_name) update = L.RelUpdate(rel, op, '_elem') code = clausetools.get_maint_code(fresh_var_prefix, fresh_join_names, comp, result_var, update, counted=counted) func = L.Parser.ps(''' def _FUNC(_elem): _CODE ''', subst={ '_FUNC': func_name, '<c>_CODE': code }) return func
def get_loop_for_join(self, comp, body, query_name): """Given a join, create code for iterating over it and running body. The join is wrapped in a Query node with the given name. """ assert self.is_join(comp) vars = self.lhs_vars_from_comp(comp) return (L.DecompFor(vars, L.Query(query_name, comp, None), body), )
def visit_Call(self, node): node = self.generic_visit(node) if node.func == 'len': if not len(node.args) == 1: raise L.ProgramError('Expected one argument for len()') return L.Aggr(L.Count(), node.args[0]) return node
def rewrite_with_demand(self, query_sym, node): """Given a query symbol and its associated Comp or Aggr node, return the demand-transformed version of that node (not transforming any subqueries). """ symtab = self.symtab demand_params = query_sym.demand_params if not query_sym.uses_demand: return node # Make a demand set or demand query. left_clauses = self.get_left_clauses() if left_clauses is None: dem_sym = make_demand_set(symtab, query_sym) dem_node = L.Name(dem_sym.name) dem_clause = L.RelMember(demand_params, dem_sym.name) self.queries_with_usets.add(query_sym.name) else: dem_sym = make_demand_query(symtab, query_sym, left_clauses) dem_node = dem_sym.make_node() dem_clause = L.VarsMember(demand_params, dem_node) self.demand_queries.add(dem_sym.name) # Determine the rewritten node. if isinstance(node, L.Comp): node = node._replace(clauses=(dem_clause, ) + node.clauses) elif isinstance(node, L.Aggr): node = L.AggrRestr(node.op, node.value, demand_params, dem_node) else: raise AssertionError( 'No rule for handling demand for {} node'.format( node.__class__.__name__)) return node
def visit_Member(self, node): node = self.generic_visit(node) if (L.is_tuple_of_names(node.target) and isinstance(node.iter, L.Query)): node = L.VarsMember(L.detuplify(node.target), node.iter) return node
def visit_RelClear(self, node): if node.rel not in self.rels: return node code = (node, ) clear_code = (L.RelClear(self.result_var), ) code = L.insert_rel_maint(code, clear_code, L.SetRemove()) return code
def visit_DecompFor(self, node): if (isinstance(node.iter, L.Name) and node.iter.id in self.rels): if len(node.vars) != 1: raise L.TransformationError( 'Singleton unwrapping requires all DecompFor loops ' 'over relation to have exactly one target variable') return L.For(node.vars[0], node.iter, node.body) return node
def visit_AttrAssign(self, node): if node.attr not in self.objrels.Fs: return pair = L.Tuple([node.obj, node.value]) var = next(self.fresh_vars) return (L.Assign(var, pair), L.RelUpdate(N.F(node.attr), L.SetAdd(), var))
def visit_DictAssign(self, node): if not self.objrels.MAP: return triple = L.Tuple([node.target, node.key, node.value]) var = next(self.fresh_vars) return (L.Assign(var, triple), L.RelUpdate(N.MAP, L.SetAdd(), var))
def visit_DictDelete(self, node): if not self.objrels.MAP: return lookup = L.DictLookup(node.target, node.key, None) triple = L.Tuple([node.target, node.key, lookup]) var = next(self.fresh_vars) return (L.Assign(var, triple), L.RelUpdate(N.MAP, L.SetRemove(), var))
def visit_AttrDelete(self, node): if node.attr not in self.objrels.Fs: return lookup = L.Attribute(node.obj, node.attr) pair = L.Tuple([node.obj, lookup]) var = next(self.fresh_vars) return (L.Assign(var, pair), L.RelUpdate(N.F(node.attr), L.SetRemove(), var))
def rewrite_aggr(self, symbol, name, aggr): if isinstance(aggr.value, L.Name): relsym = symtab.get_relations()[aggr.value.id] rel_type = relsym.type if not (isinstance(rel_type, T.Set) and isinstance(rel_type.elt, T.Tuple)): new_value = L.Unwrap(L.Wrap(aggr.value)) aggr = aggr._replace(value=new_value) return aggr
def get_priority(self, cl, bindenv): mask = L.mask_from_bounds(self.lhs_vars(cl), bindenv) if L.mask_is_allbound(mask): return Priority.Constant elif L.mask_is_allunbound(mask): return Priority.Unpreferred else: return Priority.Normal
def visit_MapClear(self, node): sfm = self.setfrommaps_by_map.get(node.map, None) if sfm is None: return node code = (node, ) clear_code = (L.RelClear(sfm.rel), ) code = L.insert_rel_maint(code, clear_code, L.SetRemove()) return code
def visit_SetUpdate(self, node): if not isinstance(node.op, (L.SetAdd, L.SetRemove)): return if not self.objrels.M: return pair = L.Tuple([node.target, node.value]) var = next(self.fresh_vars) return (L.Assign(var, pair), L.RelUpdate(N.M, node.op, var))
def visit_Query(self, node): self.generic_visit(node) if node.ann is not None: if not isinstance(node.ann, (dict, frozendict)): raise L.ProgramError('Query annotation must be a ' 'dictionary') for key in node.ann.keys(): if key not in S.annotations: raise L.ProgramError( 'Unknown annotation key "{}"'.format(key))
def visit_MapDelete(self, node): sfm = self.setfrommaps_by_map.get(node.map, None) if sfm is None: return node code = (node, ) func_name = sfm.get_maint_func_name('delete') call_code = (L.Expr(L.Call(func_name, [L.Name(node.key)])), ) code = L.insert_rel_maint(code, call_code, L.SetRemove()) return code
def Tuple_helper(self, node): if not L.is_tuple_of_names(node): raise L.ProgramError('Non-simple tuple expression: {}' .format(node)) elts = L.detuplify(node) name = self.tuple_namer(elts) clause = L.TUPMember(name, elts) self.objrels.TUPs.append(len(elts)) self.after_clauses.insert(0, clause) return name
def process(expr): if not (isinstance(expr, L.Member) and isinstance(expr.target, L.Name)): return expr, [], [] if isinstance(expr.iter, L.Unwrap): rhs = expr.iter.value else: rhs = L.Wrap(expr.iter) cl = L.VarsMember([expr.target.id], rhs) return cl, [], []
def aggrinv_from_query(symtab, query, result_var): """Determine the aggregate invariant info for a given query.""" node = query.node assert isinstance(node, (L.Aggr, L.AggrRestr)) oper = node.value op = node.op if isinstance(oper, L.Unwrap): unwrap = True oper = oper.value else: unwrap = False # Get rel, mask, and param info. if isinstance(oper, L.Name): rel = oper.id # Mask will be all-unbound, filled in below. mask = None params = () elif (isinstance(oper, L.ImgLookup) and isinstance(oper.set, L.Name)): rel = oper.set.id mask = oper.mask params = oper.bounds else: raise L.ProgramError('Unknown aggregate form: {}'.format(node)) # Lookup symbol, use type info to determine the relation's arity. t_rel = get_rel_type(symtab, rel) if not (isinstance(t_rel, T.Set) and isinstance(t_rel.elt, T.Tuple)): raise L.ProgramError( 'Invalid type for aggregate operand: {}'.format(t_rel)) arity = len(t_rel.elt.elts) if mask is None: mask = L.mask('u' * arity) else: # Confirm that this arity is consistent with the above mask. assert len(mask.m) == arity if isinstance(node, L.AggrRestr): # Check that the restriction parameters match the ImgLookup # parameters if node.params != params: raise L.TransformationError('AggrRestr params do not match ' 'ImgLookup params') if not isinstance(node.restr, L.Name): raise L.ProgramError('Bad AggrRestr restriction expr') restr = node.restr.id else: restr = None return AggrInvariant(result_var, op, rel, mask, unwrap, params, restr)
def visit_Fun(self, node): node = self.generic_visit(node) if node.name in func_costs: cost = func_costs[node.name] simp_cost = rewrite_cost_using_types(cost, symtab) cost_str = PrettyPrinter.run(cost) simp_cost_str = PrettyPrinter.run(simp_cost) comment = (L.Comment('Cost: O({})'.format(cost_str)), L.Comment(' O({})'.format(simp_cost_str))) node = node._replace(body=comment + node.body) return node
def visit_RelUpdate(self, node): if not isinstance(node.op, (L.SetAdd, L.SetRemove)): return node if node.rel not in self.rels: return node op_name = L.set_update_name(node.op) func_name = N.get_maint_func_name(self.result_var, node.rel, op_name) code = (node,) call_code = (L.Expr(L.Call(func_name, [L.Name(node.elem)])),) code = L.insert_rel_maint(code, call_code, node.op) return code
def visit_RelUpdate(self, node): if not isinstance(node.op, (L.SetAdd, L.SetRemove)): return node if node.rel not in self.rels: return node op_name = L.set_update_name(node.op) func_name = N.get_maint_func_name(self.result_var, node.rel, op_name) code = (node, ) call_code = (L.Expr(L.Call(func_name, [L.Name(node.elem)])), ) code = L.insert_rel_maint(code, call_code, node.op) return code
def rewrite_resexp_with_params(self, comp, params): """Assuming the result expression is a tuple expression, rewrite it to prepend components for the given parameter variables. """ lhs_vars = self.lhs_vars_from_comp(comp) assert set(params).issubset(set(lhs_vars)), \ 'params: {}, lhs_vars: {}'.format(params, lhs_vars) assert isinstance(comp.resexp, L.Tuple) new_resexp = L.Tuple( tuple(L.Name(p) for p in params) + comp.resexp.elts) return comp._replace(resexp=new_resexp)
def get_code(self, cl, bindenv, body): vars = self.lhs_vars(cl) assert_unique(vars) mask = L.mask_from_bounds(vars, bindenv) comparison = L.Compare(L.Name(cl.tup), L.Eq(), L.tuplify(cl.elts)) if L.mask_is_allbound(mask): code = (L.If(comparison, body, ()),) needs_typecheck = True elif mask.m.startswith('b'): elts_mask = L.mask_from_bounds(cl.elts, bindenv) code = L.bind_by_mask(elts_mask, cl.elts, L.Name(cl.tup)) if L.mask_is_allunbound(elts_mask): code += body else: code += (L.If(comparison, body, ()),) needs_typecheck = True elif mask == L.mask('u' + 'b' * len(cl.elts)): code = (L.Assign(cl.tup, L.tuplify(cl.elts)),) code += body needs_typecheck = False else: raise L.TransformationError('Cannot emit code for TUP clause ' 'that would require an auxiliary ' 'map; use demand filtering') if needs_typecheck and self.use_typecheck: code = (L.If(L.HasArity(L.Name(cl.tup), len(cl.elts)), code, ()),) return code
def visit_RelUpdate(self, node): if not isinstance(node.op, (L.SetAdd, L.SetRemove)): return node if node.rel == self.aggrinv.rel: func = self.aggrinv.get_oper_maint_func_name(node.op) code = L.insert_rel_maint_call(node, func) elif self.aggrinv.uses_demand and node.rel == self.aggrinv.restr: func = self.aggrinv.get_restr_maint_func_name(node.op) code = L.insert_rel_maint_call(node, func) else: code = node return code
def visit_RelClear(self, node): code = (node,) auxmaps = self.auxmaps_by_rel.get(node.rel, set()) for auxmap in auxmaps: clear_code = (L.MapClear(auxmap.map),) code = L.insert_rel_maint(code, clear_code, L.SetRemove()) wraps = self.wraps_by_rel.get(node.rel, set()) for wrap in wraps: clear_code = (L.RelClear(wrap.rel),) code = L.insert_rel_maint(code, clear_code, L.SetRemove()) return code
def transform_source(input_source, *, options=None, query_options=None): """Take in the Python source code to a module and return the transformed source code and the symbol table. """ tree = P.Parser.p(input_source) t1 = process_time() tree, symtab = transform_ast(tree, options=options, query_options=query_options) t2 = process_time() source = P.Parser.ts(tree) # All good human beings have trailing newlines in their # text files. source = source + '\n' symtab.stats['lines'] = get_loc_source(source) # L.tree_size() is for IncASTs, but it should also work for # Python ASTs. We have to re-parse the source to get rid of # our Comment pseudo-nodes. tree = P.Parser.p(source) symtab.stats['ast_nodes'] = L.tree_size(tree) symtab.stats['time'] = t2 - t1 return source, symtab
def flatten_memberships(comp): """Transform the comprehension to rewrite set memberships (Member nodes) as MMember clauses. Return an ObjRelations indicating whether an M set is needed. """ M = False def process(clause): nonlocal M if isinstance(clause, L.Member): # MMember. if (isinstance(clause.target, L.Name) and isinstance(clause.iter, L.Name)): set_ = clause.iter.id elem = clause.target.id M = True clause = L.MMember(set_, elem) # Subquery clause, leave as Member for now. elif (isinstance(clause.target, L.Name) and isinstance(clause.iter, L.Unwrap)): pass else: raise L.ProgramError('Cannot flatten Member clause: {}' .format(clause)) return clause, [], [] tree = L.rewrite_comp(comp, process) objrels = ObjRelations(M, [], False, []) return tree, objrels
def make_auxmap_type(auxmapinv, reltype): """Given a mask and a relation type, determine the corresponding auxiliary map type. We obtain by lattice join the smallest relation type that is at least as big as the given relation type and that has the correct arity. This should have the form {(T1, ..., Tn)}. The map type is then from a tuple of some Ts to a set of tuples of the remaining Ts. If no such type exists, e.g. if the given relation type is {Top} or a set of tuples of incorrect arity, we instead give the map type {Top: Top}. """ mask = auxmapinv.mask arity = len(mask.m) bottom_reltype = T.Set(T.Tuple([T.Bottom] * arity)) top_reltype = T.Set(T.Tuple([T.Top] * arity)) norm_type = reltype.join(bottom_reltype) well_typed = norm_type.issmaller(top_reltype) if well_typed: assert (isinstance(norm_type, T.Set) and isinstance(norm_type.elt, T.Tuple) and len(norm_type.elt.elts) == arity) t_bs, t_us = L.split_by_mask(mask, norm_type.elt.elts) t_key = t_bs[0] if auxmapinv.unwrap_key else T.Tuple(t_bs) t_value = t_us[0] if auxmapinv.unwrap_value else T.Tuple(t_us) map_type = T.Map(t_key, T.Set(t_value)) else: map_type = T.Map(T.Top, T.Top) return map_type
def is_duplicate_safe(clausetools, comp): """Return whether we can rule out duplicates analytically.""" if not L.is_injective(comp.resexp): return False vars = L.IdentFinder.find_vars(comp.resexp) determined = clausetools.all_vars_determined(comp.clauses, vars) return determined
def make_setfrommap_type(mask, maptype): """Given a mask and a map type, determine the corresponding relation type. We obtain by lattice join the smallest map type that is at least as big as the given map type and that has the correct key tuple arity. This should have the form {(K1, ..., Kn): V}. The relation type is then a set of tuples of these types interleaved according to the mask. If no such type exists, e.g. if the given relation type is {Top: Top} or the key is not a tuple of correct arity, we instead give the relation type {Top}. """ nb = mask.m.count('b') assert mask.m.count('u') == 1 bottom_maptype = T.Map(T.Tuple([T.Bottom] * nb), T.Bottom) top_maptype = T.Map(T.Tuple([T.Top] * nb), T.Top) norm_type = maptype.join(bottom_maptype) well_typed = norm_type.issmaller(top_maptype) if well_typed: assert (isinstance(norm_type, T.Map) and isinstance(norm_type.key, T.Tuple) and len(norm_type.key.elts) == nb) t_elts = L.combine_by_mask(mask, norm_type.key.elts, [norm_type.value]) rel_type = T.Set(T.Tuple(t_elts)) else: rel_type = T.Set(T.Top) return rel_type
def is_join(self, comp): lhs_vars = self.lhs_vars_from_clauses(comp.clauses) try: res_vars = L.detuplify(comp.resexp) except ValueError: return False return sorted(res_vars) == sorted(lhs_vars)
def make_setfrommap_maint_func(fresh_vars, setfrommap: SetFromMapInvariant, op: str): mask = setfrommap.mask nb = L.break_mapmask(mask) # Fresh variables for components of the key and value. key_vars = N.get_subnames('_key', nb) decomp_code = (L.DecompAssign(key_vars, L.Name('_key')),) vars = L.combine_by_mask(mask, key_vars, ['_val']) elem = L.tuplify(vars) fresh_var_prefix = next(fresh_vars) elem_var = fresh_var_prefix + '_elem' decomp_code += (L.Assign(elem_var, elem),) setopcls = {'assign': L.SetAdd, 'delete': L.SetRemove}[op] update_code = (L.RelUpdate(setfrommap.rel, setopcls(), elem_var),) func_name = setfrommap.get_maint_func_name(op) if op == 'assign': func = L.Parser.ps(''' def _FUNC(_key, _val): _DECOMP _UPDATE ''', subst={'_FUNC': func_name, '<c>_DECOMP': decomp_code, '<c>_UPDATE': update_code}) elif op == 'delete': lookup_expr = L.DictLookup(L.Name(setfrommap.map), L.Name('_key'), None) func = L.Parser.ps(''' def _FUNC(_key): _val = _LOOKUP _DECOMP _UPDATE ''', subst={'_FUNC': func_name, '_LOOKUP': lookup_expr, '<c>_DECOMP': decomp_code, '<c>_UPDATE': update_code}) else: assert() return func
def visit_RelClear(self, node): if node.rel not in self.rels: return node code = (node,) clear_code = (L.RelClear(self.result_var),) code = L.insert_rel_maint(code, clear_code, L.SetRemove()) return code
def get_code(self, cl, bindenv, body): assert_unique(cl.vars) mask = L.mask_from_bounds(cl.vars, bindenv) check_eq = L.Compare(L.tuplify(cl.vars), L.Eq(), cl.value) if L.mask_is_allbound(mask): code = (L.If(check_eq, body, ()),) elif L.mask_is_allunbound(mask): code = (L.DecompAssign(cl.vars, cl.value),) code += body else: code = L.bind_by_mask(mask, cl.vars, cl.value) code += (L.If(check_eq, body, ()),) return code
def match_eq_cond(tree): """If tree is a condition clause with form <var> == <var>, return a pair of the variables. Otherwise return None. """ result = L.match(eq_cond_pattern, tree) if result is None: return None else: return result['LEFT'], result['RIGHT']
def get_code(self, cl, bindenv, body): lhs_vars = self.visitor.lhs_vars(cl) new_body = L.Parser.pc(''' if _VARS != _VALUE: _BODY ''', subst={'_VARS': L.tuplify(lhs_vars), '_VALUE': cl.value, '<c>_BODY': body}) return self.visitor.get_code(cl.cl, bindenv, new_body)
def visit_MapClear(self, node): sfm = self.setfrommaps_by_map.get(node.map, None) if sfm is None: return node code = (node,) clear_code = (L.RelClear(sfm.rel),) code = L.insert_rel_maint(code, clear_code, L.SetRemove()) return code
def visit_RelUpdate(self, node): if not isinstance(node.op, (L.SetAdd, L.SetRemove)): return node code = (node,) auxmaps = self.auxmaps_by_rel.get(node.rel, set()) for auxmap in auxmaps: func_name = auxmap.get_maint_func_name(node.op) call_code = (L.Expr(L.Call(func_name, [L.Name(node.elem)])),) code = L.insert_rel_maint(code, call_code, node.op) wraps = self.wraps_by_rel.get(node.rel, set()) for wrap in wraps: func_name = wrap.get_maint_func_name(node.op) call_code = (L.Expr(L.Call(func_name, [L.Name(node.elem)])),) code = L.insert_rel_maint(code, call_code, node.op) return code
def visit_MapDelete(self, node): sfm = self.setfrommaps_by_map.get(node.map, None) if sfm is None: return node code = (node,) func_name = sfm.get_maint_func_name('delete') call_code = (L.Expr(L.Call(func_name, [L.Name(node.key)])),) code = L.insert_rel_maint(code, call_code, L.SetRemove()) return code