def get_code(self, cl, bindenv, body): vars = self.lhs_vars(cl) assert_unique(vars) mask = L.mask_from_bounds(vars, bindenv) comparison = L.Compare(L.Name(cl.tup), L.Eq(), L.tuplify(cl.elts)) if L.mask_is_allbound(mask): code = (L.If(comparison, body, ()), ) needs_typecheck = True elif mask.m.startswith('b'): elts_mask = L.mask_from_bounds(cl.elts, bindenv) code = L.bind_by_mask(elts_mask, cl.elts, L.Name(cl.tup)) if L.mask_is_allunbound(elts_mask): code += body else: code += (L.If(comparison, body, ()), ) needs_typecheck = True elif mask == L.mask('u' + 'b' * len(cl.elts)): code = (L.Assign(cl.tup, L.tuplify(cl.elts)), ) code += body needs_typecheck = False else: raise L.TransformationError('Cannot emit code for TUP clause ' 'that would require an auxiliary ' 'map; use demand filtering') if needs_typecheck and self.use_typecheck: code = (L.If(L.HasArity(L.Name(cl.tup), len(cl.elts)), code, ()), ) return code
def visit_DecompFor(self, node): if (isinstance(node.iter, L.Name) and node.iter.id in self.rels): if len(node.vars) != 1: raise L.TransformationError( 'Singleton unwrapping requires all DecompFor loops ' 'over relation to have exactly one target variable') return L.For(node.vars[0], node.iter, node.body) return node
def define_wrap_set(wrapinv, symtab): """Add a relation definition for a WrapInvariant.""" opersym = symtab.get_relations().get(wrapinv.oper, None) if opersym is None: raise L.TransformationError('No relation "{}" matching wrapped/' 'unwrapped relation "{}"'.format( wrapinv.oper, wrapinv.rel)) rel_type = make_wrap_type(wrapinv, opersym.type) symtab.define_relation(wrapinv.rel, type=rel_type)
def define_set(setfrommap, symtab): """Add a relation definition to the symbol table.""" # Obtain map symbol. mapsym = symtab.get_maps().get(setfrommap.map, None) if mapsym is None: raise L.TransformationError( 'No map "{}" matching relation "{}"'.format( setfrommap.map, setfrommap.rel)) rel_type = make_setfrommap_type(setfrommap.mask, mapsym.type) symtab.define_relation(setfrommap.rel, type=rel_type)
def define_map(auxmap, symtab): """Add a map definition to the symbol table.""" # Obtain relation symbol. relsym = symtab.get_relations().get(auxmap.rel, None) if relsym is None: raise L.TransformationError( 'No relation "{}" matching map "{}"'.format( auxmap.rel, auxmap.map)) map_type = make_auxmap_type(auxmap, relsym.type) symtab.define_map(auxmap.map, type=map_type)
def visit_Query(self, node): node = super().generic_visit(node) name = node.name this_occ = node.query # The first time we see a query, obtain the symbol, check # for consistency with the symbol, call the rewriter, and # update the symbol. if name not in self.queries_before: sym = self.symtab.get_queries().get(name, None) if sym is None: raise L.TransformationError('No symbol for query name "{}"' .format(name)) if this_occ != sym.node: raise L.TransformationError( 'Inconsistent symbol and occurrence for query ' '"{}": {}, {}'.format(name, sym.node, this_occ)) replacement = self.rewrite(sym, name, this_occ) self.queries_before[name] = this_occ self.queries_after[name] = replacement if replacement is not None and not self.expand: sym.node = replacement # Each subsequent time, check for consistency with the previous # occurrences, and reuse the replacement that was determined the # first time. else: prev_occ = self.queries_before[name] # Check for consistency with previous occurrences. if this_occ != prev_occ: raise L.TransformationError( 'Inconsistent occurrences for query "{}": {}, {}' .format(name, prev_occ, this_occ)) replacement = self.queries_after[name] if replacement is not None: if self.expand: node = replacement else: node = node._replace(query=replacement) return node
def aggrinv_from_query(symtab, query, result_var): """Determine the aggregate invariant info for a given query.""" node = query.node assert isinstance(node, (L.Aggr, L.AggrRestr)) oper = node.value op = node.op if isinstance(oper, L.Unwrap): unwrap = True oper = oper.value else: unwrap = False # Get rel, mask, and param info. if isinstance(oper, L.Name): rel = oper.id # Mask will be all-unbound, filled in below. mask = None params = () elif (isinstance(oper, L.ImgLookup) and isinstance(oper.set, L.Name)): rel = oper.set.id mask = oper.mask params = oper.bounds else: raise L.ProgramError('Unknown aggregate form: {}'.format(node)) # Lookup symbol, use type info to determine the relation's arity. t_rel = get_rel_type(symtab, rel) if not (isinstance(t_rel, T.Set) and isinstance(t_rel.elt, T.Tuple)): raise L.ProgramError( 'Invalid type for aggregate operand: {}'.format(t_rel)) arity = len(t_rel.elt.elts) if mask is None: mask = L.mask('u' * arity) else: # Confirm that this arity is consistent with the above mask. assert len(mask.m) == arity if isinstance(node, L.AggrRestr): # Check that the restriction parameters match the ImgLookup # parameters if node.params != params: raise L.TransformationError('AggrRestr params do not match ' 'ImgLookup params') if not isinstance(node.restr, L.Name): raise L.ProgramError('Bad AggrRestr restriction expr') restr = node.restr.id else: restr = None return AggrInvariant(result_var, op, rel, mask, unwrap, params, restr)
def order_clauses(clausevisitor, clauses): """Order clauses according to their reported priorities. Use a greedy heuristic: Choose the leftmost clause whose priority is best (lowest number). Return the ordered clauses as a list. """ init = OrderState(clausevisitor, set(), [], OrderedSet(clauses)) try: answer = Planner().get_greedy_answer(init) except ValueError: s = ', '.join(L.Parser.ts(cl) for cl in clauses) raise L.TransformationError( 'No valid order found for clauses: {}'.format(s)) return answer
def get_rel_type(symtab, rel): """Helper for returning a relation's element type.""" # This helper is used below, but it should probably be refactored # into a general helper in the type subpackage. relsym = symtab.get_symbols().get(rel, None) if relsym is None: raise L.TransformationError( 'No symbol info for operand relation {}'.format(rel)) t_rel = relsym.type t_rel = t_rel.join(T.Set(T.Bottom)) if not t_rel.issmaller(T.Set(T.Top)): raise L.ProgramError('Bad type for relation {}: {}'.format(rel, t_rel)) # Treat Set<Bottom> as a set of singleton tuples. if t_rel.elt is T.Bottom: raise L.ProgramError( 'Relation must have known tuple element type ' 'before it can be used in aggregate: {}'.format(rel)) return t_rel
def make_structs(self): """Populate the structures based on the comprehension.""" assert len(self.structs) == 0 ct = self.symtab.clausetools # Generate for each clause. for i, cl in enumerate(self.comp.clauses): if ct.kind(cl) is not Kind.Member: continue vars = ct.lhs_vars(cl) in_vars = ct.tagsin_lhs_vars(cl) out_vars = ct.tagsout_lhs_vars(cl) rel = ct.rhs_rel(cl) if rel is None: raise L.TransformationError('Cannot generate tags and filter ' 'for clause: {}'.format(cl)) # Generate a filter for this clause. if i != 0: n = len(self.filters_by_rel[rel]) + 1 name = self.get_filter_name(rel, n) preds = self.get_preds(i, in_vars) if ct.filter_needs_preds(cl) and len(preds) == 0: raise L.ProgramError('No predecessors tags for filter ' 'over clause: {}'.format( L.Parser.ts(cl))) filter = Filter(i, name, cl, preds) self.add_struct(filter) filtered_cl = self.clause_over_filter[name] else: # First clause acts as its own filter; no new structure. filtered_cl = cl # Generate a tag for each variable on the LHS. for var in out_vars: n = len(self.tags_by_var[var]) + 1 name = self.get_tag_name(var, n) tag = Tag(i, name, var, filtered_cl) self.add_struct(tag)
def badnode(self, node): raise L.TransformationError('{} nodes should not be present when ' 'converting to the pair domain'.format( node.__class__.__name__))