Beispiel #1
0
 def rewrite_comp(self, symbol, name, comp):
     if name == query.name:
         if query.params == ():
             return L.Name(result_var)
         else:
             mask = L.keymask_from_len(len(query.params), orig_arity)
             return L.ImgLookup(L.Name(result_var), mask, query.params)
Beispiel #2
0
    def get_code(self, cl, bindenv, body):
        vars = self.lhs_vars(cl)
        assert_unique(vars)
        mask = L.mask_from_bounds(vars, bindenv)

        comparison = L.Compare(L.Name(cl.tup), L.Eq(), L.tuplify(cl.elts))

        if L.mask_is_allbound(mask):
            code = (L.If(comparison, body, ()), )
            needs_typecheck = True

        elif mask.m.startswith('b'):
            elts_mask = L.mask_from_bounds(cl.elts, bindenv)
            code = L.bind_by_mask(elts_mask, cl.elts, L.Name(cl.tup))
            if L.mask_is_allunbound(elts_mask):
                code += body
            else:
                code += (L.If(comparison, body, ()), )
            needs_typecheck = True

        elif mask == L.mask('u' + 'b' * len(cl.elts)):
            code = (L.Assign(cl.tup, L.tuplify(cl.elts)), )
            code += body
            needs_typecheck = False

        else:
            raise L.TransformationError('Cannot emit code for TUP clause '
                                        'that would require an auxiliary '
                                        'map; use demand filtering')

        if needs_typecheck and self.use_typecheck:
            code = (L.If(L.HasArity(L.Name(cl.tup), len(cl.elts)), code, ()), )

        return code
Beispiel #3
0
    def get_code(self, cl, bindenv, body):
        vars = self.lhs_vars(cl)
        assert_unique(vars)
        mask = L.mask_from_bounds(vars, bindenv)

        lookup_expr = L.DictLookup(L.Name(cl.map), L.Name(cl.key), None)

        if L.mask_is_allbound(mask):
            comparison = L.Compare(L.Name(cl.value), L.Eq(), lookup_expr)
            code = (L.If(comparison, body, ()), )
            needs_typecheck = True

        elif mask == L.mask('bbu'):
            code = (L.Assign(cl.value, lookup_expr), )
            code += body
            needs_typecheck = True

        elif mask == L.mask('buu'):
            items_expr = L.Parser.pe('_MAP.items()', subst={'_MAP': cl.map})
            code = (L.DecompFor([cl.key, cl.value], items_expr, body), )
            needs_typecheck = True

        else:
            code = super().get_code(cl, bindenv, body)
            needs_typecheck = False

        if needs_typecheck and self.use_typecheck:
            code = (L.If(L.IsMap(L.Name(cl.map)), code, ()), )

        return code
Beispiel #4
0
    def visit_MapAssign(self, node):
        sfm = self.setfrommaps_by_map.get(node.map, None)
        if sfm is None:
            return node

        code = (node, )
        func_name = sfm.get_maint_func_name('assign')
        call_code = (L.Expr(
            L.Call(func_name,
                   [L.Name(node.key), L.Name(node.value)])), )
        code = L.insert_rel_maint(code, call_code, L.SetAdd())
        return code
Beispiel #5
0
def make_setfrommap_maint_func(fresh_vars, setfrommap: SetFromMapInvariant,
                               op: str):
    mask = setfrommap.mask
    nb = L.break_mapmask(mask)
    # Fresh variables for components of the key and value.
    key_vars = N.get_subnames('_key', nb)

    decomp_code = (L.DecompAssign(key_vars, L.Name('_key')), )

    vars = L.combine_by_mask(mask, key_vars, ['_val'])
    elem = L.tuplify(vars)
    fresh_var_prefix = next(fresh_vars)
    elem_var = fresh_var_prefix + '_elem'

    decomp_code += (L.Assign(elem_var, elem), )

    setopcls = {'assign': L.SetAdd, 'delete': L.SetRemove}[op]
    update_code = (L.RelUpdate(setfrommap.rel, setopcls(), elem_var), )

    func_name = setfrommap.get_maint_func_name(op)

    if op == 'assign':
        func = L.Parser.ps('''
            def _FUNC(_key, _val):
                _DECOMP
                _UPDATE
            ''',
                           subst={
                               '_FUNC': func_name,
                               '<c>_DECOMP': decomp_code,
                               '<c>_UPDATE': update_code
                           })
    elif op == 'delete':
        lookup_expr = L.DictLookup(L.Name(setfrommap.map), L.Name('_key'),
                                   None)
        func = L.Parser.ps('''
            def _FUNC(_key):
                _val = _LOOKUP
                _DECOMP
                _UPDATE
            ''',
                           subst={
                               '_FUNC': func_name,
                               '_LOOKUP': lookup_expr,
                               '<c>_DECOMP': decomp_code,
                               '<c>_UPDATE': update_code
                           })
    else:
        assert ()

    return func
Beispiel #6
0
    def visit_DictLookup(self, node):
        node = self.generic_visit(node)

        # Only simple map lookups are allowed.
        assert isinstance(node.value, L.Name)
        assert L.is_tuple_of_names(node.key)
        assert node.default is None
        map = node.value.id
        keyvars = L.detuplify(node.key)

        var = self.repls.get(node, None)
        if var is None:
            mask = L.mapmask_from_len(len(keyvars))
            rel = N.SA_name(map, mask)

            # Create a fresh variable.
            self.repls[node] = var = next(self.fresh_names)

            # Construct a clause to bind it.
            vars = list(keyvars) + [var]
            new_clause = L.SetFromMapMember(vars, rel, map, mask)
            self.new_clauses.append(new_clause)

            # Construct a corresponding SetFromMap invariant.
            sfm = SetFromMapInvariant(rel, map, mask)
            self.sfm_invs.add(sfm)

        return L.Name(var)
Beispiel #7
0
    def rewrite_with_demand(self, query_sym, node):
        """Given a query symbol and its associated Comp or Aggr node,
        return the demand-transformed version of that node (not
        transforming any subqueries).
        """
        symtab = self.symtab
        demand_params = query_sym.demand_params

        if not query_sym.uses_demand:
            return node

        # Make a demand set or demand query.
        left_clauses = self.get_left_clauses()
        if left_clauses is None:
            dem_sym = make_demand_set(symtab, query_sym)
            dem_node = L.Name(dem_sym.name)
            dem_clause = L.RelMember(demand_params, dem_sym.name)
            self.queries_with_usets.add(query_sym.name)
        else:
            dem_sym = make_demand_query(symtab, query_sym, left_clauses)
            dem_node = dem_sym.make_node()
            dem_clause = L.VarsMember(demand_params, dem_node)
            self.demand_queries.add(dem_sym.name)

        # Determine the rewritten node.
        if isinstance(node, L.Comp):
            node = node._replace(clauses=(dem_clause, ) + node.clauses)
        elif isinstance(node, L.Aggr):
            node = L.AggrRestr(node.op, node.value, demand_params, dem_node)
        else:
            raise AssertionError(
                'No rule for handling demand for {} node'.format(
                    node.__class__.__name__))

        return node
Beispiel #8
0
    def visit_RelUpdate(self, node):
        if not isinstance(node.op, (L.SetAdd, L.SetRemove)):
            return node

        code = (node, )

        auxmaps = self.auxmaps_by_rel.get(node.rel, set())
        for auxmap in auxmaps:
            func_name = auxmap.get_maint_func_name(node.op)
            call_code = (L.Expr(L.Call(func_name, [L.Name(node.elem)])), )
            code = L.insert_rel_maint(code, call_code, node.op)

        wraps = self.wraps_by_rel.get(node.rel, set())
        for wrap in wraps:
            func_name = wrap.get_maint_func_name(node.op)
            call_code = (L.Expr(L.Call(func_name, [L.Name(node.elem)])), )
            code = L.insert_rel_maint(code, call_code, node.op)

        return code
Beispiel #9
0
    def visit_MapDelete(self, node):
        sfm = self.setfrommaps_by_map.get(node.map, None)
        if sfm is None:
            return node

        code = (node, )
        func_name = sfm.get_maint_func_name('delete')
        call_code = (L.Expr(L.Call(func_name, [L.Name(node.key)])), )
        code = L.insert_rel_maint(code, call_code, L.SetRemove())
        return code
Beispiel #10
0
 def make_update_state_code(self, prefix, state, op, value):
     value = L.Name(value)
     if isinstance(op, L.SetAdd):
         template = '''
             _STATE = (index(_STATE, 0) + _VALUE, index(_STATE, 1) + 1)
             '''
     elif isinstance(op, L.SetRemove):
         template = '''
             _STATE = (index(_STATE, 0) - _VALUE, index(_STATE, 1) - 1)
             '''
     return L.Parser.pc(template, subst={'_STATE': state, '_VALUE': value})
Beispiel #11
0
    def visit_RelUpdate(self, node):
        if not isinstance(node.op, (L.SetAdd, L.SetRemove)):
            return node
        if node.rel not in self.rels:
            return node

        op_name = L.set_update_name(node.op)
        func_name = N.get_maint_func_name(self.result_var, node.rel, op_name)

        code = (node, )
        call_code = (L.Expr(L.Call(func_name, [L.Name(node.elem)])), )
        code = L.insert_rel_maint(code, call_code, node.op)
        return code
Beispiel #12
0
    def rewrite_resexp_with_params(self, comp, params):
        """Assuming the result expression is a tuple expression,
        rewrite it to prepend components for the given parameter
        variables.
        """
        lhs_vars = self.lhs_vars_from_comp(comp)
        assert set(params).issubset(set(lhs_vars)), \
            'params: {}, lhs_vars: {}'.format(params, lhs_vars)
        assert isinstance(comp.resexp, L.Tuple)

        new_resexp = L.Tuple(
            tuple(L.Name(p) for p in params) + comp.resexp.elts)
        return comp._replace(resexp=new_resexp)
Beispiel #13
0
    def get_code(self, cl, bindenv, body):
        vars = self.lhs_vars(cl)
        assert_unique(vars)
        mask = L.mask_from_bounds(vars, bindenv)

        if L.mask_is_allbound(mask):
            comparison = L.Compare(L.Name(cl.elem), L.In(), L.Name(cl.set))
            code = (L.If(comparison, body, ()), )
            needs_typecheck = True

        elif mask == L.mask('bu'):
            code = (L.For(cl.elem, L.Name(cl.set), body), )
            needs_typecheck = True

        else:
            code = super().get_code(cl, bindenv, body)
            needs_typecheck = False

        if needs_typecheck and self.use_typecheck:
            code = (L.If(L.IsSet(L.Name(cl.set)), code, ()), )

        return code
Beispiel #14
0
    def get_code(self, cl, bindenv, body):
        vars = self.lhs_vars(cl)
        rel = self.rhs_rel(cl)
        assert_unique(vars)
        mask = L.mask_from_bounds(vars, bindenv)

        if L.mask_is_allbound(mask):
            comparison = L.Compare(L.tuplify(vars), L.In(), L.Name(rel))
            code = (L.If(comparison, body, ()), )

        elif L.mask_is_allunbound(mask):
            code = (L.DecompFor(vars, L.Name(rel), body), )

        else:
            bvars, uvars = L.split_by_mask(mask, vars)
            lookup = L.ImgLookup(L.Name(rel), mask, bvars)
            # Optimize in the case where there's only one unbound.
            if len(uvars) == 1:
                code = (L.For(uvars[0], L.Unwrap(lookup), body), )
            else:
                code = (L.DecompFor(uvars, lookup, body), )

        return code
Beispiel #15
0
    def wrap_helper(self, node):
        """Process a wrap invariant at a Wrap or Unwrap node.
        Don't recurse.
        """
        if not isinstance(node.value, L.Name):
            return node
        rel = node.value.id

        wraps = self.wraps_by_rel.get(rel, [])
        for wrap in wraps:
            if ((isinstance(node, L.Wrap) and not wrap.unwrap)
                    or (isinstance(node, L.Unwrap) and wrap.unwrap)):
                return L.Name(wrap.rel)

        return node
Beispiel #16
0
    def get_code(self, cl, bindenv, body):
        vars = self.lhs_vars(cl)
        assert_unique(vars)
        mask = L.mask_from_bounds(vars, bindenv)

        if L.mask_is_allbound(mask):
            comparison = L.Compare(L.Name(cl.value), L.Eq(),
                                   L.Attribute(L.Name(cl.obj), cl.attr))
            code = (L.If(comparison, body, ()), )
            needs_typecheck = True

        elif mask == L.mask('bu'):
            code = (L.Assign(cl.value, L.Attribute(L.Name(cl.obj), cl.attr)), )
            code += body
            needs_typecheck = True

        else:
            code = super().get_code(cl, bindenv, body)
            needs_typecheck = False

        if needs_typecheck and self.use_typecheck:
            code = (L.If(L.HasField(L.Name(cl.obj), cl.attr), code, ()), )

        return code
Beispiel #17
0
    def visit_SetFromMap(self, node):
        node = self.generic_visit(node)

        if not isinstance(node.map, L.Name):
            return node
        map = node.map.id

        sfm = self.setfrommaps_by_map.get(map, None)
        if sfm is None:
            return node

        if not sfm.mask == node.mask:
            raise L.ProgramError('Multiple SetFromMap expressions on '
                                 'same map {}'.format(map))
        return L.Name(sfm.rel)
Beispiel #18
0
    def process(expr):
        if not (isinstance(expr, L.Member) and isinstance(expr.iter, L.Name)
                and expr.iter.id in symtab.get_relations()):
            return expr, [], []
        target = expr.target
        rel = expr.iter.id

        if L.is_tuple_of_names(target):
            cl = L.RelMember(L.detuplify(target), rel)
        elif isinstance(target, L.Name):
            cl = L.VarsMember([target.id], L.Wrap(L.Name(rel)))
        else:
            raise L.ProgramError('Invalid clause over relation')

        return cl, [], []
Beispiel #19
0
def make_auxmap_maint_func(fresh_vars, auxmap: AuxmapInvariant, op: L.setupop):
    """Make maintenance function for auxiliary map."""
    # Fresh variables for components of the element.
    vars = N.get_subnames('_elem', len(auxmap.mask.m))

    decomp_code = (L.DecompAssign(vars, L.Name('_elem')), )

    key, value = L.split_by_mask(auxmap.mask, vars)
    key = L.tuplify(key, unwrap=auxmap.unwrap_key)
    value = L.tuplify(value, unwrap=auxmap.unwrap_value)
    fresh_var_prefix = next(fresh_vars)
    key_var = fresh_var_prefix + '_key'
    value_var = fresh_var_prefix + '_value'

    decomp_code += L.Parser.pc('''
        _KEY_VAR = _KEY
        _VALUE_VAR = _VALUE
        ''',
                               subst={
                                   '_KEY_VAR': key_var,
                                   '_KEY': key,
                                   '_VALUE_VAR': value_var,
                                   '_VALUE': value
                               })

    img_func = {
        L.SetAdd: make_imgadd,
        L.SetRemove: make_imgremove
    }[op.__class__]
    img_code = img_func(fresh_vars, auxmap.map, key_var, value_var)

    func_name = auxmap.get_maint_func_name(op)

    func = L.Parser.ps('''
        def _FUNC(_elem):
            _DECOMP
            _IMGCODE
        ''',
                       subst={
                           '_FUNC': func_name,
                           '<c>_DECOMP': decomp_code,
                           '<c>_IMGCODE': img_code
                       })

    return func
Beispiel #20
0
    def get_maint_code(self,
                       fresh_var_prefix,
                       fresh_join_names,
                       comp,
                       result_var,
                       update,
                       *,
                       selfjoin=SelfJoin.Without,
                       counted):
        """Given a comprehension (not necessarily a join) and an
        update to a relation, return the maintenance code -- i.e.,
        the update to the stored result variable looped for each
        maintenance join.
        
        If counted is False, generate non-counted set updates.
        """
        assert isinstance(update, L.RelUpdate)
        assert isinstance(update.op, (L.SetAdd, L.SetRemove))

        result_elem_var = fresh_var_prefix + '_result'
        # Prefix LHS vars in the comp to guarantee fresh names for their
        # use in maintenance code.
        renamer = lambda x: fresh_var_prefix + '_' + x
        comp = self.comp_rename_lhs_vars(comp, renamer)

        body = ()
        body += (L.Assign(result_elem_var, comp.resexp), )
        body += L.rel_update(result_var,
                             update.op,
                             result_elem_var,
                             counted=counted)

        join = self.make_join_from_comp(comp)
        maint_joins = self.get_maint_join_union(join,
                                                update.rel,
                                                L.Name(update.elem),
                                                selfjoin=selfjoin)

        code = ()
        for maint_join in maint_joins:
            join_name = next(fresh_join_names)
            code += self.get_loop_for_join(maint_join, body, join_name)

        return code
Beispiel #21
0
 def visit_RelUpdate(self, node):
     if isinstance(node.op, L.SetAdd):
         is_add = True
     elif isinstance(node.op, L.SetRemove):
         is_add = False
     else:
         return
     rel = node.rel
     elem = L.Name(node.elem)
     
     if N.is_M(rel):
         set_ = L.Subscript(elem, L.Num(0))
         value = L.Subscript(elem, L.Num(1))
         code = (L.SetUpdate(set_, node.op, value),)
     
     elif N.is_F(rel):
         attr = N.get_F(rel)
         obj = L.Subscript(elem, L.Num(0))
         value = L.Subscript(elem, L.Num(1))
         if is_add:
             code = (L.AttrAssign(obj, attr, value),)
         else:
             code = (L.AttrDelete(obj, attr),)
     
     elif N.is_MAP(rel):
         map = L.Subscript(elem, L.Num(0))
         key = L.Subscript(elem, L.Num(1))
         value = L.Subscript(elem, L.Num(2))
         if is_add:
             code = (L.DictAssign(map, key, value),)
         else:
             code = (L.DictDelete(map, key),)
     
     else:
         code = node
     
     return code
Beispiel #22
0
def incrementalize_aggr(tree, symtab, query, result_var):
    # Form the invariant.
    aggrinv = aggrinv_from_query(symtab, query, result_var)
    handler = aggrinv.get_handler()

    # Transform to maintain it.
    trans = AggrMaintainer(symtab.fresh_names.vars, aggrinv)
    tree = trans.process(tree)
    symtab.maint_funcs.update(trans.maint_funcs)

    # Transform occurrences of the aggregate.

    zero = None if aggrinv.uses_demand else handler.make_zero_expr()
    state_expr = L.DictLookup(L.Name(aggrinv.map), L.tuplify(aggrinv.params),
                              zero)
    lookup_expr = handler.make_projection_expr(state_expr)

    class AggrExpander(S.QueryRewriter):
        expand = True

        def rewrite_aggr(self, symbol, name, expr):
            if name == query.name:
                return lookup_expr

    tree = AggrExpander.run(tree, symtab)

    # Determine the result map's type and define its symbol.
    t_rel = get_rel_type(symtab, aggrinv.rel)
    btypes, _ = L.split_by_mask(aggrinv.mask, t_rel.elt.elts)
    t_key = T.Tuple(btypes)
    t_val = handler.result_type(t_rel)
    t_map = T.Map(t_key, t_val)
    symtab.define_map(aggrinv.map, type=t_map)

    symtab.stats['aggrs_transformed'] += 1

    return tree
Beispiel #23
0
 def visit_MapDelete(self, node):
     return L.DictDelete(L.Name(node.map), L.Name(node.key))
Beispiel #24
0
 def visit_MapClear(self, node):
     return L.DictClear(L.Name(node.map))
Beispiel #25
0
 def visit_RelMember(self, node):
     return L.Member(L.tuplify(node.vars), L.Name(node.rel))
Beispiel #26
0
def make_eq_cond(left, right):
    """Make a condition of form <var> == <var>."""
    return L.Cond(L.Compare(L.Name(left), L.Eq(), L.Name(right)))
Beispiel #27
0
    'make_eq_cond',
    'SelfJoin',
    'ClauseTools',
    'CoreClauseTools',
]

from enum import Enum

from incoq.util.seq import zip_strict
from incoq.util.collections import OrderedSet, Partitioning
from incoq.compiler.incast import L

from .clause import ClauseVisitor, CoreClauseVisitor, Kind, ShouldFilter

eq_cond_pattern = L.Cond(
    L.Compare(L.Name(L.PatVar('LEFT')), L.Eq(), L.Name(L.PatVar('RIGHT'))))


def match_eq_cond(tree):
    """If tree is a condition clause with form <var> == <var>, return
    a pair of the variables. Otherwise return None.
    """
    result = L.match(eq_cond_pattern, tree)
    if result is None:
        return None
    else:
        return result['LEFT'], result['RIGHT']


def make_eq_cond(left, right):
    """Make a condition of form <var> == <var>."""
Beispiel #28
0
 def rewrite_comp(self, symbol, name, comp):
     if name == query.name:
         return L.Call(func_name, [L.Name(p) for p in query.params])
Beispiel #29
0
def make_aggr_restr_maint_func(fresh_vars, aggrinv, op):
    """Make the maintenance function for an aggregate invariant and
    an update to its restriction set.
    """
    assert isinstance(op, (L.SetAdd, L.SetRemove))
    assert aggrinv.uses_demand

    if isinstance(op, L.SetAdd):
        fresh_var_prefix = next(fresh_vars)
        value = fresh_var_prefix + '_value'
        state = fresh_var_prefix + '_state'
        keyvars = N.get_subnames('_key', len(aggrinv.params))

        decomp_key_code = (L.DecompAssign(keyvars, L.Name('_key')), )
        rellookup = L.ImgLookup(L.Name(aggrinv.rel), aggrinv.mask, keyvars)

        handler = aggrinv.get_handler()
        zero = handler.make_zero_expr()
        updatestate_code = handler.make_update_state_code(
            fresh_var_prefix, state, op, value)

        if aggrinv.unwrap:
            loop_template = '''
                for (_VALUE,) in _RELLOOKUP:
                    _UPDATESTATE
                '''
        else:
            loop_template = '''
                for _VALUE in _RELLOOKUP:
                    _UPDATESTATE
                '''

        loop_code = L.Parser.pc(loop_template,
                                subst={
                                    '_VALUE': value,
                                    '_RELLOOKUP': rellookup,
                                    '<c>_UPDATESTATE': updatestate_code
                                })

        maint_code = L.Parser.pc('''
            _STATE = _ZERO
            _DECOMP_KEY
            _LOOP
            _MAP.mapassign(_KEY, _STATE)
            ''',
                                 subst={
                                     '_MAP': aggrinv.map,
                                     '_KEY': '_key',
                                     '_STATE': state,
                                     '_ZERO': zero,
                                     '<c>_DECOMP_KEY': decomp_key_code,
                                     '<c>_LOOP': loop_code
                                 })

    else:
        maint_code = L.Parser.pc('''
            _MAP.mapdelete(_KEY)
            ''',
                                 subst={
                                     '_MAP': aggrinv.map,
                                     '_KEY': '_key'
                                 })

    func_name = aggrinv.get_restr_maint_func_name(op)

    func = L.Parser.ps('''
        def _FUNC(_key):
            _MAINT
        ''',
                       subst={
                           '_FUNC': func_name,
                           '<c>_MAINT': maint_code
                       })

    return func
Beispiel #30
0
def make_aggr_oper_maint_func(fresh_vars, aggrinv, op):
    """Make the maintenance function for an aggregate invariant and a
    given set update operation (add or remove) to the operand.
    """
    assert isinstance(op, (L.SetAdd, L.SetRemove))

    # Decompose the argument tuple into key and value components,
    # just like in auxmap.py.

    vars = N.get_subnames('_elem', len(aggrinv.mask.m))
    kvars, vvars = L.split_by_mask(aggrinv.mask, vars)
    ktuple = L.tuplify(kvars)
    vtuple = L.tuplify(vvars)
    fresh_var_prefix = next(fresh_vars)
    key = fresh_var_prefix + '_key'
    value = fresh_var_prefix + '_value'
    state = fresh_var_prefix + '_state'

    if aggrinv.unwrap:
        assert len(vvars) == 1
        value_expr = L.Name(vvars[0])
    else:
        value_expr = vtuple

    # Logic specific to aggregate operator.
    handler = aggrinv.get_handler()
    zero = handler.make_zero_expr()
    updatestate_code = handler.make_update_state_code(fresh_var_prefix, state,
                                                      op, value)

    subst = {
        '_KEY': key,
        '_KEY_EXPR': ktuple,
        '_VALUE': value,
        '_VALUE_EXPR': value_expr,
        '_MAP': aggrinv.map,
        '_STATE': state,
        '_ZERO': zero
    }

    if aggrinv.uses_demand:
        subst['_RESTR'] = aggrinv.restr
    else:
        # Empty conditions are only used when we don't have a
        # restriction set.
        subst['_EMPTY'] = handler.make_empty_cond(state)

    decomp_code = (L.DecompAssign(vars, L.Name('_elem')), )
    decomp_code += L.Parser.pc('''
        _KEY = _KEY_EXPR
        _VALUE = _VALUE_EXPR
        ''',
                               subst=subst)

    # Determine what kind of get/set state code to generate.
    if isinstance(op, L.SetAdd):
        definitely_preexists = aggrinv.uses_demand
        setstate_mayremove = False
    elif isinstance(op, L.SetRemove):
        definitely_preexists = True
        setstate_mayremove = not aggrinv.uses_demand
    else:
        assert ()

    if definitely_preexists:
        getstate_template = '_STATE = _MAP[_KEY]'
        delstate_template = '_MAP.mapdelete(_KEY)'
    else:
        getstate_template = '_STATE = _MAP.get(_KEY, _ZERO)'
        delstate_template = '''
            if _KEY in _MAP:
                _MAP.mapdelete(_KEY)
            '''

    if setstate_mayremove:
        setstate_template = '''
            if not _EMPTY:
                _MAP.mapassign(_KEY, _STATE)
            '''
    else:
        setstate_template = '_MAP.mapassign(_KEY, _STATE)'

    getstate_code = L.Parser.pc(getstate_template, subst=subst)
    delstate_code = L.Parser.pc(delstate_template, subst=subst)
    setstate_code = L.Parser.pc(setstate_template, subst=subst)

    maint_code = (getstate_code + updatestate_code + delstate_code +
                  setstate_code)

    # Guard in test if we have a restriction set.
    if aggrinv.uses_demand:
        maint_subst = dict(subst)
        maint_subst['<c>_MAINT'] = maint_code
        maint_code = L.Parser.pc('''
            if _KEY in _RESTR:
                _MAINT
            ''',
                                 subst=maint_subst)

    func_name = aggrinv.get_oper_maint_func_name(op)

    func = L.Parser.ps('''
        def _FUNC(_elem):
            _DECOMP
            _MAINT
        ''',
                       subst={
                           '_FUNC': func_name,
                           '<c>_DECOMP': decomp_code,
                           '<c>_MAINT': maint_code
                       })

    return func