コード例 #1
0
ファイル: tupletrans.py プロジェクト: dcharkes/incoq
class TupleFlattener(L.NodeTransformer):
    def __init__(self, tupvar_namer):
        super().__init__()
        self.tupvar_namer = tupvar_namer
        self.trels = OrderedSet()

    def process(self, tree):
        self.new_clauses = []
        tree = super().process(tree)
        return tree, self.new_clauses

    def visit_Enumerator(self, node):
        # If LHS is a tuple, skip over the top level.
        # Either way, don't descend into RHS.
        if isinstance(node.target, L.Tuple):
            elts = self.visit(node.target.elts)
            new_target = node.target._replace(elts=elts)
            return node._replace(target=new_target)
        else:
            new_target = self.generic_visit(node.target)
            return node._replace(target=new_target)

    def visit_Tuple(self, node):
        # No need to recurse, that's taken care of by the caller
        # of this visitor.
        tupvar = self.tupvar_namer.next()
        arity = len(node.elts)
        trel = make_trel(arity)
        elts = (L.sn(tupvar), ) + node.elts
        new_cl = L.Enumerator(L.tuplify(elts, lval=True), L.ln(trel))
        self.new_clauses.append(new_cl)
        self.trels.add(trel)
        return L.sn(tupvar)
コード例 #2
0
ファイル: join.py プロジェクト: jieaozhu/dist_lang_reviews
 def rhs_rels_from_comp(self, comp):
     rels = OrderedSet()
     for cl in comp.clauses:
         rel = self.rhs_rel(cl)
         if rel is not None:
             rels.add(rel)
     return tuple(rels)
コード例 #3
0
class BindingFinder(L.NodeVisitor):
    """Return names of variables that appear in a binding context, i.e.,
    that would have a Store context in Python's AST.
    """
    def process(self, tree):
        self.vars = OrderedSet()
        self.write_ctx = False
        super().process(tree)
        return self.vars

    def visit_Fun(self, node):
        self.vars.update(node.args)
        self.generic_visit(node)

    def visit_For(self, node):
        self.vars.add(node.target)
        self.generic_visit(node)

    def visit_DecompFor(self, node):
        self.vars.update(node.vars)
        self.generic_visit(node)

    def visit_Assign(self, node):
        self.vars.add(node.target)
        self.generic_visit(node)

    def visit_DecompAssign(self, node):
        self.vars.update(node.vars)
        self.generic_visit(node)
コード例 #4
0
ファイル: tupletrans.py プロジェクト: IncOQ/incoq
class TupleFlattener(L.NodeTransformer):
    
    def __init__(self, tupvar_namer):
        super().__init__()
        self.tupvar_namer = tupvar_namer
        self.trels = OrderedSet()
    
    def process(self, tree):
        self.new_clauses = []
        tree = super().process(tree)
        return tree, self.new_clauses
    
    def visit_Enumerator(self, node):
        # If LHS is a tuple, skip over the top level.
        # Either way, don't descend into RHS.
        if isinstance(node.target, L.Tuple):
            elts = self.visit(node.target.elts)
            new_target = node.target._replace(elts=elts)
            return node._replace(target=new_target)
        else:
            new_target = self.generic_visit(node.target)
            return node._replace(target=new_target)
    
    def visit_Tuple(self, node):
        # No need to recurse, that's taken care of by the caller
        # of this visitor.
        tupvar = self.tupvar_namer.next()
        arity = len(node.elts)
        trel = make_trel(arity)
        elts = (L.sn(tupvar),) + node.elts
        new_cl = L.Enumerator(L.tuplify(elts, lval=True),
                              L.ln(trel))
        self.new_clauses.append(new_cl)
        self.trels.add(trel)
        return L.sn(tupvar)
コード例 #5
0
class IdentFinder(L.NodeVisitor):
    """Return an OrderedSet of all identifiers in the specified
    contexts.
    """

    fun_ctxs = ('Fun.name', 'Call.func')
    query_ctxs = ('Query.name', 'ResetDemand.names')
    rel_ctxs = ('RelUpdate.rel', 'RelClear.rel', 'RelMember.rel')

    @classmethod
    def find_functions(cls, tree):
        return cls().run(tree, contexts=cls.fun_ctxs)

    @classmethod
    def find_vars(cls, tree):
        ctxs = (cls.fun_ctxs + cls.query_ctxs)
        return cls().run(tree, contexts=ctxs, invert=True)

    @classmethod
    def find_non_rel_uses(cls, tree):
        ctxs = (cls.fun_ctxs + cls.query_ctxs + cls.rel_ctxs)
        return cls().run(tree, contexts=ctxs, invert=True)

    def __init__(self, contexts=None, invert=False):
        if contexts is not None:
            for c in contexts:
                node_name, field_name = c.split('.')
                if not field_name in L.ident_fields.get(node_name, []):
                    raise ValueError(
                        'Unknown identifier context "{}"'.format(c))

        self.contexts = contexts
        """Collection of contexts to include/exclude. Each context is
        a string of the form '<node type name>.<field name>'. A value
        of None is equivalent to specifying all contexts.
        """
        self.invert = bool(invert)
        """If True, find identifiers that occur in any context besides
        the ones given.
        """

    def process(self, tree):
        self.names = OrderedSet()
        super().process(tree)
        return self.names

    def generic_visit(self, node):
        super().generic_visit(node)
        clsname = node.__class__.__name__
        id_fields = L.ident_fields.get(clsname, [])
        for f in id_fields:
            inctx = (self.contexts is None
                     or clsname + '.' + f in self.contexts)
            if inctx != self.invert:
                # Normalize for either one id or a sequence of ids.
                ids = getattr(node, f)
                if isinstance(ids, str):
                    ids = [ids]
                if ids is not None:
                    self.names.update(ids)
コード例 #6
0
class RetrievalReplacer(L.NodeTransformer):
    
    """Replace simple field and map retrieval expressions with a
    variable. A retrieval expression is simple if the object or map
    part of the expression is just a variable. Raise an error if any
    non-simple retrievals are encountered.
    
    Retrievals are processed inner-to-outer, so complex expressions
    like a.b[c.d].e can be handled, so long as they are built up using
    only variables and retrievals.
    
    The name of the replacement variable is given by the field_namer
    and map_namer functions.
    
    Two attributes, field_repls and map_repls, are made available for
    inspecting what replacements were performed. They are OrderedSets
    of triples where the first two components are the object/map and
    field/key respectively, and the third component is the replacement
    variable name. These attributes are cleared when process() is called
    again.
    """
    
    def __init__(self, field_namer, map_namer):
        super().__init__()
        self.field_namer = field_namer
        self.map_namer = map_namer
    
    def process(self, tree):
        self.field_repls = OrderedSet()
        self.map_repls = OrderedSet()
        tree = super().process(tree)
        return tree
    
    def visit_Attribute(self, node):
        node = self.generic_visit(node)
        
        if not isinstance(node.value, L.Name):
            raise L.ProgramError('Non-simple field retrieval', node=node)
        obj = node.value.id
        field = node.attr
        
        new_name = self.field_namer(obj, field)
        self.field_repls.add((obj, field, new_name))
        
        return L.Name(new_name, node.ctx)
    
    def visit_Subscript(self, node):
        node = self.generic_visit(node)
        
        if not (isinstance(node.value, L.Name) and
                isinstance(node.slice, L.Index) and
                isinstance(node.slice.value, L.Name)):
            raise L.ProgramError('Non-simple map retrieval', node=node)
        map = node.value.id
        key = node.slice.value.id
        
        new_name = self.map_namer(map, key)
        self.map_repls.add((map, key, new_name))
        
        return L.Name(new_name, node.ctx)
コード例 #7
0
ファイル: join.py プロジェクト: jieaozhu/dist_lang_reviews
 def rhs_rels_from_comp(self, comp):
     rels = OrderedSet()
     for cl in comp.clauses:
         rel = self.rhs_rel(cl)
         if rel is not None:
             rels.add(rel)
     return tuple(rels)
コード例 #8
0
def preprocess_var_decls(tree):
    """Eliminate global variable declarations of the form
    
        S = Set()
        M = Map()
    
    and return a list of pairs of variables names and type names (i.e.,
    'Set' or 'Map').
    """
    assert isinstance(tree, P.Module)
    pat = P.Assign([P.Name(P.PatVar('_VAR'), P.Wildcard())],
                   P.Call(P.Name(P.PatVar('_KIND'), P.Load()), [], [], None,
                          None))

    decls = OrderedSet()
    body = []
    changed = False
    for stmt in tree.body:
        match = P.match(pat, stmt)
        if match is not None:
            var, kind = match['_VAR'], match['_KIND']
            if kind not in ['Set', 'Map']:
                raise L.ProgramError(
                    'Unknown declaration initializer {}'.format(kind))
            decls.add((var, kind))
            changed = True
        else:
            body.append(stmt)

    if changed:
        tree = tree._replace(body=body)
    return tree, list(decls)
コード例 #9
0
def preprocess_var_decls(tree):
    """Eliminate global variable declarations of the form
    
        S = Set()
        M = Map()
    
    and return a list of pairs of variables names and type names (i.e.,
    'Set' or 'Map').
    """
    assert isinstance(tree, P.Module)
    pat = P.Assign(
        [P.Name(P.PatVar("_VAR"), P.Wildcard())], P.Call(P.Name(P.PatVar("_KIND"), P.Load()), [], [], None, None)
    )

    decls = OrderedSet()
    body = []
    changed = False
    for stmt in tree.body:
        match = P.match(pat, stmt)
        if match is not None:
            var, kind = match["_VAR"], match["_KIND"]
            if kind not in ["Set", "Map"]:
                raise L.ProgramError("Unknown declaration initializer {}".format(kind))
            decls.add((var, kind))
            changed = True
        else:
            body.append(stmt)

    if changed:
        tree = tree._replace(body=body)
    return tree, list(decls)
コード例 #10
0
ファイル: tools.py プロジェクト: jieaozhu/dist_lang_reviews
class BindingFinder(L.NodeVisitor):
    
    """Return names of variables that appear in a binding context, i.e.,
    that would have a Store context in Python's AST.
    """
    
    def process(self, tree):
        self.vars = OrderedSet()
        self.write_ctx = False
        super().process(tree)
        return self.vars
    
    def visit_Fun(self, node):
        self.vars.update(node.args)
        self.generic_visit(node)
    
    def visit_For(self, node):
        self.vars.add(node.target)
        self.generic_visit(node)
    
    def visit_DecompFor(self, node):
        self.vars.update(node.vars)
        self.generic_visit(node)
    
    def visit_Assign(self, node):
        self.vars.add(node.target)
        self.generic_visit(node)
    
    def visit_DecompAssign(self, node):
        self.vars.update(node.vars)
        self.generic_visit(node)
コード例 #11
0
    def __init__(self, fresh_vars, auxmaps, setfrommaps, wraps):
        super().__init__()
        self.fresh_vars = fresh_vars
        self.maint_funcs = OrderedSet()

        # Index over auxmaps for fast retrieval.
        self.auxmaps_by_rel = OrderedDict()
        self.auxmaps_by_relmask = OrderedDict()
        for auxmap in auxmaps:
            self.auxmaps_by_rel.setdefault(auxmap.rel, []).append(auxmap)
            # Index by relation, mask, and whether value is a singleton.
            # Not indexed by whether key is a singleton; if there are
            # two separate invariants that differ only on that, we may
            # end up only using one but maintaining both.
            stats = (auxmap.rel, auxmap.mask, auxmap.unwrap_value)
            self.auxmaps_by_relmask[stats] = auxmap

        # Index over setfrommaps.
        self.setfrommaps_by_map = sfm_by_map = OrderedDict()
        for sfm in setfrommaps:
            if sfm.map in sfm_by_map:
                raise L.ProgramError('Multiple SetFromMap invariants on '
                                     'same map {}'.format(sfm.map))
            sfm_by_map[sfm.map] = sfm

        self.wraps_by_rel = OrderedDict()
        for wrap in wraps:
            self.wraps_by_rel.setdefault(wrap.oper, []).append(wrap)
コード例 #12
0
ファイル: comp.py プロジェクト: jieaozhu/dist_lang_reviews
 def __init__(self, fresh_names):
     super().__init__()
     self.fresh_names = fresh_names
     self.repls = OrderedDict()
     """Mapping from nodes to replacement variables."""
     self.sfm_invs = OrderedSet()
     """SetFromMap invariants."""
コード例 #13
0
ファイル: trans.py プロジェクト: jieaozhu/dist_lang_reviews
 class Finder(L.NodeVisitor):
     def process(self, tree):
         self.syms = OrderedSet()
         super().process(tree)
         return self.syms
     def visit_Query(self, node):
         self.generic_visit(node)
         self.syms.add(symtab.get_symbols()[node.name])
コード例 #14
0
ファイル: trans.py プロジェクト: jieaozhu/dist_lang_reviews
 def process(self, tree):
     self.auxmaps = OrderedSet()
     self.setfrommaps = OrderedSet()
     self.wraps = OrderedSet()
     
     super().process(tree)
     
     return self.auxmaps, self.setfrommaps, self.wraps
コード例 #15
0
ファイル: join.py プロジェクト: jieaozhu/dist_lang_reviews
 def lhs_vars_from_clauses(self, clauses):
     """Return a tuple of all LHS vars appearing in the given clauses, in
     order, without duplicates.
     """
     vars = OrderedSet()
     for cl in clauses:
         vars.update(self.lhs_vars(cl))
     return tuple(vars)
コード例 #16
0
ファイル: join.py プロジェクト: jieaozhu/dist_lang_reviews
 def lhs_vars_from_clauses(self, clauses):
     """Return a tuple of all LHS vars appearing in the given clauses, in
     order, without duplicates.
     """
     vars = OrderedSet()
     for cl in clauses:
         vars.update(self.lhs_vars(cl))
     return tuple(vars)
コード例 #17
0
class AggrMaintainer(L.NodeTransformer):
    def __init__(self, fresh_vars, aggrinv):
        super().__init__()
        self.fresh_vars = fresh_vars
        self.aggrinv = aggrinv
        self.maint_funcs = OrderedSet()

    def visit_Module(self, node):
        node = self.generic_visit(node)

        fv = self.fresh_vars
        ops = [L.SetAdd(), L.SetRemove()]
        funcs = []
        for op in ops:
            func = make_aggr_oper_maint_func(fv, self.aggrinv, op)
            funcs.append(func)
        if self.aggrinv.uses_demand:
            for op in ops:
                func = make_aggr_restr_maint_func(fv, self.aggrinv, op)
                funcs.append(func)

        func_names = L.get_defined_functions(tuple(funcs))
        self.maint_funcs.update(func_names)

        node = node._replace(body=tuple(funcs) + node.body)
        return node

    def visit_RelUpdate(self, node):
        if not isinstance(node.op, (L.SetAdd, L.SetRemove)):
            return node

        if node.rel == self.aggrinv.rel:
            func = self.aggrinv.get_oper_maint_func_name(node.op)
            code = L.insert_rel_maint_call(node, func)
        elif self.aggrinv.uses_demand and node.rel == self.aggrinv.restr:
            func = self.aggrinv.get_restr_maint_func_name(node.op)
            code = L.insert_rel_maint_call(node, func)
        else:
            code = node

        return code

    def visit_RelClear(self, node):
        # We should clear if we are not using demand and our operand is
        # being cleared, or if we are using demand and our demand set is
        # being cleared.
        aggrinv = self.aggrinv
        uses_demand = aggrinv.uses_demand
        if uses_demand:
            should_clear = node.rel == aggrinv.restr
        else:
            should_clear = node.rel == aggrinv.rel
        if not should_clear:
            return node

        clear_code = (L.MapClear(self.aggrinv.map), )
        code = L.insert_rel_maint((node, ), clear_code, L.SetRemove())
        return code
コード例 #18
0
ファイル: trans.py プロジェクト: jieaozhu/dist_lang_reviews
class AggrMaintainer(L.NodeTransformer):
    def __init__(self, fresh_vars, aggrinv):
        super().__init__()
        self.fresh_vars = fresh_vars
        self.aggrinv = aggrinv
        self.maint_funcs = OrderedSet()

    def visit_Module(self, node):
        node = self.generic_visit(node)

        fv = self.fresh_vars
        ops = [L.SetAdd(), L.SetRemove()]
        funcs = []
        for op in ops:
            func = make_aggr_oper_maint_func(fv, self.aggrinv, op)
            funcs.append(func)
        if self.aggrinv.uses_demand:
            for op in ops:
                func = make_aggr_restr_maint_func(fv, self.aggrinv, op)
                funcs.append(func)

        func_names = L.get_defined_functions(tuple(funcs))
        self.maint_funcs.update(func_names)

        node = node._replace(body=tuple(funcs) + node.body)
        return node

    def visit_RelUpdate(self, node):
        if not isinstance(node.op, (L.SetAdd, L.SetRemove)):
            return node

        if node.rel == self.aggrinv.rel:
            func = self.aggrinv.get_oper_maint_func_name(node.op)
            code = L.insert_rel_maint_call(node, func)
        elif self.aggrinv.uses_demand and node.rel == self.aggrinv.restr:
            func = self.aggrinv.get_restr_maint_func_name(node.op)
            code = L.insert_rel_maint_call(node, func)
        else:
            code = node

        return code

    def visit_RelClear(self, node):
        # We should clear if we are not using demand and our operand is
        # being cleared, or if we are using demand and our demand set is
        # being cleared.
        aggrinv = self.aggrinv
        uses_demand = aggrinv.uses_demand
        if uses_demand:
            should_clear = node.rel == aggrinv.restr
        else:
            should_clear = node.rel == aggrinv.rel
        if not should_clear:
            return node

        clear_code = (L.MapClear(self.aggrinv.map),)
        code = L.insert_rel_maint((node,), clear_code, L.SetRemove())
        return code
コード例 #19
0
    class Finder(L.NodeVisitor):
        def process(self, tree):
            self.syms = OrderedSet()
            super().process(tree)
            return self.syms

        def visit_Query(self, node):
            self.generic_visit(node)
            self.syms.add(symtab.get_symbols()[node.name])
コード例 #20
0
    def process(self, tree):
        self.toplevel_funcs = OrderedSet()
        self.excluded_funcs = set()

        self.infunc = False
        super().process(tree)
        assert not self.infunc

        return self.toplevel_funcs - self.excluded_funcs
コード例 #21
0
def prefix_locals(tree, prefix, extra_boundvars):
    """Rename the local variables in a block of code with the given
    prefix. The extra_boundvars are treated as additional variables
    known to be bound even if they don't appear in a binding occurrence
    within tree.
    """
    localvars = OrderedSet(extra_boundvars)
    localvars.update(BindingFinder.run(tree))
    tree = Templater.run(tree, {v: prefix + v for v in localvars})
    return tree
コード例 #22
0
def prefix_locals(tree, prefix, extra_boundvars):
    """Rename the local variables in a block of code with the given
    prefix. The extra_boundvars are treated as additional variables
    known to be bound even if they don't appear in a binding occurrence
    within tree.
    """
    localvars = OrderedSet(extra_boundvars)
    localvars.update(BindingFinder.run(tree))
    tree = Templater.run(tree, {v: prefix + v for v in localvars})
    return tree
コード例 #23
0
ファイル: tupletrans.py プロジェクト: dcharkes/incoq
    class Flattener(L.QueryMapper):
        def process(self, tree):
            self.trels = OrderedSet()
            tree = super().process(tree)
            return tree, self.trels

        def map_Comp(self, node):
            new_comp, new_trels = flatten_tuples_comp(node)
            self.trels.update(new_trels)
            return new_comp
コード例 #24
0
def analyze_functions(tree, funcs, *, allow_recursion=False):
    """Produce a FunctionCallGraph. funcs is all the name of functions
    that are to be included in the graph. If allow_recursion is False
    and the call graph is cyclic among all_funcs, ProgramError is
    raised.
    """
    graph = FunctionCallGraph()

    for func in funcs:
        graph.calls_map[func] = OrderedSet()
        graph.calledby_map[func] = OrderedSet()

    class Visitor(L.NodeVisitor):
        def process(self, tree):
            self.current_func = None
            super().process(tree)

        def visit_Fun(self, node):
            if self.current_func is not None:
                # Don't analyze nested functions.
                self.generic_visit(node)
                return

            name = node.name
            if name not in funcs:
                return

            graph.param_map[name] = node.args
            graph.body_map[name] = node.body

            self.current_func = name
            self.generic_visit(node)
            self.current_func = None

        def visit_Call(self, node):
            self.generic_visit(node)

            source = self.current_func
            target = node.func
            if source is not None and target in funcs:
                graph.calls_map[source].add(target)
                graph.calledby_map[target].add(source)

    Visitor.run(tree)

    edges = [(x, y) for x, outedges in graph.calledby_map.items()
             for y in outedges]
    order, rem_funcs, _rem_edges = topsort_helper(funcs, edges)
    if len(rem_funcs) > 0 and not allow_recursion:
        raise ProgramError('Recursive functions found: ' +
                           str(get_cycle(funcs, edges)))
    graph.order = order

    return graph
コード例 #25
0
ファイル: tupletrans.py プロジェクト: IncOQ/incoq
 class Flattener(L.QueryMapper):
     
     def process(self, tree):
         self.trels = OrderedSet()
         tree = super().process(tree)
         return tree, self.trels
     
     def map_Comp(self, node):
         new_comp, new_trels = flatten_tuples_comp(node)
         self.trels.update(new_trels)
         return new_comp
コード例 #26
0
    def __init__(self, clausetools, fresh_vars, fresh_join_names, comp,
                 result_var, *, counted):
        super().__init__()
        self.clausetools = clausetools
        self.fresh_vars = fresh_vars
        self.fresh_join_names = fresh_join_names
        self.comp = comp
        self.result_var = result_var
        self.counted = counted

        self.rels = self.clausetools.rhs_rels_from_comp(self.comp)
        self.maint_funcs = OrderedSet()
コード例 #27
0
def flatten_retrievals(comp):
    """Flatten the retrievals in a Comp node. Return a triple of the new
    Comp node, an OrderedSet of the fields seen, and a bool indicating
    whether a map was seen.
    
    Field and map clauses are introduced immediately to the left of their
    first use (or for the result expression, at the end of the clause
    list).
    """
    # For map_namer, add a little extra fluff to reduce the liklihood
    # of us inadvertently creating ambiguous names.
    field_namer = lambda obj, field: obj + '_' + field
    map_namer = lambda map, key: 'm_' + map + '_k_' + key
    replacer = RetrievalReplacer(field_namer, map_namer)
    
    seen_fields = OrderedSet()
    seen_map = False
    seen_field_repls = OrderedSet()
    seen_map_repls = OrderedSet()
    
    def process(expr):
        """Rewrite any retrievals in the given expression. Return a pair
        of the new expression, and a list of new clauses to be added
        for any retrievals not already seen.
        """
        nonlocal seen_map
        new_expr = replacer.process(expr)
        new_field_repls = replacer.field_repls - seen_field_repls
        new_map_repls = replacer.map_repls - seen_map_repls
        new_clauses = []
        
        for repl in new_field_repls:
            obj, field, value = repl
            seen_fields.add(field)
            seen_field_repls.add(repl)
            new_cl = L.Enumerator(L.tuplify((obj, value), lval=True),
                                  L.ln(make_frel(field)))
            new_clauses.append(new_cl)
        
        for repl in new_map_repls:
            map, key, value = repl
            seen_map = True
            seen_map_repls.add(repl)
            new_cl = L.Enumerator(L.tuplify((map, key, value), lval=True),
                                  L.ln(make_maprel()))
            new_clauses.append(new_cl)
        
        return new_expr, new_clauses
    
    new_comp = L.rewrite_compclauses(comp, process)
    
    return new_comp, seen_fields, seen_map
コード例 #28
0
ファイル: comp.py プロジェクト: jieaozhu/dist_lang_reviews
class AggrMapReplacer(L.NodeTransformer):
    
    """Replace all occurrences of map lookups with fresh variables.
    Return the transformed AST, a list of new clauses binding the
    fresh variables, and a set of SetFromMap invariants that need
    to be transformed.
    
    This is intended to be used on each of a comprehension's clauses and
    its result expression, reusing the same transformer instance.
    """
    
    def __init__(self, fresh_names):
        super().__init__()
        self.fresh_names = fresh_names
        self.repls = OrderedDict()
        """Mapping from nodes to replacement variables."""
        self.sfm_invs = OrderedSet()
        """SetFromMap invariants."""
    
    def process(self, tree):
        self.new_clauses = []
        tree = super().process(tree)
        return tree, self.new_clauses, []
    
    def visit_DictLookup(self, node):
        node = self.generic_visit(node)
        
        # Only simple map lookups are allowed.
        assert isinstance(node.value, L.Name)
        assert L.is_tuple_of_names(node.key)
        assert node.default is None
        map = node.value.id
        keyvars = L.detuplify(node.key)
        
        var = self.repls.get(node, None)
        if var is None:
            mask = L.mapmask_from_len(len(keyvars))
            rel = N.SA_name(map, mask)
            
            # Create a fresh variable.
            self.repls[node] = var = next(self.fresh_names)
            
            # Construct a clause to bind it.
            vars = list(keyvars) + [var]
            new_clause = L.SetFromMapMember(vars, rel, map, mask)
            self.new_clauses.append(new_clause)
            
            # Construct a corresponding SetFromMap invariant.
            sfm = SetFromMapInvariant(rel, map, mask)
            self.sfm_invs.add(sfm)
        
        return L.Name(var)
コード例 #29
0
    def process(self, tree):
        self.queries_with_usets = OrderedSet()
        """Outer queries, for which a demand set and a call to a demand
        function are added.
        """
        self.rewrite_cache = {}
        """Map from query name to rewritten AST."""
        self.demand_queries = set()
        """Set of names of demand queries that we introduced, which
        shouldn't be recursed into.
        """

        return super().process(tree)
コード例 #30
0
ファイル: join.py プロジェクト: jieaozhu/dist_lang_reviews
 def con_lhs_vars_from_comp(self, comp):
     """Return a tuple of constrained variables.
     
     An variable is constrained if its first occurrence in the query
     is in a constrained position.
     
     For cyclic object queries like
     
         {(x, y) for y in x for x in y},
     
     after translating into clauses over M, there are two possible
     sets of constrained vars: {x} and {y}. This function processes
     clauses left-to-right, so {y} will be chosen.
     """
     uncon = OrderedSet()
     con = OrderedSet()
     for cl in comp.clauses:
         # The new unconstrained vars are the ones that occur in
         # unconstrained positions and are not already known to be
         # constrained.
         uncon.update(v for v in self.uncon_vars(cl) if v not in con)
         # Vice versa for the constrained vars. Uncons are processed
         # first so that "x in x" makes x unconstrained if it hasn't
         # been seen before.
         con.update(v for v in self.con_lhs_vars(cl) if v not in uncon)
     return tuple(con)
コード例 #31
0
ファイル: comp.py プロジェクト: jieaozhu/dist_lang_reviews
class AggrMapReplacer(L.NodeTransformer):
    """Replace all occurrences of map lookups with fresh variables.
    Return the transformed AST, a list of new clauses binding the
    fresh variables, and a set of SetFromMap invariants that need
    to be transformed.
    
    This is intended to be used on each of a comprehension's clauses and
    its result expression, reusing the same transformer instance.
    """
    def __init__(self, fresh_names):
        super().__init__()
        self.fresh_names = fresh_names
        self.repls = OrderedDict()
        """Mapping from nodes to replacement variables."""
        self.sfm_invs = OrderedSet()
        """SetFromMap invariants."""

    def process(self, tree):
        self.new_clauses = []
        tree = super().process(tree)
        return tree, self.new_clauses, []

    def visit_DictLookup(self, node):
        node = self.generic_visit(node)

        # Only simple map lookups are allowed.
        assert isinstance(node.value, L.Name)
        assert L.is_tuple_of_names(node.key)
        assert node.default is None
        map = node.value.id
        keyvars = L.detuplify(node.key)

        var = self.repls.get(node, None)
        if var is None:
            mask = L.mapmask_from_len(len(keyvars))
            rel = N.SA_name(map, mask)

            # Create a fresh variable.
            self.repls[node] = var = next(self.fresh_names)

            # Construct a clause to bind it.
            vars = list(keyvars) + [var]
            new_clause = L.SetFromMapMember(vars, rel, map, mask)
            self.new_clauses.append(new_clause)

            # Construct a corresponding SetFromMap invariant.
            sfm = SetFromMapInvariant(rel, map, mask)
            self.sfm_invs.add(sfm)

        return L.Name(var)
コード例 #32
0
ファイル: domaintrans.py プロジェクト: IncOQ/incoq
 class Flattener(L.QueryMapper):
     def process(self, tree):
         self.use_mset = False
         self.fields = OrderedSet()
         self.use_mapset = False
         tree = super().process(tree)
         return tree, self.use_mset, self.fields, self.use_mapset
     
     def map_Comp(self, node):
         new_comp, new_mset, new_fields, new_mapset = \
             flatten_comp(node, input_rels)
         self.use_mset |= new_mset
         self.fields.update(new_fields)
         self.use_mapset |= new_mapset
         return new_comp
コード例 #33
0
    class Flattener(L.QueryMapper):
        def process(self, tree):
            self.use_mset = False
            self.fields = OrderedSet()
            self.use_mapset = False
            tree = super().process(tree)
            return tree, self.use_mset, self.fields, self.use_mapset

        def map_Comp(self, node):
            new_comp, new_mset, new_fields, new_mapset = \
                flatten_comp(node, input_rels)
            self.use_mset |= new_mset
            self.fields.update(new_fields)
            self.use_mapset |= new_mapset
            return new_comp
コード例 #34
0
ファイル: comp.py プロジェクト: jieaozhu/dist_lang_reviews
class AggrMapCompRewriter(S.QueryRewriter):
    """Rewrite comprehension queries so that map lookups from aggregates
    are flattened into SetFromMap clauses. Return a pair of the new tree
    and a set of SetFromMap invariants that need to be transformed.
    """
    def process(self, tree):
        self.sfm_invs = OrderedSet()
        tree = super().process(tree)
        return tree, self.sfm_invs

    def rewrite_comp(self, symbol, name, comp):
        rewriter = AggrMapReplacer(self.symtab.fresh_names.vars)
        comp = L.rewrite_comp(comp, rewriter.process)
        self.sfm_invs.update(rewriter.sfm_invs)
        return comp
コード例 #35
0
def without_duplicates(cost):
    """For a ProductCost, SumCost, or MinCost, return a version without
    repeated terms among the direct arguments.
    """
    assert isinstance(cost, (ProductCost, SumCost, MinCost))
    new_terms = OrderedSet(cost.terms)
    return cost._replace(terms=new_terms)
コード例 #36
0
ファイル: trans.py プロジェクト: jieaozhu/dist_lang_reviews
 def __init__(self, fresh_vars, auxmaps, setfrommaps, wraps):
     super().__init__()
     self.fresh_vars = fresh_vars
     self.maint_funcs = OrderedSet()
     
     # Index over auxmaps for fast retrieval.
     self.auxmaps_by_rel = OrderedDict()
     self.auxmaps_by_relmask = OrderedDict()
     for auxmap in auxmaps:
         self.auxmaps_by_rel.setdefault(auxmap.rel, []).append(auxmap)
         # Index by relation, mask, and whether value is a singleton.
         # Not indexed by whether key is a singleton; if there are
         # two separate invariants that differ only on that, we may
         # end up only using one but maintaining both.
         stats = (auxmap.rel, auxmap.mask, auxmap.unwrap_value)
         self.auxmaps_by_relmask[stats] = auxmap
     
     # Index over setfrommaps.
     self.setfrommaps_by_map = sfm_by_map = OrderedDict()
     for sfm in setfrommaps:
         if sfm.map in sfm_by_map:
             raise L.ProgramError('Multiple SetFromMap invariants on '
                                  'same map {}'.format(sfm.map))
         sfm_by_map[sfm.map] = sfm
     
     self.wraps_by_rel = OrderedDict()
     for wrap in wraps:
         self.wraps_by_rel.setdefault(wrap.oper, []).append(wrap)
コード例 #37
0
 def __init__(self):
     self.symbols = OrderedDict()
     """Global symbols, in declaration order."""
     
     self.stats = SymbolTable.Stats()
     
     self.fresh_names = SimpleNamespace()
     self.fresh_names.vars = N.fresh_name_generator()
     self.fresh_names.queries = N.fresh_name_generator('Query{}')
     self.fresh_names.inline = N.fresh_name_generator('_i{}')
     
     self.ignored_queries = OrderedSet()
     """Names of queries that cannot or should not be processed."""
     
     self.maint_funcs = OrderedSet()
     """Names of inserted maintenance functions that can be inlined."""
コード例 #38
0
ファイル: comp.py プロジェクト: jieaozhu/dist_lang_reviews
 def __init__(self, fresh_names):
     super().__init__()
     self.fresh_names = fresh_names
     self.repls = OrderedDict()
     """Mapping from nodes to replacement variables."""
     self.sfm_invs = OrderedSet()
     """SetFromMap invariants."""
コード例 #39
0
ファイル: comp.py プロジェクト: jieaozhu/dist_lang_reviews
class AggrMapCompRewriter(S.QueryRewriter):
    
    """Rewrite comprehension queries so that map lookups from aggregates
    are flattened into SetFromMap clauses. Return a pair of the new tree
    and a set of SetFromMap invariants that need to be transformed.
    """
    
    def process(self, tree):
        self.sfm_invs = OrderedSet()
        tree = super().process(tree)
        return tree, self.sfm_invs
    
    def rewrite_comp(self, symbol, name, comp):
        rewriter = AggrMapReplacer(self.symtab.fresh_names.vars)
        comp = L.rewrite_comp(comp, rewriter.process)
        self.sfm_invs.update(rewriter.sfm_invs)
        return comp
コード例 #40
0
ファイル: inline.py プロジェクト: IncOQ/incoq
 def process(self, tree):
     self.toplevel_funcs = OrderedSet()
     self.excluded_funcs = set()
     
     self.infunc = False
     super().process(tree)
     assert not self.infunc
     
     return self.toplevel_funcs - self.excluded_funcs
コード例 #41
0
def get_defined_functions(tree):
    """Find names of top-level functions in a block of code."""
    names = OrderedSet()

    class Finder(L.NodeVisitor):
        def visit_Fun(self, node):
            names.add(node.name)

    Finder.run(tree)
    return names
コード例 #42
0
        def rewrite_comp(self, symbol, name, comp):
            ct = self.symtab.clausetools

            # Get our parameters and our subquery's parameters and don't
            # rewrite them away, to ensure we don't mess up anyone's
            # params attribute in the symbol table.
            subqueries = Finder.run(comp)
            all_params = OrderedSet.from_union(q.params for q in subqueries)
            all_params.update(symbol.params)
            comp = ct.rewrite_with_patterns(comp, all_params)

            return comp
コード例 #43
0
ファイル: join.py プロジェクト: jieaozhu/dist_lang_reviews
 def con_lhs_vars_from_comp(self, comp):
     """Return a tuple of constrained variables.
     
     An variable is constrained if its first occurrence in the query
     is in a constrained position.
     
     For cyclic object queries like
     
         {(x, y) for y in x for x in y},
     
     after translating into clauses over M, there are two possible
     sets of constrained vars: {x} and {y}. This function processes
     clauses left-to-right, so {y} will be chosen.
     """
     uncon = OrderedSet()
     con = OrderedSet()
     for cl in comp.clauses:
         # The new unconstrained vars are the ones that occur in
         # unconstrained positions and are not already known to be
         # constrained.
         uncon.update(v for v in self.uncon_vars(cl) if v not in con)
         # Vice versa for the constrained vars. Uncons are processed
         # first so that "x in x" makes x unconstrained if it hasn't
         # been seen before.
         con.update(v for v in self.con_lhs_vars(cl) if v not in uncon)
     return tuple(con)
コード例 #44
0
ファイル: compspec.py プロジェクト: dcharkes/incoq
 def get_uncon_params(self):
     """Return a tuple of the unconstrained parameters. The U-set
     must at minimum contain these parameters.
     
     To find them, we traverse the clauses from left to right and
     add unconstrained parameters to the result set as they appear.
     This means that the clauses must be runnable in a left-to-right
     order; otherwise it is an error.
     
     If the query has cyclic constraints, there may be multiple
     possible minimal sets of parameters. The one corresponding to
     this left-to-right traversal is chosen.
     """
     result = ()
     supported = set()
     for cl in self.join.clauses:
         # Vars that have an occurrence in this clause that is
         # not constrained by the clause.
         if cl.kind is Clause.KIND_ENUM:
             uncon_occ = OrderedSet(
                 v for v, bindocc in zip(cl.enumlhs, cl.con_mask)
                   if not bindocc if v != '_')
         else:
             uncon_occ = OrderedSet(cl.vars)
         
         # Add each new unconstrained var to the result.
         # They must be parameters.
         new_uncons = uncon_occ - supported
         for v in new_uncons:
             if v in self.params:
                 if v not in result:
                     result += (v,)
             else:
                 raise AssertionError('Unconstrained var {} is not a '
                                      'parameter'.format(v))
         
         # Any enumvar of this clause is now supported/constrained.
         supported.update(cl.enumvars)
     
     return result
コード例 #45
0
ファイル: trans.py プロジェクト: jieaozhu/dist_lang_reviews
 def __init__(self, clausetools, fresh_vars, fresh_join_names,
              comp, result_var, *,
              counted):
     super().__init__()
     self.clausetools = clausetools
     self.fresh_vars = fresh_vars
     self.fresh_join_names = fresh_join_names
     self.comp = comp
     self.result_var = result_var
     self.counted = counted
     
     self.rels = self.clausetools.rhs_rels_from_comp(self.comp)
     self.maint_funcs = OrderedSet()
コード例 #46
0
 def process(self, tree):
     self.queries_with_usets = OrderedSet()
     """Outer queries, for which a demand set and a call to a demand
     function are added.
     """
     self.rewrite_cache = {}
     """Map from query name to rewritten AST."""
     self.demand_queries = set()
     """Set of names of demand queries that we introduced, which
     shouldn't be recursed into.
     """
     
     return super().process(tree)
コード例 #47
0
def simplify_min_of_sums(mincost):
    """For a min of sums, return a version of this cost where
    sums that dominate other sums are removed.
    """
    assert isinstance(mincost, MinCost)
    terms = mincost.terms
    assert all(isinstance(s, SumCost)
               for s in terms)
    assert all(isinstance(p, ProductCost)
               for s in terms for p in s.terms)
    
    terms = list(OrderedSet(mincost.terms))
    factorcounts = build_factor_counts([p for s in terms for p in s.terms])
    
    for sum1 in reversed(list(terms)):
        rest = OrderedSet(terms) - {sum1}
        for sum2 in rest:
            if all_products_dominated(sum2.terms, sum1.terms, factorcounts):
                terms.remove(sum1)
                break
    
    return mincost._replace(terms=terms)
コード例 #48
0
ファイル: trans.py プロジェクト: jieaozhu/dist_lang_reviews
 def rewrite_comp(self, symbol, name, comp):
     ct = self.symtab.clausetools
     
     # Get our parameters and our subquery's parameters and don't
     # rewrite them away, to ensure we don't mess up anyone's
     # params attribute in the symbol table.
     subqueries = Finder.run(comp)
     all_params = OrderedSet.from_union(
                     q.params for q in subqueries)
     all_params.update(symbol.params)
     comp = ct.rewrite_with_patterns(comp, all_params)
     
     return comp
コード例 #49
0
    def process(self, tree):
        self.auxmaps = OrderedSet()
        self.setfrommaps = OrderedSet()
        self.wraps = OrderedSet()

        super().process(tree)

        return self.auxmaps, self.setfrommaps, self.wraps
コード例 #50
0
 def __init__(self, store, height_limit=None, unknown=Top, fixed_vars=None):
     super().__init__()
     self.store = store
     """Mapping from symbol names to inferred types.
     Each type may only change in a monotonically increasing way.
     """
     self.height_limit = height_limit
     """Maximum height of type terms in the store. None for no limit."""
     self.illtyped = OrderedSet()
     """Nodes where the well-typedness constraints are violated."""
     self.changed = True
     """True if the last call to process() updated the store
     (or if there was no call so far).
     """
     self.unknown = unknown
     """Type to use for variables not in the store. Should be Bottom,
     Top, or None. None indicates that an error should be raised for
     unknown variables.
     """
     if fixed_vars is None:
         fixed_vars = []
     self.fixed_vars = fixed_vars
     """Names of variables whose types cannot be changed by inference."""
コード例 #51
0
ファイル: auxmap.py プロジェクト: IncOQ/incoq
class RelmatchQueryFinder(L.NodeVisitor):
    
    """Return the set of auxmap specs that are used by some
    relmatch query or set-map lookup.
    """
    
    def process(self, tree):
        self.specs = OrderedSet()
        super().process(tree)
        return self.specs
    
    def visit_SetMatch(self, node):
        self.generic_visit(node)
        
        if is_relmatch(node):
            spec, _key = get_relmatch(node)
            self.specs.add(spec)
    
    def visit_SMLookup(self, node):
        self.generic_visit(node)
        
        if is_relsmlookup(node):
            spec, _key = get_relsmlookup(node)
            self.specs.add(spec)
コード例 #52
0
 def __init__(self, store, height_limit=None, unknown=Top,
              fixed_vars=None):
     super().__init__()
     self.store = store
     """Mapping from symbol names to inferred types.
     Each type may only change in a monotonically increasing way.
     """
     self.height_limit = height_limit
     """Maximum height of type terms in the store. None for no limit."""
     self.illtyped = OrderedSet()
     """Nodes where the well-typedness constraints are violated."""
     self.changed = True
     """True if the last call to process() updated the store
     (or if there was no call so far).
     """
     self.unknown = unknown
     """Type to use for variables not in the store. Should be Bottom,
     Top, or None. None indicates that an error should be raised for
     unknown variables.
     """
     if fixed_vars is None:
         fixed_vars = []
     self.fixed_vars = fixed_vars
     """Names of variables whose types cannot be changed by inference."""
コード例 #53
0
ファイル: tools.py プロジェクト: jieaozhu/dist_lang_reviews
 def process(self, tree):
     self.vars = OrderedSet()
     self.write_ctx = False
     super().process(tree)
     return self.vars
コード例 #54
0
ファイル: trans.py プロジェクト: jieaozhu/dist_lang_reviews
class InvariantTransformer(L.AdvNodeTransformer):
    
    """Insert maintenance functions, insert calls to these functions
    at updates, and replace expressions with uses of stored results.
    """
    
    # There can be at most one SetFromMap mask for a given map.
    # Multiple distinct masks would have to differ on their arity,
    # which would be a type error.
    
    def __init__(self, fresh_vars, auxmaps, setfrommaps, wraps):
        super().__init__()
        self.fresh_vars = fresh_vars
        self.maint_funcs = OrderedSet()
        
        # Index over auxmaps for fast retrieval.
        self.auxmaps_by_rel = OrderedDict()
        self.auxmaps_by_relmask = OrderedDict()
        for auxmap in auxmaps:
            self.auxmaps_by_rel.setdefault(auxmap.rel, []).append(auxmap)
            # Index by relation, mask, and whether value is a singleton.
            # Not indexed by whether key is a singleton; if there are
            # two separate invariants that differ only on that, we may
            # end up only using one but maintaining both.
            stats = (auxmap.rel, auxmap.mask, auxmap.unwrap_value)
            self.auxmaps_by_relmask[stats] = auxmap
        
        # Index over setfrommaps.
        self.setfrommaps_by_map = sfm_by_map = OrderedDict()
        for sfm in setfrommaps:
            if sfm.map in sfm_by_map:
                raise L.ProgramError('Multiple SetFromMap invariants on '
                                     'same map {}'.format(sfm.map))
            sfm_by_map[sfm.map] = sfm
        
        self.wraps_by_rel = OrderedDict()
        for wrap in wraps:
            self.wraps_by_rel.setdefault(wrap.oper, []).append(wrap)
    
    def in_loop_rhs(self):
        """Return True if the current node is the iter of a For node.
        Note that this checks the _visit_stack, and is sensitive to
        where in the recursive visiting we are.
        """
        stack = self._visit_stack
        if len(stack) < 2:
            return False
        _current, field, _index = stack[-1]
        parent, _field, _index = stack[-2]
        return isinstance(parent, L.For) and field == 'iter'
    
    def visit_Module(self, node):
        node = self.generic_visit(node)
        
        funcs = []
        for auxmaps in self.auxmaps_by_rel.values():
            for auxmap in auxmaps:
                for op in [L.SetAdd(), L.SetRemove()]:
                    func = make_auxmap_maint_func(self.fresh_vars, auxmap, op)
                    funcs.append(func)
        for sfm in self.setfrommaps_by_map.values():
            for op in ['assign', 'delete']:
                func = make_setfrommap_maint_func(self.fresh_vars, sfm, op)
                funcs.append(func)
        for wraps in self.wraps_by_rel.values():
            for wrap in wraps:
                for op in [L.SetAdd(), L.SetRemove()]:
                    func = make_wrap_maint_func(self.fresh_vars, wrap, op)
                    funcs.append(func)
        
        func_names = L.get_defined_functions(tuple(funcs))
        self.maint_funcs.update(func_names)
        
        node = node._replace(body=tuple(funcs) + node.body)
        return node
    
    def visit_RelUpdate(self, node):
        if not isinstance(node.op, (L.SetAdd, L.SetRemove)):
            return node
        
        code = (node,)
        
        auxmaps = self.auxmaps_by_rel.get(node.rel, set())
        for auxmap in auxmaps:
            func_name = auxmap.get_maint_func_name(node.op)
            call_code = (L.Expr(L.Call(func_name, [L.Name(node.elem)])),)
            code = L.insert_rel_maint(code, call_code, node.op)
        
        wraps = self.wraps_by_rel.get(node.rel, set())
        for wrap in wraps:
            func_name = wrap.get_maint_func_name(node.op)
            call_code = (L.Expr(L.Call(func_name, [L.Name(node.elem)])),)
            code = L.insert_rel_maint(code, call_code, node.op)
        
        return code
    
    def visit_RelClear(self, node):
        code = (node,)
        
        auxmaps = self.auxmaps_by_rel.get(node.rel, set())
        for auxmap in auxmaps:
            clear_code = (L.MapClear(auxmap.map),)
            code = L.insert_rel_maint(code, clear_code, L.SetRemove())
        
        wraps = self.wraps_by_rel.get(node.rel, set())
        for wrap in wraps:
            clear_code = (L.RelClear(wrap.rel),)
            code = L.insert_rel_maint(code, clear_code, L.SetRemove())
        
        return code
    
    def visit_MapAssign(self, node):
        sfm = self.setfrommaps_by_map.get(node.map, None)
        if sfm is None:
            return node
        
        code = (node,)
        func_name = sfm.get_maint_func_name('assign')
        call_code = (L.Expr(L.Call(func_name, [L.Name(node.key),
                                               L.Name(node.value)])),)
        code = L.insert_rel_maint(code, call_code, L.SetAdd())
        return code
    
    def visit_MapDelete(self, node):
        sfm = self.setfrommaps_by_map.get(node.map, None)
        if sfm is None:
            return node
        
        code = (node,)
        func_name = sfm.get_maint_func_name('delete')
        call_code = (L.Expr(L.Call(func_name, [L.Name(node.key)])),)
        code = L.insert_rel_maint(code, call_code, L.SetRemove())
        return code
    
    def visit_MapClear(self, node):
        sfm = self.setfrommaps_by_map.get(node.map, None)
        if sfm is None:
            return node
        
        code = (node,)
        clear_code = (L.RelClear(sfm.rel),)
        code = L.insert_rel_maint(code, clear_code, L.SetRemove())
        return code
    
    def imglookup_helper(self, node, *, in_unwrap, on_loop_rhs):
        """Return the replacement for an ImgLookup node, or None if
        no replacement is applicable.
        """
        if not isinstance(node.set, L.Name):
            return None
        rel = node.set.id
        
        stats = (rel, node.mask, in_unwrap)
        auxmap = self.auxmaps_by_relmask.get(stats, None)
        if auxmap is None:
            return None
        
        key = L.tuplify(node.bounds, unwrap=auxmap.unwrap_key)
        empty = L.Parser.pe('()') if on_loop_rhs else L.Parser.pe('Set()')
        return L.Parser.pe('_MAP[_KEY] if _KEY in _MAP else _EMPTY',
                           subst={'_MAP': auxmap.map,
                                  '_KEY': key,
                                  '_EMPTY': empty})
    
    def visit_ImgLookup(self, node):
        node = self.generic_visit(node)
        
        repl = self.imglookup_helper(node, in_unwrap=False,
                                     on_loop_rhs=self.in_loop_rhs())
        if repl is not None:
            node = repl
        
        return node
    
    def visit_SetFromMap(self, node):
        node = self.generic_visit(node)
        
        if not isinstance(node.map, L.Name):
            return node
        map = node.map.id
        
        sfm = self.setfrommaps_by_map.get(map, None)
        if sfm is None:
            return node
        
        if not sfm.mask == node.mask:
            raise L.ProgramError('Multiple SetFromMap expressions on '
                                 'same map {}'.format(map))
        return L.Name(sfm.rel)
    
    def wrap_helper(self, node):
        """Process a wrap invariant at a Wrap or Unwrap node.
        Don't recurse.
        """
        if not isinstance(node.value, L.Name):
            return node
        rel = node.value.id
        
        wraps = self.wraps_by_rel.get(rel, [])
        for wrap in wraps:
            if ((isinstance(node, L.Wrap) and not wrap.unwrap) or
                (isinstance(node, L.Unwrap) and wrap.unwrap)):
                return L.Name(wrap.rel)
        
        return node
    
    def visit_Unwrap(self, node):
        # Handle special case of an auxmap with the unwrap_value flag.
        if isinstance(node.value, L.ImgLookup):
            # Recurse over children below the ImgLookup.
            value = self.generic_visit(node.value)
            node = node._replace(value=value)
            # See if an auxmap invariant applies.
            repl = self.imglookup_helper(value, in_unwrap=True,
                                         on_loop_rhs=self.in_loop_rhs())
            if repl is not None:
                return repl
        
        else:
            # Don't run in case we already did generic_visit() above
            # and didn't return.
            node = self.generic_visit(node)
        
        return self.wrap_helper(node)
    
    def visit_Wrap(self, node):
        node = self.generic_visit(node)
        node = self.wrap_helper(node)
        return node
コード例 #55
0
ファイル: domaintrans.py プロジェクト: IncOQ/incoq
 def process(self, tree):
     self.use_mset = False
     self.fields = OrderedSet()
     self.use_mapset = False
     tree = super().process(tree)
     return tree, self.use_mset, self.fields, self.use_mapset
コード例 #56
0
ファイル: tools.py プロジェクト: jieaozhu/dist_lang_reviews
class IdentFinder(L.NodeVisitor):
    
    """Return an OrderedSet of all identifiers in the specified
    contexts.
    """
    
    fun_ctxs = ('Fun.name', 'Call.func')
    query_ctxs = ('Query.name', 'ResetDemand.names')
    rel_ctxs = ('RelUpdate.rel', 'RelClear.rel', 'RelMember.rel')
    
    @classmethod
    def find_functions(cls, tree):
        return cls().run(tree, contexts=cls.fun_ctxs)
    
    @classmethod
    def find_vars(cls, tree):
        ctxs = (cls.fun_ctxs + cls.query_ctxs)
        return cls().run(tree, contexts=ctxs, invert=True)
    
    @classmethod
    def find_non_rel_uses(cls, tree):
        ctxs = (cls.fun_ctxs + cls.query_ctxs + cls.rel_ctxs)
        return cls().run(tree, contexts=ctxs, invert=True)
    
    def __init__(self, contexts=None, invert=False):
        if contexts is not None:
            for c in contexts:
                node_name, field_name = c.split('.')
                if not field_name in L.ident_fields.get(node_name, []):
                    raise ValueError('Unknown identifier context "{}"'
                                     .format(c))
        
        self.contexts = contexts
        """Collection of contexts to include/exclude. Each context is
        a string of the form '<node type name>.<field name>'. A value
        of None is equivalent to specifying all contexts.
        """
        self.invert = bool(invert)
        """If True, find identifiers that occur in any context besides
        the ones given.
        """
    
    def process(self, tree):
        self.names = OrderedSet()
        super().process(tree)
        return self.names
    
    def generic_visit(self, node):
        super().generic_visit(node)
        clsname = node.__class__.__name__
        id_fields = L.ident_fields.get(clsname, [])
        for f in id_fields:
            inctx = (self.contexts is None or
                     clsname + '.' + f in self.contexts)
            if inctx != self.invert:
                # Normalize for either one id or a sequence of ids.
                ids = getattr(node, f)
                if isinstance(ids, str):
                    ids = [ids]
                if ids is not None:
                    self.names.update(ids)
コード例 #57
0
ファイル: tools.py プロジェクト: jieaozhu/dist_lang_reviews
 def process(self, tree):
     self.names = OrderedSet()
     super().process(tree)
     return self.names
コード例 #58
0
ファイル: rewritings.py プロジェクト: IncOQ/incoq
class RelationFinder(L.NodeVisitor):
    
    """Find variables that we can statically infer to be relations,
    i.e. sets that are unaliased and top-level.
    
    For R to be inferred to be a relation, it must have a global-scope
    initialization having one of the following forms:
    
        R = Set()
        R = incoq.runtime.Set()
        R = set()
    
    and its only other occurrences must have the forms:
    
        - a SetUpdate naming R as the target
        
        - the RHS of membership clauses (including condition clauses)
        
        - the RHS of a For loop
    """
    
    def process(self, tree):
        self.inited = OrderedSet()
        self.disqual = OrderedSet()
        super().process(tree)
        return self.inited - self.disqual
    
    # Manage a toplevel flag to record whether we're at global scope.
    
    def visit_Module(self, node):
        self.toplevel = True
        self.generic_visit(node)
    
    def nontoplevel_helper(self, node):
        last = self.toplevel
        self.toplevel = False
        self.generic_visit(node)
        self.toplevel = last
    
    visit_FunctionDef = nontoplevel_helper
    visit_ClassDef = nontoplevel_helper
    
    def visit_Assign(self, node):
        allowed_inits = [
            L.pe('Set()'),
            L.pe('incoq.runtime.Set()'),
            L.pe('set()'),
        ]
        # If this is a relation initializer, mark the relation name
        # and don't recurse.
        if (self.toplevel and
            L.is_varassign(node)):
            name, value = L.get_varassign(node)
            if value in allowed_inits:
                self.inited.add(name)
                return
        
        self.generic_visit(node)
    
    def visit_SetUpdate(self, node):
        # Skip the target if it's just a name.
        if isinstance(node.target, L.Name):
            self.visit(node.elem)
        else:
            self.generic_visit(node)
    
    def visit_For(self, node):
        # Skip the iter if it's just a name.
        if isinstance(node.iter, L.Name):
            self.visit(node.target)
            self.visit(node.body)
            self.visit(node.orelse)
        else:
            self.generic_visit(node)
    
    def visit_Comp(self, node):
        # Skip the iter of each clause if it's just a name.
        # Also recognize condition clauses that express memberships.
        # Always skip the params and options.
        self.visit(node.resexp)
        for cl in node.clauses:
            if (isinstance(cl, L.Enumerator) and
                isinstance(cl.iter, L.Name)):
                self.visit(cl.target)
            elif (isinstance(cl, L.Compare) and
                  len(cl.ops) == len(cl.comparators) == 1 and
                  isinstance(cl.ops[0], L.In) and
                  isinstance(cl.comparators[0], L.Name)):
                self.visit(cl.left)
            else:
                self.visit(cl)
    
    def visit_Name(self, node):
        # We got here through some disallowed use of R.
        self.disqual.add(node.id)
コード例 #59
0
ファイル: rewritings.py プロジェクト: IncOQ/incoq
 def process(self, tree):
     self.inited = OrderedSet()
     self.disqual = OrderedSet()
     super().process(tree)
     return self.inited - self.disqual