def visit_Query(self, node): self.generic_visit(node) if node.ann is not None: if not isinstance(node.ann, (dict, frozendict)): raise L.ProgramError('Query annotation must be a ' 'dictionary') for key in node.ann.keys(): if key not in S.annotations: raise L.ProgramError( 'Unknown annotation key "{}"'.format(key))
def aggrinv_from_query(symtab, query, result_var): """Determine the aggregate invariant info for a given query.""" node = query.node assert isinstance(node, (L.Aggr, L.AggrRestr)) oper = node.value op = node.op if isinstance(oper, L.Unwrap): unwrap = True oper = oper.value else: unwrap = False # Get rel, mask, and param info. if isinstance(oper, L.Name): rel = oper.id # Mask will be all-unbound, filled in below. mask = None params = () elif (isinstance(oper, L.ImgLookup) and isinstance(oper.set, L.Name)): rel = oper.set.id mask = oper.mask params = oper.bounds else: raise L.ProgramError('Unknown aggregate form: {}'.format(node)) # Lookup symbol, use type info to determine the relation's arity. t_rel = get_rel_type(symtab, rel) if not (isinstance(t_rel, T.Set) and isinstance(t_rel.elt, T.Tuple)): raise L.ProgramError( 'Invalid type for aggregate operand: {}'.format(t_rel)) arity = len(t_rel.elt.elts) if mask is None: mask = L.mask('u' * arity) else: # Confirm that this arity is consistent with the above mask. assert len(mask.m) == arity if isinstance(node, L.AggrRestr): # Check that the restriction parameters match the ImgLookup # parameters if node.params != params: raise L.TransformationError('AggrRestr params do not match ' 'ImgLookup params') if not isinstance(node.restr, L.Name): raise L.ProgramError('Bad AggrRestr restriction expr') restr = node.restr.id else: restr = None return AggrInvariant(result_var, op, rel, mask, unwrap, params, restr)
def visit_Call(self, node): node = self.generic_visit(node) if node.func == 'len': if not len(node.args) == 1: raise L.ProgramError('Expected one argument for len()') return L.Aggr(L.Count(), node.args[0]) return node
def define_symbol(self, name, kind, **kargs): """Define a new symbol of the given kind. Return the symbol.""" if name in self.symbols: raise L.ProgramError('Symbol "{}" already defined'.format(name)) sym = kind(name, **kargs) self.symbols[name] = sym return sym
def transform_auxmaps_stepper(tree, symtab): """Transform all set expressions we can find that are over Name nodes. Return the tree and whether any transformation was done. """ auxmaps, setfrommaps, wraps = InvariantFinder.run(tree) if len(auxmaps) == len(setfrommaps) == len(wraps) == 0: return tree, False for auxmap in auxmaps: if auxmap.rel not in symtab.get_relations(): raise L.ProgramError('Cannot make auxiliary map for image-set ' 'lookup over non-relation variable {}'.format( auxmap.rel)) for auxmap in auxmaps: define_map(auxmap, symtab) for sfm in setfrommaps: define_set(sfm, symtab) for wrap in wraps: define_wrap_set(wrap, symtab) trans = InvariantTransformer(symtab.fresh_names.vars, auxmaps, setfrommaps, wraps) tree = trans.process(tree) symtab.maint_funcs.update(trans.maint_funcs) symtab.stats['auxmaps_transformed'] += len(auxmaps) return tree, True
def __init__(self, fresh_vars, auxmaps, setfrommaps, wraps): super().__init__() self.fresh_vars = fresh_vars self.maint_funcs = OrderedSet() # Index over auxmaps for fast retrieval. self.auxmaps_by_rel = OrderedDict() self.auxmaps_by_relmask = OrderedDict() for auxmap in auxmaps: self.auxmaps_by_rel.setdefault(auxmap.rel, []).append(auxmap) # Index by relation, mask, and whether value is a singleton. # Not indexed by whether key is a singleton; if there are # two separate invariants that differ only on that, we may # end up only using one but maintaining both. stats = (auxmap.rel, auxmap.mask, auxmap.unwrap_value) self.auxmaps_by_relmask[stats] = auxmap # Index over setfrommaps. self.setfrommaps_by_map = sfm_by_map = OrderedDict() for sfm in setfrommaps: if sfm.map in sfm_by_map: raise L.ProgramError('Multiple SetFromMap invariants on ' 'same map {}'.format(sfm.map)) sfm_by_map[sfm.map] = sfm self.wraps_by_rel = OrderedDict() for wrap in wraps: self.wraps_by_rel.setdefault(wrap.oper, []).append(wrap)
def preprocess_var_decls(tree): """Eliminate global variable declarations of the form S = Set() M = Map() and return a list of pairs of variables names and type names (i.e., 'Set' or 'Map'). """ assert isinstance(tree, P.Module) pat = P.Assign([P.Name(P.PatVar('_VAR'), P.Wildcard())], P.Call(P.Name(P.PatVar('_KIND'), P.Load()), [], [], None, None)) decls = OrderedSet() body = [] changed = False for stmt in tree.body: match = P.match(pat, stmt) if match is not None: var, kind = match['_VAR'], match['_KIND'] if kind not in ['Set', 'Map']: raise L.ProgramError( 'Unknown declaration initializer {}'.format(kind)) decls.add((var, kind)) changed = True else: body.append(stmt) if changed: tree = tree._replace(body=body) return tree, list(decls)
def apply_symconfig(self, name, info): """Given a symbol name and a key-value dictionary of symbol config attribute, apply the attributes. """ if name not in self.symbols: raise L.ProgramError('No symbol "{}"'.format(name)) sym = self.symbols[name] sym.parse_and_update(self, **info)
def get_rel_type(symtab, rel): """Helper for returning a relation's element type.""" # This helper is used below, but it should probably be refactored # into a general helper in the type subpackage. relsym = symtab.get_symbols().get(rel, None) if relsym is None: raise L.TransformationError( 'No symbol info for operand relation {}'.format(rel)) t_rel = relsym.type t_rel = t_rel.join(T.Set(T.Bottom)) if not t_rel.issmaller(T.Set(T.Top)): raise L.ProgramError('Bad type for relation {}: {}'.format(rel, t_rel)) # Treat Set<Bottom> as a set of singleton tuples. if t_rel.elt is T.Bottom: raise L.ProgramError( 'Relation must have known tuple element type ' 'before it can be used in aggregate: {}'.format(rel)) return t_rel
def rewrite_comp(self, symbol, name, comp): if symbol.clause_reorder is not None: indices = [i - 1 for i in symbol.clause_reorder] if not sorted(indices) == list(range(len(comp.clauses))): raise L.ProgramError( 'Bad clause_reorder list for query: {}, {}'.format( name, symbol.clause_reorder)) clauses = [comp.clauses[i] for i in indices] comp = comp._replace(clauses=clauses) return comp
def process(self, tree): self.found = set() tree = super().process(tree) notfound = set(self.query_name_map.keys()) - self.found if self.strict and len(notfound) > 0: qstrs = [L.Parser.ts(query) for query in notfound] raise L.ProgramError( 'No matching occurrence for queries: {}'.format( ', '.join(qstrs))) return tree
def process(expr): if not (isinstance(expr, L.Member) and isinstance(expr.iter, L.Name) and expr.iter.id in symtab.get_relations()): return expr, [], [] target = expr.target rel = expr.iter.id if L.is_tuple_of_names(target): cl = L.RelMember(L.detuplify(target), rel) elif isinstance(target, L.Name): cl = L.VarsMember([target.id], L.Wrap(L.Name(rel))) else: raise L.ProgramError('Invalid clause over relation') return cl, [], []
def visit_SetFromMap(self, node): node = self.generic_visit(node) if not isinstance(node.map, L.Name): return node map = node.map.id sfm = self.setfrommaps_by_map.get(map, None) if sfm is None: return node if not sfm.mask == node.mask: raise L.ProgramError('Multiple SetFromMap expressions on ' 'same map {}'.format(map)) return L.Name(sfm.rel)
def visit_Query(self, node): self.generic_visit(node) if self.ignore is not None and node.name in self.ignore: return if self.first: raise self.Done((node.name, node.query)) # Otherwise... if node.name in self.queries: if node.query != self.queries[node.name]: raise L.ProgramError('Multiple inconsistent expressions for ' 'query {}'.format(node.name)) else: self.queries[node.name] = node.query
def determine_demand_params(self, node): # Skip if already processed. if node.name in self.processed: return symtab = self.symtab ct = symtab.clausetools query_sym = symtab.get_queries()[node.name] if query_sym.impl is S.Aux: # No demand for aux impl. demand_params = () uses_demand = False elif isinstance(node.query, L.Comp): demand_params = determine_comp_demand_params( ct, node.query, query_sym.params, query_sym.demand_params, query_sym.demand_param_strat) uses_demand = len(demand_params) > 0 elif isinstance(node.query, L.Aggr): # If the operand contains a demand-driven query, or if the # aggregate appears in a transformed comprehension, then the # aggregate must be demand-driven for all parameters. operand_uses_demand = False aggr_in_comp = self.get_left_clauses() is not None for q in find_nested_queries(node): inner_sym = symtab.get_queries()[q.name] if inner_sym.uses_demand: operand_uses_demand = True if operand_uses_demand or aggr_in_comp: uses_demand = True demand_params = query_sym.params else: uses_demand = False demand_params = () else: kind = query_sym.node.__class__.__name__ raise L.ProgramError('No rule for analyzing parameters of ' '{} query'.format(kind)) query_sym.uses_demand = uses_demand query_sym.demand_params = demand_params self.processed.add(query_sym.name)
def make_demand_func(query): func = N.get_query_demand_func_name(query.name) uset = N.get_query_demand_set_name(query.name) maxsize = query.demand_set_maxsize if maxsize is None: code = L.Parser.ps(''' def _FUNC(_elem): if _elem not in _U: _U.reladd(_elem) ''', subst={ '_FUNC': func, '_U': uset }) elif maxsize == 1: code = L.Parser.ps(''' def _FUNC(_elem): if _elem not in _U: _U.relclear() _U.reladd(_elem) ''', subst={ '_FUNC': func, '_U': uset }) elif not isinstance(maxsize, int) or maxsize <= 0: raise L.ProgramError('Invalid value for demand_set_maxsize') else: code = L.Parser.ps(''' def _FUNC(_elem): if _elem not in _U: while len(_U) >= _MAXSIZE: _stale = _U.peek() _U.relremove(_stale) _U.reladd(_elem) ''', subst={ '_FUNC': func, '_U': uset, '_MAXSIZE': L.Num(maxsize) }) return code
def transform_query(tree, symtab, query): assert query.impl is not S.Unspecified if isinstance(query.node, L.Comp): if query.impl == S.Normal: success = False elif query.impl == S.Aux: tree = transform_aux_comp_query(tree, symtab, query) success = True elif query.impl == S.Inc: tree = transform_comp_query(tree, symtab, query) success = True elif query.impl == S.Filtered: symtab.print(' Comp: ' + L.Parser.ts(query.node)) tree = transform_comp_query_with_filtering(tree, symtab, query) success = True else: assert () elif isinstance(query.node, (L.Aggr, L.AggrRestr)): if query.impl == S.Normal: success = False elif query.impl == S.Inc: result_var = N.A_name(query.name) tree = incrementalize_aggr(tree, symtab, query, result_var) # Transform any aggregate map lookups inside comprehensions, # including incrementalizing the SetFromMaps used in the # additional comp clauses. tree = transform_comps_with_maps(tree, symtab) success = True else: assert () else: raise L.ProgramError('Unknown query kind: {}'.format( query.node.__class__.__name__)) return tree, success
def make_structs(self): """Populate the structures based on the comprehension.""" assert len(self.structs) == 0 ct = self.symtab.clausetools # Generate for each clause. for i, cl in enumerate(self.comp.clauses): if ct.kind(cl) is not Kind.Member: continue vars = ct.lhs_vars(cl) in_vars = ct.tagsin_lhs_vars(cl) out_vars = ct.tagsout_lhs_vars(cl) rel = ct.rhs_rel(cl) if rel is None: raise L.TransformationError('Cannot generate tags and filter ' 'for clause: {}'.format(cl)) # Generate a filter for this clause. if i != 0: n = len(self.filters_by_rel[rel]) + 1 name = self.get_filter_name(rel, n) preds = self.get_preds(i, in_vars) if ct.filter_needs_preds(cl) and len(preds) == 0: raise L.ProgramError('No predecessors tags for filter ' 'over clause: {}'.format( L.Parser.ts(cl))) filter = Filter(i, name, cl, preds) self.add_struct(filter) filtered_cl = self.clause_over_filter[name] else: # First clause acts as its own filter; no new structure. filtered_cl = cl # Generate a tag for each variable on the LHS. for var in out_vars: n = len(self.tags_by_var[var]) + 1 name = self.get_tag_name(var, n) tag = Tag(i, name, var, filtered_cl) self.add_struct(tag)
def visit_Query(self, node): ct = self.symtab.clausetools querysym = self.symtab.get_queries()[node.name] cache = self.query_param_map # Analyze parameters from node. params = self.get_params(node) # If we've already analyzed this query, just confirm that this # occurrence's parameters match what we're expecting. if node.name in cache: if params != cache[node.name]: raise L.ProgramError('Inconsistent parameter info for query ' '{}: {}, {}'.format( querysym.name, cache[node.name], params)) # Otherwise, add to the cache and update the symbol. else: cache[node.name] = params querysym.params = params self.generic_visit(node)
def visit_SetFromMap(self, node): raise L.ProgramError( 'Invalid SetFromMap expression: {}'.format(node))
def visit_ImgLookup(self, node): raise L.ProgramError( 'Invalid ImgLookup expression: {}'.format(node))
def __init__(self, *args, **kargs): if self.unwrap and self.mask.m.count('u') != 1: raise L.ProgramError('Aggregates of unwrap() expressions must ' 'have arity 1')