Beispiel #1
0
 def visit_Query(self, node):
     self.generic_visit(node)
     if node.ann is not None:
         if not isinstance(node.ann, (dict, frozendict)):
             raise L.ProgramError('Query annotation must be a '
                                  'dictionary')
         for key in node.ann.keys():
             if key not in S.annotations:
                 raise L.ProgramError(
                     'Unknown annotation key "{}"'.format(key))
Beispiel #2
0
def aggrinv_from_query(symtab, query, result_var):
    """Determine the aggregate invariant info for a given query."""
    node = query.node

    assert isinstance(node, (L.Aggr, L.AggrRestr))
    oper = node.value
    op = node.op

    if isinstance(oper, L.Unwrap):
        unwrap = True
        oper = oper.value
    else:
        unwrap = False

    # Get rel, mask, and param info.
    if isinstance(oper, L.Name):
        rel = oper.id
        # Mask will be all-unbound, filled in below.
        mask = None
        params = ()
    elif (isinstance(oper, L.ImgLookup) and isinstance(oper.set, L.Name)):
        rel = oper.set.id
        mask = oper.mask
        params = oper.bounds
    else:
        raise L.ProgramError('Unknown aggregate form: {}'.format(node))

    # Lookup symbol, use type info to determine the relation's arity.
    t_rel = get_rel_type(symtab, rel)
    if not (isinstance(t_rel, T.Set) and isinstance(t_rel.elt, T.Tuple)):
        raise L.ProgramError(
            'Invalid type for aggregate operand: {}'.format(t_rel))
    arity = len(t_rel.elt.elts)

    if mask is None:
        mask = L.mask('u' * arity)
    else:
        # Confirm that this arity is consistent with the above mask.
        assert len(mask.m) == arity

    if isinstance(node, L.AggrRestr):
        # Check that the restriction parameters match the ImgLookup
        # parameters
        if node.params != params:
            raise L.TransformationError('AggrRestr params do not match '
                                        'ImgLookup params')
        if not isinstance(node.restr, L.Name):
            raise L.ProgramError('Bad AggrRestr restriction expr')
        restr = node.restr.id
    else:
        restr = None

    return AggrInvariant(result_var, op, rel, mask, unwrap, params, restr)
 def visit_Call(self, node):
     node = self.generic_visit(node)
     if node.func == 'len':
         if not len(node.args) == 1:
             raise L.ProgramError('Expected one argument for len()')
         return L.Aggr(L.Count(), node.args[0])
     return node
Beispiel #4
0
 def define_symbol(self, name, kind, **kargs):
     """Define a new symbol of the given kind. Return the symbol."""
     if name in self.symbols:
         raise L.ProgramError('Symbol "{}" already defined'.format(name))
     sym = kind(name, **kargs)
     self.symbols[name] = sym
     return sym
Beispiel #5
0
def transform_auxmaps_stepper(tree, symtab):
    """Transform all set expressions we can find that are over Name
    nodes. Return the tree and whether any transformation was done.
    """
    auxmaps, setfrommaps, wraps = InvariantFinder.run(tree)
    if len(auxmaps) == len(setfrommaps) == len(wraps) == 0:
        return tree, False

    for auxmap in auxmaps:
        if auxmap.rel not in symtab.get_relations():
            raise L.ProgramError('Cannot make auxiliary map for image-set '
                                 'lookup over non-relation variable {}'.format(
                                     auxmap.rel))
    for auxmap in auxmaps:
        define_map(auxmap, symtab)
    for sfm in setfrommaps:
        define_set(sfm, symtab)
    for wrap in wraps:
        define_wrap_set(wrap, symtab)
    trans = InvariantTransformer(symtab.fresh_names.vars, auxmaps, setfrommaps,
                                 wraps)
    tree = trans.process(tree)
    symtab.maint_funcs.update(trans.maint_funcs)

    symtab.stats['auxmaps_transformed'] += len(auxmaps)

    return tree, True
Beispiel #6
0
    def __init__(self, fresh_vars, auxmaps, setfrommaps, wraps):
        super().__init__()
        self.fresh_vars = fresh_vars
        self.maint_funcs = OrderedSet()

        # Index over auxmaps for fast retrieval.
        self.auxmaps_by_rel = OrderedDict()
        self.auxmaps_by_relmask = OrderedDict()
        for auxmap in auxmaps:
            self.auxmaps_by_rel.setdefault(auxmap.rel, []).append(auxmap)
            # Index by relation, mask, and whether value is a singleton.
            # Not indexed by whether key is a singleton; if there are
            # two separate invariants that differ only on that, we may
            # end up only using one but maintaining both.
            stats = (auxmap.rel, auxmap.mask, auxmap.unwrap_value)
            self.auxmaps_by_relmask[stats] = auxmap

        # Index over setfrommaps.
        self.setfrommaps_by_map = sfm_by_map = OrderedDict()
        for sfm in setfrommaps:
            if sfm.map in sfm_by_map:
                raise L.ProgramError('Multiple SetFromMap invariants on '
                                     'same map {}'.format(sfm.map))
            sfm_by_map[sfm.map] = sfm

        self.wraps_by_rel = OrderedDict()
        for wrap in wraps:
            self.wraps_by_rel.setdefault(wrap.oper, []).append(wrap)
def preprocess_var_decls(tree):
    """Eliminate global variable declarations of the form
    
        S = Set()
        M = Map()
    
    and return a list of pairs of variables names and type names (i.e.,
    'Set' or 'Map').
    """
    assert isinstance(tree, P.Module)
    pat = P.Assign([P.Name(P.PatVar('_VAR'), P.Wildcard())],
                   P.Call(P.Name(P.PatVar('_KIND'), P.Load()), [], [], None,
                          None))

    decls = OrderedSet()
    body = []
    changed = False
    for stmt in tree.body:
        match = P.match(pat, stmt)
        if match is not None:
            var, kind = match['_VAR'], match['_KIND']
            if kind not in ['Set', 'Map']:
                raise L.ProgramError(
                    'Unknown declaration initializer {}'.format(kind))
            decls.add((var, kind))
            changed = True
        else:
            body.append(stmt)

    if changed:
        tree = tree._replace(body=body)
    return tree, list(decls)
Beispiel #8
0
 def apply_symconfig(self, name, info):
     """Given a symbol name and a key-value dictionary of symbol
     config attribute, apply the attributes.
     """
     if name not in self.symbols:
         raise L.ProgramError('No symbol "{}"'.format(name))
     sym = self.symbols[name]
     sym.parse_and_update(self, **info)
Beispiel #9
0
def get_rel_type(symtab, rel):
    """Helper for returning a relation's element type."""
    # This helper is used below, but it should probably be refactored
    # into a general helper in the type subpackage.
    relsym = symtab.get_symbols().get(rel, None)
    if relsym is None:
        raise L.TransformationError(
            'No symbol info for operand relation {}'.format(rel))
    t_rel = relsym.type
    t_rel = t_rel.join(T.Set(T.Bottom))
    if not t_rel.issmaller(T.Set(T.Top)):
        raise L.ProgramError('Bad type for relation {}: {}'.format(rel, t_rel))
    # Treat Set<Bottom> as a set of singleton tuples.
    if t_rel.elt is T.Bottom:
        raise L.ProgramError(
            'Relation must have known tuple element type '
            'before it can be used in aggregate: {}'.format(rel))
    return t_rel
 def rewrite_comp(self, symbol, name, comp):
     if symbol.clause_reorder is not None:
         indices = [i - 1 for i in symbol.clause_reorder]
         if not sorted(indices) == list(range(len(comp.clauses))):
             raise L.ProgramError(
                 'Bad clause_reorder list for query: {}, {}'.format(
                 name, symbol.clause_reorder))
         clauses = [comp.clauses[i] for i in indices]
         comp = comp._replace(clauses=clauses)
     return comp
Beispiel #11
0
 def process(self, tree):
     self.found = set()
     tree = super().process(tree)
     notfound = set(self.query_name_map.keys()) - self.found
     if self.strict and len(notfound) > 0:
         qstrs = [L.Parser.ts(query) for query in notfound]
         raise L.ProgramError(
             'No matching occurrence for queries: {}'.format(
                 ', '.join(qstrs)))
     return tree
Beispiel #12
0
    def process(expr):
        if not (isinstance(expr, L.Member) and isinstance(expr.iter, L.Name)
                and expr.iter.id in symtab.get_relations()):
            return expr, [], []
        target = expr.target
        rel = expr.iter.id

        if L.is_tuple_of_names(target):
            cl = L.RelMember(L.detuplify(target), rel)
        elif isinstance(target, L.Name):
            cl = L.VarsMember([target.id], L.Wrap(L.Name(rel)))
        else:
            raise L.ProgramError('Invalid clause over relation')

        return cl, [], []
Beispiel #13
0
    def visit_SetFromMap(self, node):
        node = self.generic_visit(node)

        if not isinstance(node.map, L.Name):
            return node
        map = node.map.id

        sfm = self.setfrommaps_by_map.get(map, None)
        if sfm is None:
            return node

        if not sfm.mask == node.mask:
            raise L.ProgramError('Multiple SetFromMap expressions on '
                                 'same map {}'.format(map))
        return L.Name(sfm.rel)
Beispiel #14
0
    def visit_Query(self, node):
        self.generic_visit(node)

        if self.ignore is not None and node.name in self.ignore:
            return

        if self.first:
            raise self.Done((node.name, node.query))

        # Otherwise...
        if node.name in self.queries:
            if node.query != self.queries[node.name]:
                raise L.ProgramError('Multiple inconsistent expressions for '
                                     'query {}'.format(node.name))
        else:
            self.queries[node.name] = node.query
Beispiel #15
0
    def determine_demand_params(self, node):
        # Skip if already processed.
        if node.name in self.processed:
            return

        symtab = self.symtab
        ct = symtab.clausetools
        query_sym = symtab.get_queries()[node.name]

        if query_sym.impl is S.Aux:
            # No demand for aux impl.
            demand_params = ()
            uses_demand = False

        elif isinstance(node.query, L.Comp):
            demand_params = determine_comp_demand_params(
                ct, node.query, query_sym.params, query_sym.demand_params,
                query_sym.demand_param_strat)
            uses_demand = len(demand_params) > 0

        elif isinstance(node.query, L.Aggr):
            # If the operand contains a demand-driven query, or if the
            # aggregate appears in a transformed comprehension, then the
            # aggregate must be demand-driven for all parameters.
            operand_uses_demand = False
            aggr_in_comp = self.get_left_clauses() is not None

            for q in find_nested_queries(node):
                inner_sym = symtab.get_queries()[q.name]
                if inner_sym.uses_demand:
                    operand_uses_demand = True

            if operand_uses_demand or aggr_in_comp:
                uses_demand = True
                demand_params = query_sym.params
            else:
                uses_demand = False
                demand_params = ()

        else:
            kind = query_sym.node.__class__.__name__
            raise L.ProgramError('No rule for analyzing parameters of '
                                 '{} query'.format(kind))

        query_sym.uses_demand = uses_demand
        query_sym.demand_params = demand_params
        self.processed.add(query_sym.name)
Beispiel #16
0
def make_demand_func(query):
    func = N.get_query_demand_func_name(query.name)
    uset = N.get_query_demand_set_name(query.name)

    maxsize = query.demand_set_maxsize

    if maxsize is None:
        code = L.Parser.ps('''
            def _FUNC(_elem):
                if _elem not in _U:
                    _U.reladd(_elem)
            ''',
                           subst={
                               '_FUNC': func,
                               '_U': uset
                           })
    elif maxsize == 1:
        code = L.Parser.ps('''
            def _FUNC(_elem):
                if _elem not in _U:
                    _U.relclear()
                    _U.reladd(_elem)
            ''',
                           subst={
                               '_FUNC': func,
                               '_U': uset
                           })
    elif not isinstance(maxsize, int) or maxsize <= 0:
        raise L.ProgramError('Invalid value for demand_set_maxsize')
    else:
        code = L.Parser.ps('''
            def _FUNC(_elem):
                if _elem not in _U:
                    while len(_U) >= _MAXSIZE:
                        _stale = _U.peek()
                        _U.relremove(_stale)
                    _U.reladd(_elem)
            ''',
                           subst={
                               '_FUNC': func,
                               '_U': uset,
                               '_MAXSIZE': L.Num(maxsize)
                           })

    return code
Beispiel #17
0
def transform_query(tree, symtab, query):
    assert query.impl is not S.Unspecified

    if isinstance(query.node, L.Comp):

        if query.impl == S.Normal:
            success = False

        elif query.impl == S.Aux:
            tree = transform_aux_comp_query(tree, symtab, query)
            success = True

        elif query.impl == S.Inc:
            tree = transform_comp_query(tree, symtab, query)
            success = True

        elif query.impl == S.Filtered:
            symtab.print('  Comp: ' + L.Parser.ts(query.node))
            tree = transform_comp_query_with_filtering(tree, symtab, query)
            success = True

        else:
            assert ()

    elif isinstance(query.node, (L.Aggr, L.AggrRestr)):

        if query.impl == S.Normal:
            success = False

        elif query.impl == S.Inc:
            result_var = N.A_name(query.name)
            tree = incrementalize_aggr(tree, symtab, query, result_var)
            # Transform any aggregate map lookups inside comprehensions,
            # including incrementalizing the SetFromMaps used in the
            # additional comp clauses.
            tree = transform_comps_with_maps(tree, symtab)
            success = True

        else:
            assert ()

    else:
        raise L.ProgramError('Unknown query kind: {}'.format(
            query.node.__class__.__name__))
    return tree, success
Beispiel #18
0
    def make_structs(self):
        """Populate the structures based on the comprehension."""
        assert len(self.structs) == 0
        ct = self.symtab.clausetools

        # Generate for each clause.
        for i, cl in enumerate(self.comp.clauses):
            if ct.kind(cl) is not Kind.Member:
                continue

            vars = ct.lhs_vars(cl)
            in_vars = ct.tagsin_lhs_vars(cl)
            out_vars = ct.tagsout_lhs_vars(cl)
            rel = ct.rhs_rel(cl)
            if rel is None:
                raise L.TransformationError('Cannot generate tags and filter '
                                            'for clause: {}'.format(cl))

            # Generate a filter for this clause.
            if i != 0:
                n = len(self.filters_by_rel[rel]) + 1
                name = self.get_filter_name(rel, n)
                preds = self.get_preds(i, in_vars)
                if ct.filter_needs_preds(cl) and len(preds) == 0:
                    raise L.ProgramError('No predecessors tags for filter '
                                         'over clause: {}'.format(
                                             L.Parser.ts(cl)))
                filter = Filter(i, name, cl, preds)
                self.add_struct(filter)
                filtered_cl = self.clause_over_filter[name]
            else:
                # First clause acts as its own filter; no new structure.
                filtered_cl = cl

            # Generate a tag for each variable on the LHS.
            for var in out_vars:
                n = len(self.tags_by_var[var]) + 1
                name = self.get_tag_name(var, n)
                tag = Tag(i, name, var, filtered_cl)
                self.add_struct(tag)
Beispiel #19
0
    def visit_Query(self, node):
        ct = self.symtab.clausetools
        querysym = self.symtab.get_queries()[node.name]
        cache = self.query_param_map

        # Analyze parameters from node.
        params = self.get_params(node)

        # If we've already analyzed this query, just confirm that this
        # occurrence's parameters match what we're expecting.
        if node.name in cache:
            if params != cache[node.name]:
                raise L.ProgramError('Inconsistent parameter info for query '
                                     '{}: {}, {}'.format(
                                         querysym.name, cache[node.name],
                                         params))

        # Otherwise, add to the cache and update the symbol.
        else:
            cache[node.name] = params
            querysym.params = params

        self.generic_visit(node)
Beispiel #20
0
 def visit_SetFromMap(self, node):
     raise L.ProgramError(
         'Invalid SetFromMap expression: {}'.format(node))
Beispiel #21
0
 def visit_ImgLookup(self, node):
     raise L.ProgramError(
         'Invalid ImgLookup expression: {}'.format(node))
Beispiel #22
0
 def __init__(self, *args, **kargs):
     if self.unwrap and self.mask.m.count('u') != 1:
         raise L.ProgramError('Aggregates of unwrap() expressions must '
                              'have arity 1')