def rewrite_comp(self, symbol, name, comp): if name == query.name: if query.params == (): return L.Name(result_var) else: mask = L.keymask_from_len(len(query.params), orig_arity) return L.ImgLookup(L.Name(result_var), mask, query.params)
def get_code(self, cl, bindenv, body): vars = self.lhs_vars(cl) assert_unique(vars) mask = L.mask_from_bounds(vars, bindenv) comparison = L.Compare(L.Name(cl.tup), L.Eq(), L.tuplify(cl.elts)) if L.mask_is_allbound(mask): code = (L.If(comparison, body, ()), ) needs_typecheck = True elif mask.m.startswith('b'): elts_mask = L.mask_from_bounds(cl.elts, bindenv) code = L.bind_by_mask(elts_mask, cl.elts, L.Name(cl.tup)) if L.mask_is_allunbound(elts_mask): code += body else: code += (L.If(comparison, body, ()), ) needs_typecheck = True elif mask == L.mask('u' + 'b' * len(cl.elts)): code = (L.Assign(cl.tup, L.tuplify(cl.elts)), ) code += body needs_typecheck = False else: raise L.TransformationError('Cannot emit code for TUP clause ' 'that would require an auxiliary ' 'map; use demand filtering') if needs_typecheck and self.use_typecheck: code = (L.If(L.HasArity(L.Name(cl.tup), len(cl.elts)), code, ()), ) return code
def get_code(self, cl, bindenv, body): vars = self.lhs_vars(cl) assert_unique(vars) mask = L.mask_from_bounds(vars, bindenv) lookup_expr = L.DictLookup(L.Name(cl.map), L.Name(cl.key), None) if L.mask_is_allbound(mask): comparison = L.Compare(L.Name(cl.value), L.Eq(), lookup_expr) code = (L.If(comparison, body, ()), ) needs_typecheck = True elif mask == L.mask('bbu'): code = (L.Assign(cl.value, lookup_expr), ) code += body needs_typecheck = True elif mask == L.mask('buu'): items_expr = L.Parser.pe('_MAP.items()', subst={'_MAP': cl.map}) code = (L.DecompFor([cl.key, cl.value], items_expr, body), ) needs_typecheck = True else: code = super().get_code(cl, bindenv, body) needs_typecheck = False if needs_typecheck and self.use_typecheck: code = (L.If(L.IsMap(L.Name(cl.map)), code, ()), ) return code
def visit_MapAssign(self, node): sfm = self.setfrommaps_by_map.get(node.map, None) if sfm is None: return node code = (node, ) func_name = sfm.get_maint_func_name('assign') call_code = (L.Expr( L.Call(func_name, [L.Name(node.key), L.Name(node.value)])), ) code = L.insert_rel_maint(code, call_code, L.SetAdd()) return code
def make_setfrommap_maint_func(fresh_vars, setfrommap: SetFromMapInvariant, op: str): mask = setfrommap.mask nb = L.break_mapmask(mask) # Fresh variables for components of the key and value. key_vars = N.get_subnames('_key', nb) decomp_code = (L.DecompAssign(key_vars, L.Name('_key')), ) vars = L.combine_by_mask(mask, key_vars, ['_val']) elem = L.tuplify(vars) fresh_var_prefix = next(fresh_vars) elem_var = fresh_var_prefix + '_elem' decomp_code += (L.Assign(elem_var, elem), ) setopcls = {'assign': L.SetAdd, 'delete': L.SetRemove}[op] update_code = (L.RelUpdate(setfrommap.rel, setopcls(), elem_var), ) func_name = setfrommap.get_maint_func_name(op) if op == 'assign': func = L.Parser.ps(''' def _FUNC(_key, _val): _DECOMP _UPDATE ''', subst={ '_FUNC': func_name, '<c>_DECOMP': decomp_code, '<c>_UPDATE': update_code }) elif op == 'delete': lookup_expr = L.DictLookup(L.Name(setfrommap.map), L.Name('_key'), None) func = L.Parser.ps(''' def _FUNC(_key): _val = _LOOKUP _DECOMP _UPDATE ''', subst={ '_FUNC': func_name, '_LOOKUP': lookup_expr, '<c>_DECOMP': decomp_code, '<c>_UPDATE': update_code }) else: assert () return func
def visit_DictLookup(self, node): node = self.generic_visit(node) # Only simple map lookups are allowed. assert isinstance(node.value, L.Name) assert L.is_tuple_of_names(node.key) assert node.default is None map = node.value.id keyvars = L.detuplify(node.key) var = self.repls.get(node, None) if var is None: mask = L.mapmask_from_len(len(keyvars)) rel = N.SA_name(map, mask) # Create a fresh variable. self.repls[node] = var = next(self.fresh_names) # Construct a clause to bind it. vars = list(keyvars) + [var] new_clause = L.SetFromMapMember(vars, rel, map, mask) self.new_clauses.append(new_clause) # Construct a corresponding SetFromMap invariant. sfm = SetFromMapInvariant(rel, map, mask) self.sfm_invs.add(sfm) return L.Name(var)
def rewrite_with_demand(self, query_sym, node): """Given a query symbol and its associated Comp or Aggr node, return the demand-transformed version of that node (not transforming any subqueries). """ symtab = self.symtab demand_params = query_sym.demand_params if not query_sym.uses_demand: return node # Make a demand set or demand query. left_clauses = self.get_left_clauses() if left_clauses is None: dem_sym = make_demand_set(symtab, query_sym) dem_node = L.Name(dem_sym.name) dem_clause = L.RelMember(demand_params, dem_sym.name) self.queries_with_usets.add(query_sym.name) else: dem_sym = make_demand_query(symtab, query_sym, left_clauses) dem_node = dem_sym.make_node() dem_clause = L.VarsMember(demand_params, dem_node) self.demand_queries.add(dem_sym.name) # Determine the rewritten node. if isinstance(node, L.Comp): node = node._replace(clauses=(dem_clause, ) + node.clauses) elif isinstance(node, L.Aggr): node = L.AggrRestr(node.op, node.value, demand_params, dem_node) else: raise AssertionError( 'No rule for handling demand for {} node'.format( node.__class__.__name__)) return node
def visit_RelUpdate(self, node): if not isinstance(node.op, (L.SetAdd, L.SetRemove)): return node code = (node, ) auxmaps = self.auxmaps_by_rel.get(node.rel, set()) for auxmap in auxmaps: func_name = auxmap.get_maint_func_name(node.op) call_code = (L.Expr(L.Call(func_name, [L.Name(node.elem)])), ) code = L.insert_rel_maint(code, call_code, node.op) wraps = self.wraps_by_rel.get(node.rel, set()) for wrap in wraps: func_name = wrap.get_maint_func_name(node.op) call_code = (L.Expr(L.Call(func_name, [L.Name(node.elem)])), ) code = L.insert_rel_maint(code, call_code, node.op) return code
def visit_MapDelete(self, node): sfm = self.setfrommaps_by_map.get(node.map, None) if sfm is None: return node code = (node, ) func_name = sfm.get_maint_func_name('delete') call_code = (L.Expr(L.Call(func_name, [L.Name(node.key)])), ) code = L.insert_rel_maint(code, call_code, L.SetRemove()) return code
def make_update_state_code(self, prefix, state, op, value): value = L.Name(value) if isinstance(op, L.SetAdd): template = ''' _STATE = (index(_STATE, 0) + _VALUE, index(_STATE, 1) + 1) ''' elif isinstance(op, L.SetRemove): template = ''' _STATE = (index(_STATE, 0) - _VALUE, index(_STATE, 1) - 1) ''' return L.Parser.pc(template, subst={'_STATE': state, '_VALUE': value})
def visit_RelUpdate(self, node): if not isinstance(node.op, (L.SetAdd, L.SetRemove)): return node if node.rel not in self.rels: return node op_name = L.set_update_name(node.op) func_name = N.get_maint_func_name(self.result_var, node.rel, op_name) code = (node, ) call_code = (L.Expr(L.Call(func_name, [L.Name(node.elem)])), ) code = L.insert_rel_maint(code, call_code, node.op) return code
def rewrite_resexp_with_params(self, comp, params): """Assuming the result expression is a tuple expression, rewrite it to prepend components for the given parameter variables. """ lhs_vars = self.lhs_vars_from_comp(comp) assert set(params).issubset(set(lhs_vars)), \ 'params: {}, lhs_vars: {}'.format(params, lhs_vars) assert isinstance(comp.resexp, L.Tuple) new_resexp = L.Tuple( tuple(L.Name(p) for p in params) + comp.resexp.elts) return comp._replace(resexp=new_resexp)
def get_code(self, cl, bindenv, body): vars = self.lhs_vars(cl) assert_unique(vars) mask = L.mask_from_bounds(vars, bindenv) if L.mask_is_allbound(mask): comparison = L.Compare(L.Name(cl.elem), L.In(), L.Name(cl.set)) code = (L.If(comparison, body, ()), ) needs_typecheck = True elif mask == L.mask('bu'): code = (L.For(cl.elem, L.Name(cl.set), body), ) needs_typecheck = True else: code = super().get_code(cl, bindenv, body) needs_typecheck = False if needs_typecheck and self.use_typecheck: code = (L.If(L.IsSet(L.Name(cl.set)), code, ()), ) return code
def get_code(self, cl, bindenv, body): vars = self.lhs_vars(cl) rel = self.rhs_rel(cl) assert_unique(vars) mask = L.mask_from_bounds(vars, bindenv) if L.mask_is_allbound(mask): comparison = L.Compare(L.tuplify(vars), L.In(), L.Name(rel)) code = (L.If(comparison, body, ()), ) elif L.mask_is_allunbound(mask): code = (L.DecompFor(vars, L.Name(rel), body), ) else: bvars, uvars = L.split_by_mask(mask, vars) lookup = L.ImgLookup(L.Name(rel), mask, bvars) # Optimize in the case where there's only one unbound. if len(uvars) == 1: code = (L.For(uvars[0], L.Unwrap(lookup), body), ) else: code = (L.DecompFor(uvars, lookup, body), ) return code
def wrap_helper(self, node): """Process a wrap invariant at a Wrap or Unwrap node. Don't recurse. """ if not isinstance(node.value, L.Name): return node rel = node.value.id wraps = self.wraps_by_rel.get(rel, []) for wrap in wraps: if ((isinstance(node, L.Wrap) and not wrap.unwrap) or (isinstance(node, L.Unwrap) and wrap.unwrap)): return L.Name(wrap.rel) return node
def get_code(self, cl, bindenv, body): vars = self.lhs_vars(cl) assert_unique(vars) mask = L.mask_from_bounds(vars, bindenv) if L.mask_is_allbound(mask): comparison = L.Compare(L.Name(cl.value), L.Eq(), L.Attribute(L.Name(cl.obj), cl.attr)) code = (L.If(comparison, body, ()), ) needs_typecheck = True elif mask == L.mask('bu'): code = (L.Assign(cl.value, L.Attribute(L.Name(cl.obj), cl.attr)), ) code += body needs_typecheck = True else: code = super().get_code(cl, bindenv, body) needs_typecheck = False if needs_typecheck and self.use_typecheck: code = (L.If(L.HasField(L.Name(cl.obj), cl.attr), code, ()), ) return code
def visit_SetFromMap(self, node): node = self.generic_visit(node) if not isinstance(node.map, L.Name): return node map = node.map.id sfm = self.setfrommaps_by_map.get(map, None) if sfm is None: return node if not sfm.mask == node.mask: raise L.ProgramError('Multiple SetFromMap expressions on ' 'same map {}'.format(map)) return L.Name(sfm.rel)
def process(expr): if not (isinstance(expr, L.Member) and isinstance(expr.iter, L.Name) and expr.iter.id in symtab.get_relations()): return expr, [], [] target = expr.target rel = expr.iter.id if L.is_tuple_of_names(target): cl = L.RelMember(L.detuplify(target), rel) elif isinstance(target, L.Name): cl = L.VarsMember([target.id], L.Wrap(L.Name(rel))) else: raise L.ProgramError('Invalid clause over relation') return cl, [], []
def make_auxmap_maint_func(fresh_vars, auxmap: AuxmapInvariant, op: L.setupop): """Make maintenance function for auxiliary map.""" # Fresh variables for components of the element. vars = N.get_subnames('_elem', len(auxmap.mask.m)) decomp_code = (L.DecompAssign(vars, L.Name('_elem')), ) key, value = L.split_by_mask(auxmap.mask, vars) key = L.tuplify(key, unwrap=auxmap.unwrap_key) value = L.tuplify(value, unwrap=auxmap.unwrap_value) fresh_var_prefix = next(fresh_vars) key_var = fresh_var_prefix + '_key' value_var = fresh_var_prefix + '_value' decomp_code += L.Parser.pc(''' _KEY_VAR = _KEY _VALUE_VAR = _VALUE ''', subst={ '_KEY_VAR': key_var, '_KEY': key, '_VALUE_VAR': value_var, '_VALUE': value }) img_func = { L.SetAdd: make_imgadd, L.SetRemove: make_imgremove }[op.__class__] img_code = img_func(fresh_vars, auxmap.map, key_var, value_var) func_name = auxmap.get_maint_func_name(op) func = L.Parser.ps(''' def _FUNC(_elem): _DECOMP _IMGCODE ''', subst={ '_FUNC': func_name, '<c>_DECOMP': decomp_code, '<c>_IMGCODE': img_code }) return func
def get_maint_code(self, fresh_var_prefix, fresh_join_names, comp, result_var, update, *, selfjoin=SelfJoin.Without, counted): """Given a comprehension (not necessarily a join) and an update to a relation, return the maintenance code -- i.e., the update to the stored result variable looped for each maintenance join. If counted is False, generate non-counted set updates. """ assert isinstance(update, L.RelUpdate) assert isinstance(update.op, (L.SetAdd, L.SetRemove)) result_elem_var = fresh_var_prefix + '_result' # Prefix LHS vars in the comp to guarantee fresh names for their # use in maintenance code. renamer = lambda x: fresh_var_prefix + '_' + x comp = self.comp_rename_lhs_vars(comp, renamer) body = () body += (L.Assign(result_elem_var, comp.resexp), ) body += L.rel_update(result_var, update.op, result_elem_var, counted=counted) join = self.make_join_from_comp(comp) maint_joins = self.get_maint_join_union(join, update.rel, L.Name(update.elem), selfjoin=selfjoin) code = () for maint_join in maint_joins: join_name = next(fresh_join_names) code += self.get_loop_for_join(maint_join, body, join_name) return code
def visit_RelUpdate(self, node): if isinstance(node.op, L.SetAdd): is_add = True elif isinstance(node.op, L.SetRemove): is_add = False else: return rel = node.rel elem = L.Name(node.elem) if N.is_M(rel): set_ = L.Subscript(elem, L.Num(0)) value = L.Subscript(elem, L.Num(1)) code = (L.SetUpdate(set_, node.op, value),) elif N.is_F(rel): attr = N.get_F(rel) obj = L.Subscript(elem, L.Num(0)) value = L.Subscript(elem, L.Num(1)) if is_add: code = (L.AttrAssign(obj, attr, value),) else: code = (L.AttrDelete(obj, attr),) elif N.is_MAP(rel): map = L.Subscript(elem, L.Num(0)) key = L.Subscript(elem, L.Num(1)) value = L.Subscript(elem, L.Num(2)) if is_add: code = (L.DictAssign(map, key, value),) else: code = (L.DictDelete(map, key),) else: code = node return code
def incrementalize_aggr(tree, symtab, query, result_var): # Form the invariant. aggrinv = aggrinv_from_query(symtab, query, result_var) handler = aggrinv.get_handler() # Transform to maintain it. trans = AggrMaintainer(symtab.fresh_names.vars, aggrinv) tree = trans.process(tree) symtab.maint_funcs.update(trans.maint_funcs) # Transform occurrences of the aggregate. zero = None if aggrinv.uses_demand else handler.make_zero_expr() state_expr = L.DictLookup(L.Name(aggrinv.map), L.tuplify(aggrinv.params), zero) lookup_expr = handler.make_projection_expr(state_expr) class AggrExpander(S.QueryRewriter): expand = True def rewrite_aggr(self, symbol, name, expr): if name == query.name: return lookup_expr tree = AggrExpander.run(tree, symtab) # Determine the result map's type and define its symbol. t_rel = get_rel_type(symtab, aggrinv.rel) btypes, _ = L.split_by_mask(aggrinv.mask, t_rel.elt.elts) t_key = T.Tuple(btypes) t_val = handler.result_type(t_rel) t_map = T.Map(t_key, t_val) symtab.define_map(aggrinv.map, type=t_map) symtab.stats['aggrs_transformed'] += 1 return tree
def visit_MapDelete(self, node): return L.DictDelete(L.Name(node.map), L.Name(node.key))
def visit_MapClear(self, node): return L.DictClear(L.Name(node.map))
def visit_RelMember(self, node): return L.Member(L.tuplify(node.vars), L.Name(node.rel))
def make_eq_cond(left, right): """Make a condition of form <var> == <var>.""" return L.Cond(L.Compare(L.Name(left), L.Eq(), L.Name(right)))
'make_eq_cond', 'SelfJoin', 'ClauseTools', 'CoreClauseTools', ] from enum import Enum from incoq.util.seq import zip_strict from incoq.util.collections import OrderedSet, Partitioning from incoq.compiler.incast import L from .clause import ClauseVisitor, CoreClauseVisitor, Kind, ShouldFilter eq_cond_pattern = L.Cond( L.Compare(L.Name(L.PatVar('LEFT')), L.Eq(), L.Name(L.PatVar('RIGHT')))) def match_eq_cond(tree): """If tree is a condition clause with form <var> == <var>, return a pair of the variables. Otherwise return None. """ result = L.match(eq_cond_pattern, tree) if result is None: return None else: return result['LEFT'], result['RIGHT'] def make_eq_cond(left, right): """Make a condition of form <var> == <var>."""
def rewrite_comp(self, symbol, name, comp): if name == query.name: return L.Call(func_name, [L.Name(p) for p in query.params])
def make_aggr_restr_maint_func(fresh_vars, aggrinv, op): """Make the maintenance function for an aggregate invariant and an update to its restriction set. """ assert isinstance(op, (L.SetAdd, L.SetRemove)) assert aggrinv.uses_demand if isinstance(op, L.SetAdd): fresh_var_prefix = next(fresh_vars) value = fresh_var_prefix + '_value' state = fresh_var_prefix + '_state' keyvars = N.get_subnames('_key', len(aggrinv.params)) decomp_key_code = (L.DecompAssign(keyvars, L.Name('_key')), ) rellookup = L.ImgLookup(L.Name(aggrinv.rel), aggrinv.mask, keyvars) handler = aggrinv.get_handler() zero = handler.make_zero_expr() updatestate_code = handler.make_update_state_code( fresh_var_prefix, state, op, value) if aggrinv.unwrap: loop_template = ''' for (_VALUE,) in _RELLOOKUP: _UPDATESTATE ''' else: loop_template = ''' for _VALUE in _RELLOOKUP: _UPDATESTATE ''' loop_code = L.Parser.pc(loop_template, subst={ '_VALUE': value, '_RELLOOKUP': rellookup, '<c>_UPDATESTATE': updatestate_code }) maint_code = L.Parser.pc(''' _STATE = _ZERO _DECOMP_KEY _LOOP _MAP.mapassign(_KEY, _STATE) ''', subst={ '_MAP': aggrinv.map, '_KEY': '_key', '_STATE': state, '_ZERO': zero, '<c>_DECOMP_KEY': decomp_key_code, '<c>_LOOP': loop_code }) else: maint_code = L.Parser.pc(''' _MAP.mapdelete(_KEY) ''', subst={ '_MAP': aggrinv.map, '_KEY': '_key' }) func_name = aggrinv.get_restr_maint_func_name(op) func = L.Parser.ps(''' def _FUNC(_key): _MAINT ''', subst={ '_FUNC': func_name, '<c>_MAINT': maint_code }) return func
def make_aggr_oper_maint_func(fresh_vars, aggrinv, op): """Make the maintenance function for an aggregate invariant and a given set update operation (add or remove) to the operand. """ assert isinstance(op, (L.SetAdd, L.SetRemove)) # Decompose the argument tuple into key and value components, # just like in auxmap.py. vars = N.get_subnames('_elem', len(aggrinv.mask.m)) kvars, vvars = L.split_by_mask(aggrinv.mask, vars) ktuple = L.tuplify(kvars) vtuple = L.tuplify(vvars) fresh_var_prefix = next(fresh_vars) key = fresh_var_prefix + '_key' value = fresh_var_prefix + '_value' state = fresh_var_prefix + '_state' if aggrinv.unwrap: assert len(vvars) == 1 value_expr = L.Name(vvars[0]) else: value_expr = vtuple # Logic specific to aggregate operator. handler = aggrinv.get_handler() zero = handler.make_zero_expr() updatestate_code = handler.make_update_state_code(fresh_var_prefix, state, op, value) subst = { '_KEY': key, '_KEY_EXPR': ktuple, '_VALUE': value, '_VALUE_EXPR': value_expr, '_MAP': aggrinv.map, '_STATE': state, '_ZERO': zero } if aggrinv.uses_demand: subst['_RESTR'] = aggrinv.restr else: # Empty conditions are only used when we don't have a # restriction set. subst['_EMPTY'] = handler.make_empty_cond(state) decomp_code = (L.DecompAssign(vars, L.Name('_elem')), ) decomp_code += L.Parser.pc(''' _KEY = _KEY_EXPR _VALUE = _VALUE_EXPR ''', subst=subst) # Determine what kind of get/set state code to generate. if isinstance(op, L.SetAdd): definitely_preexists = aggrinv.uses_demand setstate_mayremove = False elif isinstance(op, L.SetRemove): definitely_preexists = True setstate_mayremove = not aggrinv.uses_demand else: assert () if definitely_preexists: getstate_template = '_STATE = _MAP[_KEY]' delstate_template = '_MAP.mapdelete(_KEY)' else: getstate_template = '_STATE = _MAP.get(_KEY, _ZERO)' delstate_template = ''' if _KEY in _MAP: _MAP.mapdelete(_KEY) ''' if setstate_mayremove: setstate_template = ''' if not _EMPTY: _MAP.mapassign(_KEY, _STATE) ''' else: setstate_template = '_MAP.mapassign(_KEY, _STATE)' getstate_code = L.Parser.pc(getstate_template, subst=subst) delstate_code = L.Parser.pc(delstate_template, subst=subst) setstate_code = L.Parser.pc(setstate_template, subst=subst) maint_code = (getstate_code + updatestate_code + delstate_code + setstate_code) # Guard in test if we have a restriction set. if aggrinv.uses_demand: maint_subst = dict(subst) maint_subst['<c>_MAINT'] = maint_code maint_code = L.Parser.pc(''' if _KEY in _RESTR: _MAINT ''', subst=maint_subst) func_name = aggrinv.get_oper_maint_func_name(op) func = L.Parser.ps(''' def _FUNC(_elem): _DECOMP _MAINT ''', subst={ '_FUNC': func_name, '<c>_DECOMP': decomp_code, '<c>_MAINT': maint_code }) return func