def get_code(self, cl, bindenv, body): vars = self.lhs_vars(cl) assert_unique(vars) mask = L.mask_from_bounds(vars, bindenv) lookup_expr = L.DictLookup(L.Name(cl.map), L.Name(cl.key), None) if L.mask_is_allbound(mask): comparison = L.Compare(L.Name(cl.value), L.Eq(), lookup_expr) code = (L.If(comparison, body, ()), ) needs_typecheck = True elif mask == L.mask('bbu'): code = (L.Assign(cl.value, lookup_expr), ) code += body needs_typecheck = True elif mask == L.mask('buu'): items_expr = L.Parser.pe('_MAP.items()', subst={'_MAP': cl.map}) code = (L.DecompFor([cl.key, cl.value], items_expr, body), ) needs_typecheck = True else: code = super().get_code(cl, bindenv, body) needs_typecheck = False if needs_typecheck and self.use_typecheck: code = (L.If(L.IsMap(L.Name(cl.map)), code, ()), ) return code
def visit_DictDelete(self, node): if not self.objrels.MAP: return lookup = L.DictLookup(node.target, node.key, None) triple = L.Tuple([node.target, node.key, lookup]) var = next(self.fresh_vars) return (L.Assign(var, triple), L.RelUpdate(N.MAP, L.SetRemove(), var))
def make_setfrommap_maint_func(fresh_vars, setfrommap: SetFromMapInvariant, op: str): mask = setfrommap.mask nb = L.break_mapmask(mask) # Fresh variables for components of the key and value. key_vars = N.get_subnames('_key', nb) decomp_code = (L.DecompAssign(key_vars, L.Name('_key')), ) vars = L.combine_by_mask(mask, key_vars, ['_val']) elem = L.tuplify(vars) fresh_var_prefix = next(fresh_vars) elem_var = fresh_var_prefix + '_elem' decomp_code += (L.Assign(elem_var, elem), ) setopcls = {'assign': L.SetAdd, 'delete': L.SetRemove}[op] update_code = (L.RelUpdate(setfrommap.rel, setopcls(), elem_var), ) func_name = setfrommap.get_maint_func_name(op) if op == 'assign': func = L.Parser.ps(''' def _FUNC(_key, _val): _DECOMP _UPDATE ''', subst={ '_FUNC': func_name, '<c>_DECOMP': decomp_code, '<c>_UPDATE': update_code }) elif op == 'delete': lookup_expr = L.DictLookup(L.Name(setfrommap.map), L.Name('_key'), None) func = L.Parser.ps(''' def _FUNC(_key): _val = _LOOKUP _DECOMP _UPDATE ''', subst={ '_FUNC': func_name, '_LOOKUP': lookup_expr, '<c>_DECOMP': decomp_code, '<c>_UPDATE': update_code }) else: assert () return func
def incrementalize_aggr(tree, symtab, query, result_var): # Form the invariant. aggrinv = aggrinv_from_query(symtab, query, result_var) handler = aggrinv.get_handler() # Transform to maintain it. trans = AggrMaintainer(symtab.fresh_names.vars, aggrinv) tree = trans.process(tree) symtab.maint_funcs.update(trans.maint_funcs) # Transform occurrences of the aggregate. zero = None if aggrinv.uses_demand else handler.make_zero_expr() state_expr = L.DictLookup(L.Name(aggrinv.map), L.tuplify(aggrinv.params), zero) lookup_expr = handler.make_projection_expr(state_expr) class AggrExpander(S.QueryRewriter): expand = True def rewrite_aggr(self, symbol, name, expr): if name == query.name: return lookup_expr tree = AggrExpander.run(tree, symtab) # Determine the result map's type and define its symbol. t_rel = get_rel_type(symtab, aggrinv.rel) btypes, _ = L.split_by_mask(aggrinv.mask, t_rel.elt.elts) t_key = T.Tuple(btypes) t_val = handler.result_type(t_rel) t_map = T.Map(t_key, t_val) symtab.define_map(aggrinv.map, type=t_map) symtab.stats['aggrs_transformed'] += 1 return tree