Example #1
0
 def con_lhs_vars_from_comp(self, comp):
     """Return a tuple of constrained variables.
     
     An variable is constrained if its first occurrence in the query
     is in a constrained position.
     
     For cyclic object queries like
     
         {(x, y) for y in x for x in y},
     
     after translating into clauses over M, there are two possible
     sets of constrained vars: {x} and {y}. This function processes
     clauses left-to-right, so {y} will be chosen.
     """
     uncon = OrderedSet()
     con = OrderedSet()
     for cl in comp.clauses:
         # The new unconstrained vars are the ones that occur in
         # unconstrained positions and are not already known to be
         # constrained.
         uncon.update(v for v in self.uncon_vars(cl) if v not in con)
         # Vice versa for the constrained vars. Uncons are processed
         # first so that "x in x" makes x unconstrained if it hasn't
         # been seen before.
         con.update(v for v in self.con_lhs_vars(cl) if v not in uncon)
     return tuple(con)
Example #2
0
class IdentFinder(L.NodeVisitor):
    """Return an OrderedSet of all identifiers in the specified
    contexts.
    """

    fun_ctxs = ('Fun.name', 'Call.func')
    query_ctxs = ('Query.name', 'ResetDemand.names')
    rel_ctxs = ('RelUpdate.rel', 'RelClear.rel', 'RelMember.rel')

    @classmethod
    def find_functions(cls, tree):
        return cls().run(tree, contexts=cls.fun_ctxs)

    @classmethod
    def find_vars(cls, tree):
        ctxs = (cls.fun_ctxs + cls.query_ctxs)
        return cls().run(tree, contexts=ctxs, invert=True)

    @classmethod
    def find_non_rel_uses(cls, tree):
        ctxs = (cls.fun_ctxs + cls.query_ctxs + cls.rel_ctxs)
        return cls().run(tree, contexts=ctxs, invert=True)

    def __init__(self, contexts=None, invert=False):
        if contexts is not None:
            for c in contexts:
                node_name, field_name = c.split('.')
                if not field_name in L.ident_fields.get(node_name, []):
                    raise ValueError(
                        'Unknown identifier context "{}"'.format(c))

        self.contexts = contexts
        """Collection of contexts to include/exclude. Each context is
        a string of the form '<node type name>.<field name>'. A value
        of None is equivalent to specifying all contexts.
        """
        self.invert = bool(invert)
        """If True, find identifiers that occur in any context besides
        the ones given.
        """

    def process(self, tree):
        self.names = OrderedSet()
        super().process(tree)
        return self.names

    def generic_visit(self, node):
        super().generic_visit(node)
        clsname = node.__class__.__name__
        id_fields = L.ident_fields.get(clsname, [])
        for f in id_fields:
            inctx = (self.contexts is None
                     or clsname + '.' + f in self.contexts)
            if inctx != self.invert:
                # Normalize for either one id or a sequence of ids.
                ids = getattr(node, f)
                if isinstance(ids, str):
                    ids = [ids]
                if ids is not None:
                    self.names.update(ids)
Example #3
0
 def con_lhs_vars_from_comp(self, comp):
     """Return a tuple of constrained variables.
     
     An variable is constrained if its first occurrence in the query
     is in a constrained position.
     
     For cyclic object queries like
     
         {(x, y) for y in x for x in y},
     
     after translating into clauses over M, there are two possible
     sets of constrained vars: {x} and {y}. This function processes
     clauses left-to-right, so {y} will be chosen.
     """
     uncon = OrderedSet()
     con = OrderedSet()
     for cl in comp.clauses:
         # The new unconstrained vars are the ones that occur in
         # unconstrained positions and are not already known to be
         # constrained.
         uncon.update(v for v in self.uncon_vars(cl) if v not in con)
         # Vice versa for the constrained vars. Uncons are processed
         # first so that "x in x" makes x unconstrained if it hasn't
         # been seen before.
         con.update(v for v in self.con_lhs_vars(cl) if v not in uncon)
     return tuple(con)
Example #4
0
 def lhs_vars_from_clauses(self, clauses):
     """Return a tuple of all LHS vars appearing in the given clauses, in
     order, without duplicates.
     """
     vars = OrderedSet()
     for cl in clauses:
         vars.update(self.lhs_vars(cl))
     return tuple(vars)
Example #5
0
 def lhs_vars_from_clauses(self, clauses):
     """Return a tuple of all LHS vars appearing in the given clauses, in
     order, without duplicates.
     """
     vars = OrderedSet()
     for cl in clauses:
         vars.update(self.lhs_vars(cl))
     return tuple(vars)
Example #6
0
class AggrMaintainer(L.NodeTransformer):
    def __init__(self, fresh_vars, aggrinv):
        super().__init__()
        self.fresh_vars = fresh_vars
        self.aggrinv = aggrinv
        self.maint_funcs = OrderedSet()

    def visit_Module(self, node):
        node = self.generic_visit(node)

        fv = self.fresh_vars
        ops = [L.SetAdd(), L.SetRemove()]
        funcs = []
        for op in ops:
            func = make_aggr_oper_maint_func(fv, self.aggrinv, op)
            funcs.append(func)
        if self.aggrinv.uses_demand:
            for op in ops:
                func = make_aggr_restr_maint_func(fv, self.aggrinv, op)
                funcs.append(func)

        func_names = L.get_defined_functions(tuple(funcs))
        self.maint_funcs.update(func_names)

        node = node._replace(body=tuple(funcs) + node.body)
        return node

    def visit_RelUpdate(self, node):
        if not isinstance(node.op, (L.SetAdd, L.SetRemove)):
            return node

        if node.rel == self.aggrinv.rel:
            func = self.aggrinv.get_oper_maint_func_name(node.op)
            code = L.insert_rel_maint_call(node, func)
        elif self.aggrinv.uses_demand and node.rel == self.aggrinv.restr:
            func = self.aggrinv.get_restr_maint_func_name(node.op)
            code = L.insert_rel_maint_call(node, func)
        else:
            code = node

        return code

    def visit_RelClear(self, node):
        # We should clear if we are not using demand and our operand is
        # being cleared, or if we are using demand and our demand set is
        # being cleared.
        aggrinv = self.aggrinv
        uses_demand = aggrinv.uses_demand
        if uses_demand:
            should_clear = node.rel == aggrinv.restr
        else:
            should_clear = node.rel == aggrinv.rel
        if not should_clear:
            return node

        clear_code = (L.MapClear(self.aggrinv.map),)
        code = L.insert_rel_maint((node,), clear_code, L.SetRemove())
        return code
Example #7
0
class AggrMaintainer(L.NodeTransformer):
    def __init__(self, fresh_vars, aggrinv):
        super().__init__()
        self.fresh_vars = fresh_vars
        self.aggrinv = aggrinv
        self.maint_funcs = OrderedSet()

    def visit_Module(self, node):
        node = self.generic_visit(node)

        fv = self.fresh_vars
        ops = [L.SetAdd(), L.SetRemove()]
        funcs = []
        for op in ops:
            func = make_aggr_oper_maint_func(fv, self.aggrinv, op)
            funcs.append(func)
        if self.aggrinv.uses_demand:
            for op in ops:
                func = make_aggr_restr_maint_func(fv, self.aggrinv, op)
                funcs.append(func)

        func_names = L.get_defined_functions(tuple(funcs))
        self.maint_funcs.update(func_names)

        node = node._replace(body=tuple(funcs) + node.body)
        return node

    def visit_RelUpdate(self, node):
        if not isinstance(node.op, (L.SetAdd, L.SetRemove)):
            return node

        if node.rel == self.aggrinv.rel:
            func = self.aggrinv.get_oper_maint_func_name(node.op)
            code = L.insert_rel_maint_call(node, func)
        elif self.aggrinv.uses_demand and node.rel == self.aggrinv.restr:
            func = self.aggrinv.get_restr_maint_func_name(node.op)
            code = L.insert_rel_maint_call(node, func)
        else:
            code = node

        return code

    def visit_RelClear(self, node):
        # We should clear if we are not using demand and our operand is
        # being cleared, or if we are using demand and our demand set is
        # being cleared.
        aggrinv = self.aggrinv
        uses_demand = aggrinv.uses_demand
        if uses_demand:
            should_clear = node.rel == aggrinv.restr
        else:
            should_clear = node.rel == aggrinv.rel
        if not should_clear:
            return node

        clear_code = (L.MapClear(self.aggrinv.map), )
        code = L.insert_rel_maint((node, ), clear_code, L.SetRemove())
        return code
Example #8
0
def prefix_locals(tree, prefix, extra_boundvars):
    """Rename the local variables in a block of code with the given
    prefix. The extra_boundvars are treated as additional variables
    known to be bound even if they don't appear in a binding occurrence
    within tree.
    """
    localvars = OrderedSet(extra_boundvars)
    localvars.update(BindingFinder.run(tree))
    tree = Templater.run(tree, {v: prefix + v for v in localvars})
    return tree
Example #9
0
def prefix_locals(tree, prefix, extra_boundvars):
    """Rename the local variables in a block of code with the given
    prefix. The extra_boundvars are treated as additional variables
    known to be bound even if they don't appear in a binding occurrence
    within tree.
    """
    localvars = OrderedSet(extra_boundvars)
    localvars.update(BindingFinder.run(tree))
    tree = Templater.run(tree, {v: prefix + v for v in localvars})
    return tree
Example #10
0
    class Flattener(L.QueryMapper):
        def process(self, tree):
            self.trels = OrderedSet()
            tree = super().process(tree)
            return tree, self.trels

        def map_Comp(self, node):
            new_comp, new_trels = flatten_tuples_comp(node)
            self.trels.update(new_trels)
            return new_comp
Example #11
0
 class Flattener(L.QueryMapper):
     
     def process(self, tree):
         self.trels = OrderedSet()
         tree = super().process(tree)
         return tree, self.trels
     
     def map_Comp(self, node):
         new_comp, new_trels = flatten_tuples_comp(node)
         self.trels.update(new_trels)
         return new_comp
Example #12
0
    class Flattener(L.QueryMapper):
        def process(self, tree):
            self.use_mset = False
            self.fields = OrderedSet()
            self.use_mapset = False
            tree = super().process(tree)
            return tree, self.use_mset, self.fields, self.use_mapset

        def map_Comp(self, node):
            new_comp, new_mset, new_fields, new_mapset = \
                flatten_comp(node, input_rels)
            self.use_mset |= new_mset
            self.fields.update(new_fields)
            self.use_mapset |= new_mapset
            return new_comp
Example #13
0
class AggrMapCompRewriter(S.QueryRewriter):
    """Rewrite comprehension queries so that map lookups from aggregates
    are flattened into SetFromMap clauses. Return a pair of the new tree
    and a set of SetFromMap invariants that need to be transformed.
    """
    def process(self, tree):
        self.sfm_invs = OrderedSet()
        tree = super().process(tree)
        return tree, self.sfm_invs

    def rewrite_comp(self, symbol, name, comp):
        rewriter = AggrMapReplacer(self.symtab.fresh_names.vars)
        comp = L.rewrite_comp(comp, rewriter.process)
        self.sfm_invs.update(rewriter.sfm_invs)
        return comp
Example #14
0
 class Flattener(L.QueryMapper):
     def process(self, tree):
         self.use_mset = False
         self.fields = OrderedSet()
         self.use_mapset = False
         tree = super().process(tree)
         return tree, self.use_mset, self.fields, self.use_mapset
     
     def map_Comp(self, node):
         new_comp, new_mset, new_fields, new_mapset = \
             flatten_comp(node, input_rels)
         self.use_mset |= new_mset
         self.fields.update(new_fields)
         self.use_mapset |= new_mapset
         return new_comp
Example #15
0
class AggrMapCompRewriter(S.QueryRewriter):
    
    """Rewrite comprehension queries so that map lookups from aggregates
    are flattened into SetFromMap clauses. Return a pair of the new tree
    and a set of SetFromMap invariants that need to be transformed.
    """
    
    def process(self, tree):
        self.sfm_invs = OrderedSet()
        tree = super().process(tree)
        return tree, self.sfm_invs
    
    def rewrite_comp(self, symbol, name, comp):
        rewriter = AggrMapReplacer(self.symtab.fresh_names.vars)
        comp = L.rewrite_comp(comp, rewriter.process)
        self.sfm_invs.update(rewriter.sfm_invs)
        return comp
Example #16
0
class BindingFinder(L.NodeVisitor):
    """Return names of variables that appear in a binding context, i.e.,
    that would have a Store context in Python's AST.
    """
    def process(self, tree):
        self.vars = OrderedSet()
        self.write_ctx = False
        super().process(tree)
        return self.vars

    def visit_Fun(self, node):
        self.vars.update(node.args)
        self.generic_visit(node)

    def visit_For(self, node):
        self.vars.add(node.target)
        self.generic_visit(node)

    def visit_DecompFor(self, node):
        self.vars.update(node.vars)
        self.generic_visit(node)

    def visit_Assign(self, node):
        self.vars.add(node.target)
        self.generic_visit(node)

    def visit_DecompAssign(self, node):
        self.vars.update(node.vars)
        self.generic_visit(node)
Example #17
0
class BindingFinder(L.NodeVisitor):
    
    """Return names of variables that appear in a binding context, i.e.,
    that would have a Store context in Python's AST.
    """
    
    def process(self, tree):
        self.vars = OrderedSet()
        self.write_ctx = False
        super().process(tree)
        return self.vars
    
    def visit_Fun(self, node):
        self.vars.update(node.args)
        self.generic_visit(node)
    
    def visit_For(self, node):
        self.vars.add(node.target)
        self.generic_visit(node)
    
    def visit_DecompFor(self, node):
        self.vars.update(node.vars)
        self.generic_visit(node)
    
    def visit_Assign(self, node):
        self.vars.add(node.target)
        self.generic_visit(node)
    
    def visit_DecompAssign(self, node):
        self.vars.update(node.vars)
        self.generic_visit(node)
Example #18
0
class InvariantTransformer(L.AdvNodeTransformer):
    
    """Insert maintenance functions, insert calls to these functions
    at updates, and replace expressions with uses of stored results.
    """
    
    # There can be at most one SetFromMap mask for a given map.
    # Multiple distinct masks would have to differ on their arity,
    # which would be a type error.
    
    def __init__(self, fresh_vars, auxmaps, setfrommaps, wraps):
        super().__init__()
        self.fresh_vars = fresh_vars
        self.maint_funcs = OrderedSet()
        
        # Index over auxmaps for fast retrieval.
        self.auxmaps_by_rel = OrderedDict()
        self.auxmaps_by_relmask = OrderedDict()
        for auxmap in auxmaps:
            self.auxmaps_by_rel.setdefault(auxmap.rel, []).append(auxmap)
            # Index by relation, mask, and whether value is a singleton.
            # Not indexed by whether key is a singleton; if there are
            # two separate invariants that differ only on that, we may
            # end up only using one but maintaining both.
            stats = (auxmap.rel, auxmap.mask, auxmap.unwrap_value)
            self.auxmaps_by_relmask[stats] = auxmap
        
        # Index over setfrommaps.
        self.setfrommaps_by_map = sfm_by_map = OrderedDict()
        for sfm in setfrommaps:
            if sfm.map in sfm_by_map:
                raise L.ProgramError('Multiple SetFromMap invariants on '
                                     'same map {}'.format(sfm.map))
            sfm_by_map[sfm.map] = sfm
        
        self.wraps_by_rel = OrderedDict()
        for wrap in wraps:
            self.wraps_by_rel.setdefault(wrap.oper, []).append(wrap)
    
    def in_loop_rhs(self):
        """Return True if the current node is the iter of a For node.
        Note that this checks the _visit_stack, and is sensitive to
        where in the recursive visiting we are.
        """
        stack = self._visit_stack
        if len(stack) < 2:
            return False
        _current, field, _index = stack[-1]
        parent, _field, _index = stack[-2]
        return isinstance(parent, L.For) and field == 'iter'
    
    def visit_Module(self, node):
        node = self.generic_visit(node)
        
        funcs = []
        for auxmaps in self.auxmaps_by_rel.values():
            for auxmap in auxmaps:
                for op in [L.SetAdd(), L.SetRemove()]:
                    func = make_auxmap_maint_func(self.fresh_vars, auxmap, op)
                    funcs.append(func)
        for sfm in self.setfrommaps_by_map.values():
            for op in ['assign', 'delete']:
                func = make_setfrommap_maint_func(self.fresh_vars, sfm, op)
                funcs.append(func)
        for wraps in self.wraps_by_rel.values():
            for wrap in wraps:
                for op in [L.SetAdd(), L.SetRemove()]:
                    func = make_wrap_maint_func(self.fresh_vars, wrap, op)
                    funcs.append(func)
        
        func_names = L.get_defined_functions(tuple(funcs))
        self.maint_funcs.update(func_names)
        
        node = node._replace(body=tuple(funcs) + node.body)
        return node
    
    def visit_RelUpdate(self, node):
        if not isinstance(node.op, (L.SetAdd, L.SetRemove)):
            return node
        
        code = (node,)
        
        auxmaps = self.auxmaps_by_rel.get(node.rel, set())
        for auxmap in auxmaps:
            func_name = auxmap.get_maint_func_name(node.op)
            call_code = (L.Expr(L.Call(func_name, [L.Name(node.elem)])),)
            code = L.insert_rel_maint(code, call_code, node.op)
        
        wraps = self.wraps_by_rel.get(node.rel, set())
        for wrap in wraps:
            func_name = wrap.get_maint_func_name(node.op)
            call_code = (L.Expr(L.Call(func_name, [L.Name(node.elem)])),)
            code = L.insert_rel_maint(code, call_code, node.op)
        
        return code
    
    def visit_RelClear(self, node):
        code = (node,)
        
        auxmaps = self.auxmaps_by_rel.get(node.rel, set())
        for auxmap in auxmaps:
            clear_code = (L.MapClear(auxmap.map),)
            code = L.insert_rel_maint(code, clear_code, L.SetRemove())
        
        wraps = self.wraps_by_rel.get(node.rel, set())
        for wrap in wraps:
            clear_code = (L.RelClear(wrap.rel),)
            code = L.insert_rel_maint(code, clear_code, L.SetRemove())
        
        return code
    
    def visit_MapAssign(self, node):
        sfm = self.setfrommaps_by_map.get(node.map, None)
        if sfm is None:
            return node
        
        code = (node,)
        func_name = sfm.get_maint_func_name('assign')
        call_code = (L.Expr(L.Call(func_name, [L.Name(node.key),
                                               L.Name(node.value)])),)
        code = L.insert_rel_maint(code, call_code, L.SetAdd())
        return code
    
    def visit_MapDelete(self, node):
        sfm = self.setfrommaps_by_map.get(node.map, None)
        if sfm is None:
            return node
        
        code = (node,)
        func_name = sfm.get_maint_func_name('delete')
        call_code = (L.Expr(L.Call(func_name, [L.Name(node.key)])),)
        code = L.insert_rel_maint(code, call_code, L.SetRemove())
        return code
    
    def visit_MapClear(self, node):
        sfm = self.setfrommaps_by_map.get(node.map, None)
        if sfm is None:
            return node
        
        code = (node,)
        clear_code = (L.RelClear(sfm.rel),)
        code = L.insert_rel_maint(code, clear_code, L.SetRemove())
        return code
    
    def imglookup_helper(self, node, *, in_unwrap, on_loop_rhs):
        """Return the replacement for an ImgLookup node, or None if
        no replacement is applicable.
        """
        if not isinstance(node.set, L.Name):
            return None
        rel = node.set.id
        
        stats = (rel, node.mask, in_unwrap)
        auxmap = self.auxmaps_by_relmask.get(stats, None)
        if auxmap is None:
            return None
        
        key = L.tuplify(node.bounds, unwrap=auxmap.unwrap_key)
        empty = L.Parser.pe('()') if on_loop_rhs else L.Parser.pe('Set()')
        return L.Parser.pe('_MAP[_KEY] if _KEY in _MAP else _EMPTY',
                           subst={'_MAP': auxmap.map,
                                  '_KEY': key,
                                  '_EMPTY': empty})
    
    def visit_ImgLookup(self, node):
        node = self.generic_visit(node)
        
        repl = self.imglookup_helper(node, in_unwrap=False,
                                     on_loop_rhs=self.in_loop_rhs())
        if repl is not None:
            node = repl
        
        return node
    
    def visit_SetFromMap(self, node):
        node = self.generic_visit(node)
        
        if not isinstance(node.map, L.Name):
            return node
        map = node.map.id
        
        sfm = self.setfrommaps_by_map.get(map, None)
        if sfm is None:
            return node
        
        if not sfm.mask == node.mask:
            raise L.ProgramError('Multiple SetFromMap expressions on '
                                 'same map {}'.format(map))
        return L.Name(sfm.rel)
    
    def wrap_helper(self, node):
        """Process a wrap invariant at a Wrap or Unwrap node.
        Don't recurse.
        """
        if not isinstance(node.value, L.Name):
            return node
        rel = node.value.id
        
        wraps = self.wraps_by_rel.get(rel, [])
        for wrap in wraps:
            if ((isinstance(node, L.Wrap) and not wrap.unwrap) or
                (isinstance(node, L.Unwrap) and wrap.unwrap)):
                return L.Name(wrap.rel)
        
        return node
    
    def visit_Unwrap(self, node):
        # Handle special case of an auxmap with the unwrap_value flag.
        if isinstance(node.value, L.ImgLookup):
            # Recurse over children below the ImgLookup.
            value = self.generic_visit(node.value)
            node = node._replace(value=value)
            # See if an auxmap invariant applies.
            repl = self.imglookup_helper(value, in_unwrap=True,
                                         on_loop_rhs=self.in_loop_rhs())
            if repl is not None:
                return repl
        
        else:
            # Don't run in case we already did generic_visit() above
            # and didn't return.
            node = self.generic_visit(node)
        
        return self.wrap_helper(node)
    
    def visit_Wrap(self, node):
        node = self.generic_visit(node)
        node = self.wrap_helper(node)
        return node
Example #19
0
class CompMaintainer(L.NodeTransformer):
    
    """Insert comprehension maintenance functions and calls to these
    functions at relevant updates.
    """
    
    def __init__(self, clausetools, fresh_vars, fresh_join_names,
                 comp, result_var, *,
                 counted):
        super().__init__()
        self.clausetools = clausetools
        self.fresh_vars = fresh_vars
        self.fresh_join_names = fresh_join_names
        self.comp = comp
        self.result_var = result_var
        self.counted = counted
        
        self.rels = self.clausetools.rhs_rels_from_comp(self.comp)
        self.maint_funcs = OrderedSet()
    
    def visit_Module(self, node):
        ct = self.clausetools
        
        node = self.generic_visit(node)
        
        funcs = []
        for rel in self.rels:
            for op in [L.SetAdd(), L.SetRemove()]:
                fresh_var_prefix = next(self.fresh_vars)
                func = make_comp_maint_func(
                        ct, fresh_var_prefix, self.fresh_join_names,
                        self.comp, self.result_var, rel, op,
                        counted=self.counted)
                funcs.append(func)
        
        func_names = L.get_defined_functions(tuple(funcs))
        self.maint_funcs.update(func_names)
        
        node = node._replace(body=tuple(funcs) + node.body)
        return node
    
    def visit_RelUpdate(self, node):
        if not isinstance(node.op, (L.SetAdd, L.SetRemove)):
            return node
        if node.rel not in self.rels:
            return node
        
        op_name = L.set_update_name(node.op)
        func_name = N.get_maint_func_name(self.result_var, node.rel, op_name)
        
        code = (node,)
        call_code = (L.Expr(L.Call(func_name, [L.Name(node.elem)])),)
        code = L.insert_rel_maint(code, call_code, node.op)
        return code
    
    def visit_RelClear(self, node):
        if node.rel not in self.rels:
            return node
        
        code = (node,)
        clear_code = (L.RelClear(self.result_var),)
        code = L.insert_rel_maint(code, clear_code, L.SetRemove())
        return code
Example #20
0
class InvariantTransformer(L.AdvNodeTransformer):
    """Insert maintenance functions, insert calls to these functions
    at updates, and replace expressions with uses of stored results.
    """

    # There can be at most one SetFromMap mask for a given map.
    # Multiple distinct masks would have to differ on their arity,
    # which would be a type error.

    def __init__(self, fresh_vars, auxmaps, setfrommaps, wraps):
        super().__init__()
        self.fresh_vars = fresh_vars
        self.maint_funcs = OrderedSet()

        # Index over auxmaps for fast retrieval.
        self.auxmaps_by_rel = OrderedDict()
        self.auxmaps_by_relmask = OrderedDict()
        for auxmap in auxmaps:
            self.auxmaps_by_rel.setdefault(auxmap.rel, []).append(auxmap)
            # Index by relation, mask, and whether value is a singleton.
            # Not indexed by whether key is a singleton; if there are
            # two separate invariants that differ only on that, we may
            # end up only using one but maintaining both.
            stats = (auxmap.rel, auxmap.mask, auxmap.unwrap_value)
            self.auxmaps_by_relmask[stats] = auxmap

        # Index over setfrommaps.
        self.setfrommaps_by_map = sfm_by_map = OrderedDict()
        for sfm in setfrommaps:
            if sfm.map in sfm_by_map:
                raise L.ProgramError('Multiple SetFromMap invariants on '
                                     'same map {}'.format(sfm.map))
            sfm_by_map[sfm.map] = sfm

        self.wraps_by_rel = OrderedDict()
        for wrap in wraps:
            self.wraps_by_rel.setdefault(wrap.oper, []).append(wrap)

    def in_loop_rhs(self):
        """Return True if the current node is the iter of a For node.
        Note that this checks the _visit_stack, and is sensitive to
        where in the recursive visiting we are.
        """
        stack = self._visit_stack
        if len(stack) < 2:
            return False
        _current, field, _index = stack[-1]
        parent, _field, _index = stack[-2]
        return isinstance(parent, L.For) and field == 'iter'

    def visit_Module(self, node):
        node = self.generic_visit(node)

        funcs = []
        for auxmaps in self.auxmaps_by_rel.values():
            for auxmap in auxmaps:
                for op in [L.SetAdd(), L.SetRemove()]:
                    func = make_auxmap_maint_func(self.fresh_vars, auxmap, op)
                    funcs.append(func)
        for sfm in self.setfrommaps_by_map.values():
            for op in ['assign', 'delete']:
                func = make_setfrommap_maint_func(self.fresh_vars, sfm, op)
                funcs.append(func)
        for wraps in self.wraps_by_rel.values():
            for wrap in wraps:
                for op in [L.SetAdd(), L.SetRemove()]:
                    func = make_wrap_maint_func(self.fresh_vars, wrap, op)
                    funcs.append(func)

        func_names = L.get_defined_functions(tuple(funcs))
        self.maint_funcs.update(func_names)

        node = node._replace(body=tuple(funcs) + node.body)
        return node

    def visit_RelUpdate(self, node):
        if not isinstance(node.op, (L.SetAdd, L.SetRemove)):
            return node

        code = (node, )

        auxmaps = self.auxmaps_by_rel.get(node.rel, set())
        for auxmap in auxmaps:
            func_name = auxmap.get_maint_func_name(node.op)
            call_code = (L.Expr(L.Call(func_name, [L.Name(node.elem)])), )
            code = L.insert_rel_maint(code, call_code, node.op)

        wraps = self.wraps_by_rel.get(node.rel, set())
        for wrap in wraps:
            func_name = wrap.get_maint_func_name(node.op)
            call_code = (L.Expr(L.Call(func_name, [L.Name(node.elem)])), )
            code = L.insert_rel_maint(code, call_code, node.op)

        return code

    def visit_RelClear(self, node):
        code = (node, )

        auxmaps = self.auxmaps_by_rel.get(node.rel, set())
        for auxmap in auxmaps:
            clear_code = (L.MapClear(auxmap.map), )
            code = L.insert_rel_maint(code, clear_code, L.SetRemove())

        wraps = self.wraps_by_rel.get(node.rel, set())
        for wrap in wraps:
            clear_code = (L.RelClear(wrap.rel), )
            code = L.insert_rel_maint(code, clear_code, L.SetRemove())

        return code

    def visit_MapAssign(self, node):
        sfm = self.setfrommaps_by_map.get(node.map, None)
        if sfm is None:
            return node

        code = (node, )
        func_name = sfm.get_maint_func_name('assign')
        call_code = (L.Expr(
            L.Call(func_name,
                   [L.Name(node.key), L.Name(node.value)])), )
        code = L.insert_rel_maint(code, call_code, L.SetAdd())
        return code

    def visit_MapDelete(self, node):
        sfm = self.setfrommaps_by_map.get(node.map, None)
        if sfm is None:
            return node

        code = (node, )
        func_name = sfm.get_maint_func_name('delete')
        call_code = (L.Expr(L.Call(func_name, [L.Name(node.key)])), )
        code = L.insert_rel_maint(code, call_code, L.SetRemove())
        return code

    def visit_MapClear(self, node):
        sfm = self.setfrommaps_by_map.get(node.map, None)
        if sfm is None:
            return node

        code = (node, )
        clear_code = (L.RelClear(sfm.rel), )
        code = L.insert_rel_maint(code, clear_code, L.SetRemove())
        return code

    def imglookup_helper(self, node, *, in_unwrap, on_loop_rhs):
        """Return the replacement for an ImgLookup node, or None if
        no replacement is applicable.
        """
        if not isinstance(node.set, L.Name):
            return None
        rel = node.set.id

        stats = (rel, node.mask, in_unwrap)
        auxmap = self.auxmaps_by_relmask.get(stats, None)
        if auxmap is None:
            return None

        key = L.tuplify(node.bounds, unwrap=auxmap.unwrap_key)
        empty = L.Parser.pe('()') if on_loop_rhs else L.Parser.pe('Set()')
        return L.Parser.pe('_MAP[_KEY] if _KEY in _MAP else _EMPTY',
                           subst={
                               '_MAP': auxmap.map,
                               '_KEY': key,
                               '_EMPTY': empty
                           })

    def visit_ImgLookup(self, node):
        node = self.generic_visit(node)

        repl = self.imglookup_helper(node,
                                     in_unwrap=False,
                                     on_loop_rhs=self.in_loop_rhs())
        if repl is not None:
            node = repl

        return node

    def visit_SetFromMap(self, node):
        node = self.generic_visit(node)

        if not isinstance(node.map, L.Name):
            return node
        map = node.map.id

        sfm = self.setfrommaps_by_map.get(map, None)
        if sfm is None:
            return node

        if not sfm.mask == node.mask:
            raise L.ProgramError('Multiple SetFromMap expressions on '
                                 'same map {}'.format(map))
        return L.Name(sfm.rel)

    def wrap_helper(self, node):
        """Process a wrap invariant at a Wrap or Unwrap node.
        Don't recurse.
        """
        if not isinstance(node.value, L.Name):
            return node
        rel = node.value.id

        wraps = self.wraps_by_rel.get(rel, [])
        for wrap in wraps:
            if ((isinstance(node, L.Wrap) and not wrap.unwrap)
                    or (isinstance(node, L.Unwrap) and wrap.unwrap)):
                return L.Name(wrap.rel)

        return node

    def visit_Unwrap(self, node):
        # Handle special case of an auxmap with the unwrap_value flag.
        if isinstance(node.value, L.ImgLookup):
            # Recurse over children below the ImgLookup.
            value = self.generic_visit(node.value)
            node = node._replace(value=value)
            # See if an auxmap invariant applies.
            repl = self.imglookup_helper(value,
                                         in_unwrap=True,
                                         on_loop_rhs=self.in_loop_rhs())
            if repl is not None:
                return repl

        else:
            # Don't run in case we already did generic_visit() above
            # and didn't return.
            node = self.generic_visit(node)

        return self.wrap_helper(node)

    def visit_Wrap(self, node):
        node = self.generic_visit(node)
        node = self.wrap_helper(node)
        return node
Example #21
0
File: util.py Project: IncOQ/incoq
class VarsFinder(NodeVisitor):
    
    """Simple finder of variables (Name nodes).
    
    Flags:
    
        ignore_store:
            Name nodes with Store context are ignored, as are update
            operations (e.g. SetUpdate). As an exception, Name nodes
            on the LHS of Enumerators are not ignored. This is to
            ensure safety under pattern matching semantics.
        ignore_functions:
            Name nodes that appear to be functions are ignored.
        ignore_rels:
            Names that appear to be relations are ignored.
    
    The builtins None, True, and False are always excluded, as they
    are NameConstants, not variables.
    """
    
    def __init__(self, *,
                 ignore_store=False,
                 ignore_functions=False,
                 ignore_rels=False):
        super().__init__()
        self.ignore_store = ignore_store
        self.ignore_functions = ignore_functions
        self.ignore_rels = ignore_rels
    
    def process(self, tree):
        self.usedvars = OrderedSet()
        super().process(tree)
        return self.usedvars
    
    def visit_Name(self, node):
        self.generic_visit(node)
        if not (self.ignore_store and isinstance(node.ctx, Store)):
            self.usedvars.add(node.id)
    
    def visit_Call(self, node):
        class IGNORE(iast.AST):
            _meta = True
        
        if isinstance(node.func, Name) and self.ignore_functions:
            self.generic_visit(node._replace(func=IGNORE()))
        
        else:
            self.generic_visit(node)
    
    def visit_Enumerator(self, node):
        if is_vartuple(node.target):
            # Bypass ignore_store.
            vars = get_vartuple(node.target)
            self.usedvars.update(vars)
        else:
            self.visit(node.target)
        
        if not (self.ignore_rels and isinstance(node.iter, Name)):
            self.visit(node.iter)
    
    def visit_Comp(self, node):
        self.visit(node.resexp)
        # Hack to ensure we don't grab rels on RHS of
        # membership conditions.
        for i in node.clauses:
            if (self.ignore_rels and
                isinstance(i, Compare) and
                len(i.ops) == len(i.comparators) == 1 and
                isinstance(i.ops[0], In) and
                isinstance(i.comparators[0], Name)):
                self.visit(i.left)
            else:
                self.visit(i)
    
    def visit_Aggregate(self, node):
        if isinstance(node.value, Name) and self.ignore_rels:
            return
        else:
            self.generic_visit(node)
    
    def visit_SetMatch(self, node):
        if isinstance(node.target, Name) and self.ignore_rels:
            self.visit(node.key)
        else:
            self.generic_visit(node)
    
    def update_helper(self, node):
        IGNORE = object()
        
        if isinstance(node.target, Name) and self.ignore_store:
            self.generic_visit(node._replace(target=IGNORE))
        
        else:
            self.generic_visit(node)
    
    visit_SetUpdate = update_helper
    visit_RCSetRefUpdate = update_helper
    visit_AssignKey = update_helper
    visit_DelKey = update_helper
Example #22
0
class CompMaintainer(L.NodeTransformer):
    """Insert comprehension maintenance functions and calls to these
    functions at relevant updates.
    """
    def __init__(self, clausetools, fresh_vars, fresh_join_names, comp,
                 result_var, *, counted):
        super().__init__()
        self.clausetools = clausetools
        self.fresh_vars = fresh_vars
        self.fresh_join_names = fresh_join_names
        self.comp = comp
        self.result_var = result_var
        self.counted = counted

        self.rels = self.clausetools.rhs_rels_from_comp(self.comp)
        self.maint_funcs = OrderedSet()

    def visit_Module(self, node):
        ct = self.clausetools

        node = self.generic_visit(node)

        funcs = []
        for rel in self.rels:
            for op in [L.SetAdd(), L.SetRemove()]:
                fresh_var_prefix = next(self.fresh_vars)
                func = make_comp_maint_func(ct,
                                            fresh_var_prefix,
                                            self.fresh_join_names,
                                            self.comp,
                                            self.result_var,
                                            rel,
                                            op,
                                            counted=self.counted)
                funcs.append(func)

        func_names = L.get_defined_functions(tuple(funcs))
        self.maint_funcs.update(func_names)

        node = node._replace(body=tuple(funcs) + node.body)
        return node

    def visit_RelUpdate(self, node):
        if not isinstance(node.op, (L.SetAdd, L.SetRemove)):
            return node
        if node.rel not in self.rels:
            return node

        op_name = L.set_update_name(node.op)
        func_name = N.get_maint_func_name(self.result_var, node.rel, op_name)

        code = (node, )
        call_code = (L.Expr(L.Call(func_name, [L.Name(node.elem)])), )
        code = L.insert_rel_maint(code, call_code, node.op)
        return code

    def visit_RelClear(self, node):
        if node.rel not in self.rels:
            return node

        code = (node, )
        clear_code = (L.RelClear(self.result_var), )
        code = L.insert_rel_maint(code, clear_code, L.SetRemove())
        return code
Example #23
0
class VarsFinder(NodeVisitor):
    """Simple finder of variables (Name nodes).
    
    Flags:
    
        ignore_store:
            Name nodes with Store context are ignored, as are update
            operations (e.g. SetUpdate). As an exception, Name nodes
            on the LHS of Enumerators are not ignored. This is to
            ensure safety under pattern matching semantics.
        ignore_functions:
            Name nodes that appear to be functions are ignored.
        ignore_rels:
            Names that appear to be relations are ignored.
    
    The builtins None, True, and False are always excluded, as they
    are NameConstants, not variables.
    """
    def __init__(self,
                 *,
                 ignore_store=False,
                 ignore_functions=False,
                 ignore_rels=False):
        super().__init__()
        self.ignore_store = ignore_store
        self.ignore_functions = ignore_functions
        self.ignore_rels = ignore_rels

    def process(self, tree):
        self.usedvars = OrderedSet()
        super().process(tree)
        return self.usedvars

    def visit_Name(self, node):
        self.generic_visit(node)
        if not (self.ignore_store and isinstance(node.ctx, Store)):
            self.usedvars.add(node.id)

    def visit_Call(self, node):
        class IGNORE(iast.AST):
            _meta = True

        if isinstance(node.func, Name) and self.ignore_functions:
            self.generic_visit(node._replace(func=IGNORE()))

        else:
            self.generic_visit(node)

    def visit_Enumerator(self, node):
        if is_vartuple(node.target):
            # Bypass ignore_store.
            vars = get_vartuple(node.target)
            self.usedvars.update(vars)
        else:
            self.visit(node.target)

        if not (self.ignore_rels and isinstance(node.iter, Name)):
            self.visit(node.iter)

    def visit_Comp(self, node):
        self.visit(node.resexp)
        # Hack to ensure we don't grab rels on RHS of
        # membership conditions.
        for i in node.clauses:
            if (self.ignore_rels and isinstance(i, Compare)
                    and len(i.ops) == len(i.comparators) == 1
                    and isinstance(i.ops[0], In)
                    and isinstance(i.comparators[0], Name)):
                self.visit(i.left)
            else:
                self.visit(i)

    def visit_Aggregate(self, node):
        if isinstance(node.value, Name) and self.ignore_rels:
            return
        else:
            self.generic_visit(node)

    def visit_SetMatch(self, node):
        if isinstance(node.target, Name) and self.ignore_rels:
            self.visit(node.key)
        else:
            self.generic_visit(node)

    def update_helper(self, node):
        IGNORE = object()

        if isinstance(node.target, Name) and self.ignore_store:
            self.generic_visit(node._replace(target=IGNORE))

        else:
            self.generic_visit(node)

    visit_SetUpdate = update_helper
    visit_RCSetRefUpdate = update_helper
    visit_AssignKey = update_helper
    visit_DelKey = update_helper
Example #24
0
class IdentFinder(L.NodeVisitor):
    
    """Return an OrderedSet of all identifiers in the specified
    contexts.
    """
    
    fun_ctxs = ('Fun.name', 'Call.func')
    query_ctxs = ('Query.name', 'ResetDemand.names')
    rel_ctxs = ('RelUpdate.rel', 'RelClear.rel', 'RelMember.rel')
    
    @classmethod
    def find_functions(cls, tree):
        return cls().run(tree, contexts=cls.fun_ctxs)
    
    @classmethod
    def find_vars(cls, tree):
        ctxs = (cls.fun_ctxs + cls.query_ctxs)
        return cls().run(tree, contexts=ctxs, invert=True)
    
    @classmethod
    def find_non_rel_uses(cls, tree):
        ctxs = (cls.fun_ctxs + cls.query_ctxs + cls.rel_ctxs)
        return cls().run(tree, contexts=ctxs, invert=True)
    
    def __init__(self, contexts=None, invert=False):
        if contexts is not None:
            for c in contexts:
                node_name, field_name = c.split('.')
                if not field_name in L.ident_fields.get(node_name, []):
                    raise ValueError('Unknown identifier context "{}"'
                                     .format(c))
        
        self.contexts = contexts
        """Collection of contexts to include/exclude. Each context is
        a string of the form '<node type name>.<field name>'. A value
        of None is equivalent to specifying all contexts.
        """
        self.invert = bool(invert)
        """If True, find identifiers that occur in any context besides
        the ones given.
        """
    
    def process(self, tree):
        self.names = OrderedSet()
        super().process(tree)
        return self.names
    
    def generic_visit(self, node):
        super().generic_visit(node)
        clsname = node.__class__.__name__
        id_fields = L.ident_fields.get(clsname, [])
        for f in id_fields:
            inctx = (self.contexts is None or
                     clsname + '.' + f in self.contexts)
            if inctx != self.invert:
                # Normalize for either one id or a sequence of ids.
                ids = getattr(node, f)
                if isinstance(ids, str):
                    ids = [ids]
                if ids is not None:
                    self.names.update(ids)