def rewrite_comp(self, symbol, name, comp): if name == query.name: if query.params == (): return L.Name(result_var) else: mask = L.keymask_from_len(len(query.params), orig_arity) return L.ImgLookup(L.Name(result_var), mask, query.params)
def get_code(self, cl, bindenv, body): vars = self.lhs_vars(cl) rel = self.rhs_rel(cl) assert_unique(vars) mask = L.mask_from_bounds(vars, bindenv) if L.mask_is_allbound(mask): comparison = L.Compare(L.tuplify(vars), L.In(), L.Name(rel)) code = (L.If(comparison, body, ()), ) elif L.mask_is_allunbound(mask): code = (L.DecompFor(vars, L.Name(rel), body), ) else: bvars, uvars = L.split_by_mask(mask, vars) lookup = L.ImgLookup(L.Name(rel), mask, bvars) # Optimize in the case where there's only one unbound. if len(uvars) == 1: code = (L.For(uvars[0], L.Unwrap(lookup), body), ) else: code = (L.DecompFor(uvars, lookup, body), ) return code
def make_aggr_restr_maint_func(fresh_vars, aggrinv, op): """Make the maintenance function for an aggregate invariant and an update to its restriction set. """ assert isinstance(op, (L.SetAdd, L.SetRemove)) assert aggrinv.uses_demand if isinstance(op, L.SetAdd): fresh_var_prefix = next(fresh_vars) value = fresh_var_prefix + '_value' state = fresh_var_prefix + '_state' keyvars = N.get_subnames('_key', len(aggrinv.params)) decomp_key_code = (L.DecompAssign(keyvars, L.Name('_key')), ) rellookup = L.ImgLookup(L.Name(aggrinv.rel), aggrinv.mask, keyvars) handler = aggrinv.get_handler() zero = handler.make_zero_expr() updatestate_code = handler.make_update_state_code( fresh_var_prefix, state, op, value) if aggrinv.unwrap: loop_template = ''' for (_VALUE,) in _RELLOOKUP: _UPDATESTATE ''' else: loop_template = ''' for _VALUE in _RELLOOKUP: _UPDATESTATE ''' loop_code = L.Parser.pc(loop_template, subst={ '_VALUE': value, '_RELLOOKUP': rellookup, '<c>_UPDATESTATE': updatestate_code }) maint_code = L.Parser.pc(''' _STATE = _ZERO _DECOMP_KEY _LOOP _MAP.mapassign(_KEY, _STATE) ''', subst={ '_MAP': aggrinv.map, '_KEY': '_key', '_STATE': state, '_ZERO': zero, '<c>_DECOMP_KEY': decomp_key_code, '<c>_LOOP': loop_code }) else: maint_code = L.Parser.pc(''' _MAP.mapdelete(_KEY) ''', subst={ '_MAP': aggrinv.map, '_KEY': '_key' }) func_name = aggrinv.get_restr_maint_func_name(op) func = L.Parser.ps(''' def _FUNC(_key): _MAINT ''', subst={ '_FUNC': func_name, '<c>_MAINT': maint_code }) return func