Beispiel #1
0
    def get_code(self, cl, bindenv, body):
        vars = self.lhs_vars(cl)
        assert_unique(vars)
        mask = L.mask_from_bounds(vars, bindenv)

        comparison = L.Compare(L.Name(cl.tup), L.Eq(), L.tuplify(cl.elts))

        if L.mask_is_allbound(mask):
            code = (L.If(comparison, body, ()), )
            needs_typecheck = True

        elif mask.m.startswith('b'):
            elts_mask = L.mask_from_bounds(cl.elts, bindenv)
            code = L.bind_by_mask(elts_mask, cl.elts, L.Name(cl.tup))
            if L.mask_is_allunbound(elts_mask):
                code += body
            else:
                code += (L.If(comparison, body, ()), )
            needs_typecheck = True

        elif mask == L.mask('u' + 'b' * len(cl.elts)):
            code = (L.Assign(cl.tup, L.tuplify(cl.elts)), )
            code += body
            needs_typecheck = False

        else:
            raise L.TransformationError('Cannot emit code for TUP clause '
                                        'that would require an auxiliary '
                                        'map; use demand filtering')

        if needs_typecheck and self.use_typecheck:
            code = (L.If(L.HasArity(L.Name(cl.tup), len(cl.elts)), code, ()), )

        return code
Beispiel #2
0
 def visit_DecompFor(self, node):
     if (isinstance(node.iter, L.Name) and node.iter.id in self.rels):
         if len(node.vars) != 1:
             raise L.TransformationError(
                 'Singleton unwrapping requires all DecompFor loops '
                 'over relation to have exactly one target variable')
         return L.For(node.vars[0], node.iter, node.body)
     return node
Beispiel #3
0
def define_wrap_set(wrapinv, symtab):
    """Add a relation definition for a WrapInvariant."""
    opersym = symtab.get_relations().get(wrapinv.oper, None)
    if opersym is None:
        raise L.TransformationError('No relation "{}" matching wrapped/'
                                    'unwrapped relation "{}"'.format(
                                        wrapinv.oper, wrapinv.rel))

    rel_type = make_wrap_type(wrapinv, opersym.type)
    symtab.define_relation(wrapinv.rel, type=rel_type)
Beispiel #4
0
def define_set(setfrommap, symtab):
    """Add a relation definition to the symbol table."""
    # Obtain map symbol.
    mapsym = symtab.get_maps().get(setfrommap.map, None)
    if mapsym is None:
        raise L.TransformationError(
            'No map "{}" matching relation "{}"'.format(
                setfrommap.map, setfrommap.rel))
    rel_type = make_setfrommap_type(setfrommap.mask, mapsym.type)
    symtab.define_relation(setfrommap.rel, type=rel_type)
Beispiel #5
0
def define_map(auxmap, symtab):
    """Add a map definition to the symbol table."""
    # Obtain relation symbol.
    relsym = symtab.get_relations().get(auxmap.rel, None)
    if relsym is None:
        raise L.TransformationError(
            'No relation "{}" matching map "{}"'.format(
                auxmap.rel, auxmap.map))
    map_type = make_auxmap_type(auxmap, relsym.type)
    symtab.define_map(auxmap.map, type=map_type)
Beispiel #6
0
 def visit_Query(self, node):
     node = super().generic_visit(node)
     
     name = node.name
     this_occ = node.query
     
     # The first time we see a query, obtain the symbol, check
     # for consistency with the symbol, call the rewriter, and
     # update the symbol.
     if name not in self.queries_before:
         sym = self.symtab.get_queries().get(name, None)
         if sym is None:
             raise L.TransformationError('No symbol for query name "{}"'
                                         .format(name))
         if this_occ != sym.node:
             raise L.TransformationError(
                 'Inconsistent symbol and occurrence for query '
                 '"{}": {}, {}'.format(name, sym.node, this_occ))
         
         replacement = self.rewrite(sym, name, this_occ)
         self.queries_before[name] = this_occ
         self.queries_after[name] = replacement
         if replacement is not None and not self.expand:
             sym.node = replacement
     
     # Each subsequent time, check for consistency with the previous
     # occurrences, and reuse the replacement that was determined the
     # first time.
     else:
         prev_occ = self.queries_before[name]
         # Check for consistency with previous occurrences.
         if this_occ != prev_occ:
             raise L.TransformationError(
                 'Inconsistent occurrences for query "{}": {}, {}'
                 .format(name, prev_occ, this_occ))
         replacement = self.queries_after[name]
     
     if replacement is not None:
         if self.expand:
             node = replacement
         else:
             node = node._replace(query=replacement)
     return node
Beispiel #7
0
def aggrinv_from_query(symtab, query, result_var):
    """Determine the aggregate invariant info for a given query."""
    node = query.node

    assert isinstance(node, (L.Aggr, L.AggrRestr))
    oper = node.value
    op = node.op

    if isinstance(oper, L.Unwrap):
        unwrap = True
        oper = oper.value
    else:
        unwrap = False

    # Get rel, mask, and param info.
    if isinstance(oper, L.Name):
        rel = oper.id
        # Mask will be all-unbound, filled in below.
        mask = None
        params = ()
    elif (isinstance(oper, L.ImgLookup) and isinstance(oper.set, L.Name)):
        rel = oper.set.id
        mask = oper.mask
        params = oper.bounds
    else:
        raise L.ProgramError('Unknown aggregate form: {}'.format(node))

    # Lookup symbol, use type info to determine the relation's arity.
    t_rel = get_rel_type(symtab, rel)
    if not (isinstance(t_rel, T.Set) and isinstance(t_rel.elt, T.Tuple)):
        raise L.ProgramError(
            'Invalid type for aggregate operand: {}'.format(t_rel))
    arity = len(t_rel.elt.elts)

    if mask is None:
        mask = L.mask('u' * arity)
    else:
        # Confirm that this arity is consistent with the above mask.
        assert len(mask.m) == arity

    if isinstance(node, L.AggrRestr):
        # Check that the restriction parameters match the ImgLookup
        # parameters
        if node.params != params:
            raise L.TransformationError('AggrRestr params do not match '
                                        'ImgLookup params')
        if not isinstance(node.restr, L.Name):
            raise L.ProgramError('Bad AggrRestr restriction expr')
        restr = node.restr.id
    else:
        restr = None

    return AggrInvariant(result_var, op, rel, mask, unwrap, params, restr)
Beispiel #8
0
def order_clauses(clausevisitor, clauses):
    """Order clauses according to their reported priorities.
    
    Use a greedy heuristic: Choose the leftmost clause whose priority
    is best (lowest number). Return the ordered clauses as a list.
    """
    init = OrderState(clausevisitor, set(), [], OrderedSet(clauses))
    try:
        answer = Planner().get_greedy_answer(init)
    except ValueError:
        s = ', '.join(L.Parser.ts(cl) for cl in clauses)
        raise L.TransformationError(
            'No valid order found for clauses: {}'.format(s))
    return answer
Beispiel #9
0
def get_rel_type(symtab, rel):
    """Helper for returning a relation's element type."""
    # This helper is used below, but it should probably be refactored
    # into a general helper in the type subpackage.
    relsym = symtab.get_symbols().get(rel, None)
    if relsym is None:
        raise L.TransformationError(
            'No symbol info for operand relation {}'.format(rel))
    t_rel = relsym.type
    t_rel = t_rel.join(T.Set(T.Bottom))
    if not t_rel.issmaller(T.Set(T.Top)):
        raise L.ProgramError('Bad type for relation {}: {}'.format(rel, t_rel))
    # Treat Set<Bottom> as a set of singleton tuples.
    if t_rel.elt is T.Bottom:
        raise L.ProgramError(
            'Relation must have known tuple element type '
            'before it can be used in aggregate: {}'.format(rel))
    return t_rel
Beispiel #10
0
    def make_structs(self):
        """Populate the structures based on the comprehension."""
        assert len(self.structs) == 0
        ct = self.symtab.clausetools

        # Generate for each clause.
        for i, cl in enumerate(self.comp.clauses):
            if ct.kind(cl) is not Kind.Member:
                continue

            vars = ct.lhs_vars(cl)
            in_vars = ct.tagsin_lhs_vars(cl)
            out_vars = ct.tagsout_lhs_vars(cl)
            rel = ct.rhs_rel(cl)
            if rel is None:
                raise L.TransformationError('Cannot generate tags and filter '
                                            'for clause: {}'.format(cl))

            # Generate a filter for this clause.
            if i != 0:
                n = len(self.filters_by_rel[rel]) + 1
                name = self.get_filter_name(rel, n)
                preds = self.get_preds(i, in_vars)
                if ct.filter_needs_preds(cl) and len(preds) == 0:
                    raise L.ProgramError('No predecessors tags for filter '
                                         'over clause: {}'.format(
                                             L.Parser.ts(cl)))
                filter = Filter(i, name, cl, preds)
                self.add_struct(filter)
                filtered_cl = self.clause_over_filter[name]
            else:
                # First clause acts as its own filter; no new structure.
                filtered_cl = cl

            # Generate a tag for each variable on the LHS.
            for var in out_vars:
                n = len(self.tags_by_var[var]) + 1
                name = self.get_tag_name(var, n)
                tag = Tag(i, name, var, filtered_cl)
                self.add_struct(tag)
Beispiel #11
0
 def badnode(self, node):
     raise L.TransformationError('{} nodes should not be present when '
                                 'converting to the pair domain'.format(
                                 node.__class__.__name__))