Example #1
0
 def con_lhs_vars_from_comp(self, comp):
     """Return a tuple of constrained variables.
     
     An variable is constrained if its first occurrence in the query
     is in a constrained position.
     
     For cyclic object queries like
     
         {(x, y) for y in x for x in y},
     
     after translating into clauses over M, there are two possible
     sets of constrained vars: {x} and {y}. This function processes
     clauses left-to-right, so {y} will be chosen.
     """
     uncon = OrderedSet()
     con = OrderedSet()
     for cl in comp.clauses:
         # The new unconstrained vars are the ones that occur in
         # unconstrained positions and are not already known to be
         # constrained.
         uncon.update(v for v in self.uncon_vars(cl) if v not in con)
         # Vice versa for the constrained vars. Uncons are processed
         # first so that "x in x" makes x unconstrained if it hasn't
         # been seen before.
         con.update(v for v in self.con_lhs_vars(cl) if v not in uncon)
     return tuple(con)
Example #2
0
    def process(self, tree):
        self.auxmaps = OrderedSet()
        self.setfrommaps = OrderedSet()
        self.wraps = OrderedSet()

        super().process(tree)

        return self.auxmaps, self.setfrommaps, self.wraps
Example #3
0
def analyze_functions(tree, funcs, *, allow_recursion=False):
    """Produce a FunctionCallGraph. funcs is all the name of functions
    that are to be included in the graph. If allow_recursion is False
    and the call graph is cyclic among all_funcs, ProgramError is
    raised.
    """
    graph = FunctionCallGraph()

    for func in funcs:
        graph.calls_map[func] = OrderedSet()
        graph.calledby_map[func] = OrderedSet()

    class Visitor(L.NodeVisitor):
        def process(self, tree):
            self.current_func = None
            super().process(tree)

        def visit_Fun(self, node):
            if self.current_func is not None:
                # Don't analyze nested functions.
                self.generic_visit(node)
                return

            name = node.name
            if name not in funcs:
                return

            graph.param_map[name] = node.args
            graph.body_map[name] = node.body

            self.current_func = name
            self.generic_visit(node)
            self.current_func = None

        def visit_Call(self, node):
            self.generic_visit(node)

            source = self.current_func
            target = node.func
            if source is not None and target in funcs:
                graph.calls_map[source].add(target)
                graph.calledby_map[target].add(source)

    Visitor.run(tree)

    edges = [(x, y) for x, outedges in graph.calledby_map.items()
             for y in outedges]
    order, rem_funcs, _rem_edges = topsort_helper(funcs, edges)
    if len(rem_funcs) > 0 and not allow_recursion:
        raise ProgramError('Recursive functions found: ' +
                           str(get_cycle(funcs, edges)))
    graph.order = order

    return graph
Example #4
0
def flatten_retrievals(comp):
    """Flatten the retrievals in a Comp node. Return a triple of the new
    Comp node, an OrderedSet of the fields seen, and a bool indicating
    whether a map was seen.
    
    Field and map clauses are introduced immediately to the left of their
    first use (or for the result expression, at the end of the clause
    list).
    """
    # For map_namer, add a little extra fluff to reduce the liklihood
    # of us inadvertently creating ambiguous names.
    field_namer = lambda obj, field: obj + '_' + field
    map_namer = lambda map, key: 'm_' + map + '_k_' + key
    replacer = RetrievalReplacer(field_namer, map_namer)
    
    seen_fields = OrderedSet()
    seen_map = False
    seen_field_repls = OrderedSet()
    seen_map_repls = OrderedSet()
    
    def process(expr):
        """Rewrite any retrievals in the given expression. Return a pair
        of the new expression, and a list of new clauses to be added
        for any retrievals not already seen.
        """
        nonlocal seen_map
        new_expr = replacer.process(expr)
        new_field_repls = replacer.field_repls - seen_field_repls
        new_map_repls = replacer.map_repls - seen_map_repls
        new_clauses = []
        
        for repl in new_field_repls:
            obj, field, value = repl
            seen_fields.add(field)
            seen_field_repls.add(repl)
            new_cl = L.Enumerator(L.tuplify((obj, value), lval=True),
                                  L.ln(make_frel(field)))
            new_clauses.append(new_cl)
        
        for repl in new_map_repls:
            map, key, value = repl
            seen_map = True
            seen_map_repls.add(repl)
            new_cl = L.Enumerator(L.tuplify((map, key, value), lval=True),
                                  L.ln(make_maprel()))
            new_clauses.append(new_cl)
        
        return new_expr, new_clauses
    
    new_comp = L.rewrite_compclauses(comp, process)
    
    return new_comp, seen_fields, seen_map
Example #5
0
 def rhs_rels_from_comp(self, comp):
     rels = OrderedSet()
     for cl in comp.clauses:
         rel = self.rhs_rel(cl)
         if rel is not None:
             rels.add(rel)
     return tuple(rels)
def preprocess_var_decls(tree):
    """Eliminate global variable declarations of the form
    
        S = Set()
        M = Map()
    
    and return a list of pairs of variables names and type names (i.e.,
    'Set' or 'Map').
    """
    assert isinstance(tree, P.Module)
    pat = P.Assign([P.Name(P.PatVar('_VAR'), P.Wildcard())],
                   P.Call(P.Name(P.PatVar('_KIND'), P.Load()), [], [], None,
                          None))

    decls = OrderedSet()
    body = []
    changed = False
    for stmt in tree.body:
        match = P.match(pat, stmt)
        if match is not None:
            var, kind = match['_VAR'], match['_KIND']
            if kind not in ['Set', 'Map']:
                raise L.ProgramError(
                    'Unknown declaration initializer {}'.format(kind))
            decls.add((var, kind))
            changed = True
        else:
            body.append(stmt)

    if changed:
        tree = tree._replace(body=body)
    return tree, list(decls)
Example #7
0
 def __init__(self, fresh_names):
     super().__init__()
     self.fresh_names = fresh_names
     self.repls = OrderedDict()
     """Mapping from nodes to replacement variables."""
     self.sfm_invs = OrderedSet()
     """SetFromMap invariants."""
Example #8
0
    def __init__(self, fresh_vars, auxmaps, setfrommaps, wraps):
        super().__init__()
        self.fresh_vars = fresh_vars
        self.maint_funcs = OrderedSet()

        # Index over auxmaps for fast retrieval.
        self.auxmaps_by_rel = OrderedDict()
        self.auxmaps_by_relmask = OrderedDict()
        for auxmap in auxmaps:
            self.auxmaps_by_rel.setdefault(auxmap.rel, []).append(auxmap)
            # Index by relation, mask, and whether value is a singleton.
            # Not indexed by whether key is a singleton; if there are
            # two separate invariants that differ only on that, we may
            # end up only using one but maintaining both.
            stats = (auxmap.rel, auxmap.mask, auxmap.unwrap_value)
            self.auxmaps_by_relmask[stats] = auxmap

        # Index over setfrommaps.
        self.setfrommaps_by_map = sfm_by_map = OrderedDict()
        for sfm in setfrommaps:
            if sfm.map in sfm_by_map:
                raise L.ProgramError('Multiple SetFromMap invariants on '
                                     'same map {}'.format(sfm.map))
            sfm_by_map[sfm.map] = sfm

        self.wraps_by_rel = OrderedDict()
        for wrap in wraps:
            self.wraps_by_rel.setdefault(wrap.oper, []).append(wrap)
Example #9
0
 def __init__(self):
     self.symbols = OrderedDict()
     """Global symbols, in declaration order."""
     
     self.stats = SymbolTable.Stats()
     
     self.fresh_names = SimpleNamespace()
     self.fresh_names.vars = N.fresh_name_generator()
     self.fresh_names.queries = N.fresh_name_generator('Query{}')
     self.fresh_names.inline = N.fresh_name_generator('_i{}')
     
     self.ignored_queries = OrderedSet()
     """Names of queries that cannot or should not be processed."""
     
     self.maint_funcs = OrderedSet()
     """Names of inserted maintenance functions that can be inlined."""
Example #10
0
def without_duplicates(cost):
    """For a ProductCost, SumCost, or MinCost, return a version without
    repeated terms among the direct arguments.
    """
    assert isinstance(cost, (ProductCost, SumCost, MinCost))
    new_terms = OrderedSet(cost.terms)
    return cost._replace(terms=new_terms)
Example #11
0
 def lhs_vars_from_clauses(self, clauses):
     """Return a tuple of all LHS vars appearing in the given clauses, in
     order, without duplicates.
     """
     vars = OrderedSet()
     for cl in clauses:
         vars.update(self.lhs_vars(cl))
     return tuple(vars)
Example #12
0
    def process(self, tree):
        self.toplevel_funcs = OrderedSet()
        self.excluded_funcs = set()

        self.infunc = False
        super().process(tree)
        assert not self.infunc

        return self.toplevel_funcs - self.excluded_funcs
Example #13
0
def get_defined_functions(tree):
    """Find names of top-level functions in a block of code."""
    names = OrderedSet()

    class Finder(L.NodeVisitor):
        def visit_Fun(self, node):
            names.add(node.name)

    Finder.run(tree)
    return names
Example #14
0
def prefix_locals(tree, prefix, extra_boundvars):
    """Rename the local variables in a block of code with the given
    prefix. The extra_boundvars are treated as additional variables
    known to be bound even if they don't appear in a binding occurrence
    within tree.
    """
    localvars = OrderedSet(extra_boundvars)
    localvars.update(BindingFinder.run(tree))
    tree = Templater.run(tree, {v: prefix + v for v in localvars})
    return tree
Example #15
0
    def __init__(self, clausetools, fresh_vars, fresh_join_names, comp,
                 result_var, *, counted):
        super().__init__()
        self.clausetools = clausetools
        self.fresh_vars = fresh_vars
        self.fresh_join_names = fresh_join_names
        self.comp = comp
        self.result_var = result_var
        self.counted = counted

        self.rels = self.clausetools.rhs_rels_from_comp(self.comp)
        self.maint_funcs = OrderedSet()
Example #16
0
 def get_uncon_params(self):
     """Return a tuple of the unconstrained parameters. The U-set
     must at minimum contain these parameters.
     
     To find them, we traverse the clauses from left to right and
     add unconstrained parameters to the result set as they appear.
     This means that the clauses must be runnable in a left-to-right
     order; otherwise it is an error.
     
     If the query has cyclic constraints, there may be multiple
     possible minimal sets of parameters. The one corresponding to
     this left-to-right traversal is chosen.
     """
     result = ()
     supported = set()
     for cl in self.join.clauses:
         # Vars that have an occurrence in this clause that is
         # not constrained by the clause.
         if cl.kind is Clause.KIND_ENUM:
             uncon_occ = OrderedSet(
                 v for v, bindocc in zip(cl.enumlhs, cl.con_mask)
                   if not bindocc if v != '_')
         else:
             uncon_occ = OrderedSet(cl.vars)
         
         # Add each new unconstrained var to the result.
         # They must be parameters.
         new_uncons = uncon_occ - supported
         for v in new_uncons:
             if v in self.params:
                 if v not in result:
                     result += (v,)
             else:
                 raise AssertionError('Unconstrained var {} is not a '
                                      'parameter'.format(v))
         
         # Any enumvar of this clause is now supported/constrained.
         supported.update(cl.enumvars)
     
     return result
Example #17
0
    def process(self, tree):
        self.queries_with_usets = OrderedSet()
        """Outer queries, for which a demand set and a call to a demand
        function are added.
        """
        self.rewrite_cache = {}
        """Map from query name to rewritten AST."""
        self.demand_queries = set()
        """Set of names of demand queries that we introduced, which
        shouldn't be recursed into.
        """

        return super().process(tree)
Example #18
0
def simplify_min_of_sums(mincost):
    """For a min of sums, return a version of this cost where
    sums that dominate other sums are removed.
    """
    assert isinstance(mincost, MinCost)
    terms = mincost.terms
    assert all(isinstance(s, SumCost)
               for s in terms)
    assert all(isinstance(p, ProductCost)
               for s in terms for p in s.terms)
    
    terms = list(OrderedSet(mincost.terms))
    factorcounts = build_factor_counts([p for s in terms for p in s.terms])
    
    for sum1 in reversed(list(terms)):
        rest = OrderedSet(terms) - {sum1}
        for sum2 in rest:
            if all_products_dominated(sum2.terms, sum1.terms, factorcounts):
                terms.remove(sum1)
                break
    
    return mincost._replace(terms=terms)
Example #19
0
def order_clauses(clausevisitor, clauses):
    """Order clauses according to their reported priorities.
    
    Use a greedy heuristic: Choose the leftmost clause whose priority
    is best (lowest number). Return the ordered clauses as a list.
    """
    init = OrderState(clausevisitor, set(), [], OrderedSet(clauses))
    try:
        answer = Planner().get_greedy_answer(init)
    except ValueError:
        s = ', '.join(L.Parser.ts(cl) for cl in clauses)
        raise L.TransformationError(
            'No valid order found for clauses: {}'.format(s))
    return answer
Example #20
0
def simplify_sum_of_products(sumcost):
    """For a sum of products, return a version of this cost where
    products that are dominated by other products are removed.
    """
    assert isinstance(sumcost, SumCost)
    assert all(isinstance(p, ProductCost) for p in sumcost.terms)
    
    # A naive approach only keeps terms that are not dominated by any
    # other term. This would incorrectly remove two terms that are
    # dominated only by each other. Once a term is dominated, we remove
    # it from the set so it can't be used to dominate anything else.
    
    terms = list(OrderedSet(sumcost.terms))
    factorcounts = build_factor_counts(terms)
    
    # Go right-to-left so that we keep the left occurrence of distinct
    # tied terms. (Non-distinct tied terms are eliminated as duplicates
    # above.)
    for prod in reversed(list(terms)):
        rest = OrderedSet(terms) - {prod}
        if all_products_dominated([prod], rest, factorcounts):
            terms.remove(prod)
    
    return sumcost._replace(terms=terms)
Example #21
0
def trivial_simplify(cost):
    """For Sum and Min cost terms, rewrite them to eliminate duplicate
    entries. For Product and Sum, eliminate unit cost entries. Other
    costs are returned verbatim.
    """
    if not isinstance(cost, (Product, Sum, Min)):
        return cost
    terms = cost.terms

    if isinstance(cost, (Product, Sum)):
        terms = [t for t in terms if t != Unit()]

    if isinstance(cost, (Sum, Min)):
        terms = OrderedSet(terms)

    cost = cost._replace(terms=terms)
    return cost
Example #22
0
 def __init__(self, store, height_limit=None, unknown=Top, fixed_vars=None):
     super().__init__()
     self.store = store
     """Mapping from symbol names to inferred types.
     Each type may only change in a monotonically increasing way.
     """
     self.height_limit = height_limit
     """Maximum height of type terms in the store. None for no limit."""
     self.illtyped = OrderedSet()
     """Nodes where the well-typedness constraints are violated."""
     self.changed = True
     """True if the last call to process() updated the store
     (or if there was no call so far).
     """
     self.unknown = unknown
     """Type to use for variables not in the store. Should be Bottom,
     Top, or None. None indicates that an error should be raised for
     unknown variables.
     """
     if fixed_vars is None:
         fixed_vars = []
     self.fixed_vars = fixed_vars
     """Names of variables whose types cannot be changed by inference."""
Example #23
0
 def enter(self):
     self._scope_stack.append(OrderedSet())
Example #24
0
 def process(self, tree):
     self.names = OrderedSet()
     super().process(tree)
     return self.names
Example #25
0
 def process(self, tree):
     self.vars = OrderedSet()
     self.write_ctx = False
     super().process(tree)
     return self.vars
Example #26
0
 def unique_helper(self, cost):
     terms = OrderedSet(cost.terms)
     return cost._replace(terms=terms)
Example #27
0
 def process(self, tree):
     self.use_mset = False
     self.fields = OrderedSet()
     self.use_mapset = False
     tree = super().process(tree)
     return tree, self.use_mset, self.fields, self.use_mapset
Example #28
0
 def process(self, tree):
     self.inited = OrderedSet()
     self.disqual = OrderedSet()
     super().process(tree)
     return self.inited - self.disqual
Example #29
0
 def bvars_from_scopestack(scope_stack):
     """Return an OrderedSet of all variables bound by some
     entry in a scope stack.
     """
     return OrderedSet(chain(*scope_stack))
Example #30
0
 def process(self, tree):
     self.usedvars = OrderedSet()
     super().process(tree)
     return self.usedvars