def make_demand_func(query):
    func = N.get_query_demand_func_name(query.name)
    uset = N.get_query_demand_set_name(query.name)
    
    maxsize = query.demand_set_maxsize
    
    if maxsize is None:
        code = L.Parser.ps('''
            def _FUNC(_elem):
                if _elem not in _U:
                    _U.reladd(_elem)
            ''', subst={'_FUNC': func, '_U': uset})
    elif maxsize == 1:
        code = L.Parser.ps('''
            def _FUNC(_elem):
                if _elem not in _U:
                    _U.relclear()
                    _U.reladd(_elem)
            ''', subst={'_FUNC': func, '_U': uset})
    elif not isinstance(maxsize, int) or maxsize <= 0:
        raise L.ProgramError('Invalid value for demand_set_maxsize')
    else:
        code = L.Parser.ps('''
            def _FUNC(_elem):
                if _elem not in _U:
                    while len(_U) >= _MAXSIZE:
                        _stale = _U.peek()
                        _U.relremove(_stale)
                    _U.reladd(_elem)
            ''', subst={'_FUNC': func, '_U': uset,
                        '_MAXSIZE': L.Num(maxsize)})
    
    return code
Example #2
0
 def change_rhs(self, cl, query_name):
     """Generate a new clause whose RHS rel is the name of a result
     set over the given query name.
     """
     ct = self.symtab.clausetools
     rel_name = N.get_resultset_name(query_name)
     return ct.rename_rhs_rel(cl, lambda x: rel_name)
Example #3
0
def make_comp_maint_func(clausetools, fresh_var_prefix, fresh_join_names, comp,
                         result_var, rel, op, *, counted):
    """Make maintenance function for a comprehension."""
    assert isinstance(op, (L.SetAdd, L.SetRemove))

    op_name = L.set_update_name(op)
    func_name = N.get_maint_func_name(result_var, rel, op_name)

    update = L.RelUpdate(rel, op, '_elem')
    code = clausetools.get_maint_code(fresh_var_prefix,
                                      fresh_join_names,
                                      comp,
                                      result_var,
                                      update,
                                      counted=counted)
    func = L.Parser.ps('''
        def _FUNC(_elem):
            _CODE
        ''',
                       subst={
                           '_FUNC': func_name,
                           '<c>_CODE': code
                       })

    return func
Example #4
0
 def visit_Unwrap(self, node):
     # Catch case where the immediate child is an ImgLookup, in which
     # case we can generate an AuxmapInvariant with the unwrap_value
     # flag set.
     if isinstance(node.value, L.ImgLookup):
         # Recurse over children below the ImgLookup.
         self.generic_visit(node.value)
         
         auxmap = self.imglookup_helper(node.value)
         if auxmap is not None:
             auxmap = auxmap._replace(unwrap_value=True)
             self.auxmaps.add(auxmap)
             return
     
     else:
         # Don't run in the case where we already did generic_visit()
         # above but failed to return.
         self.generic_visit(node)
     
     # Couldn't construct auxmap for ourselves + child;
     # treat this as normal unwrap.
     
     if not isinstance(node.value, L.Name):
         return
     oper = node.value.id
     
     rel = N.get_unwrap_name(oper)
     wrapinv = WrapInvariant(rel, oper, True)
     self.wraps.add(wrapinv)
Example #5
0
    def visit_DictLookup(self, node):
        node = self.generic_visit(node)

        # Only simple map lookups are allowed.
        assert isinstance(node.value, L.Name)
        assert L.is_tuple_of_names(node.key)
        assert node.default is None
        map = node.value.id
        keyvars = L.detuplify(node.key)

        var = self.repls.get(node, None)
        if var is None:
            mask = L.mapmask_from_len(len(keyvars))
            rel = N.SA_name(map, mask)

            # Create a fresh variable.
            self.repls[node] = var = next(self.fresh_names)

            # Construct a clause to bind it.
            vars = list(keyvars) + [var]
            new_clause = L.SetFromMapMember(vars, rel, map, mask)
            self.new_clauses.append(new_clause)

            # Construct a corresponding SetFromMap invariant.
            sfm = SetFromMapInvariant(rel, map, mask)
            self.sfm_invs.add(sfm)

        return L.Name(var)
Example #6
0
    def visit_Unwrap(self, node):
        # Catch case where the immediate child is an ImgLookup, in which
        # case we can generate an AuxmapInvariant with the unwrap_value
        # flag set.
        if isinstance(node.value, L.ImgLookup):
            # Recurse over children below the ImgLookup.
            self.generic_visit(node.value)

            auxmap = self.imglookup_helper(node.value)
            if auxmap is not None:
                auxmap = auxmap._replace(unwrap_value=True)
                self.auxmaps.add(auxmap)
                return

        else:
            # Don't run in the case where we already did generic_visit()
            # above but failed to return.
            self.generic_visit(node)

        # Couldn't construct auxmap for ourselves + child;
        # treat this as normal unwrap.

        if not isinstance(node.value, L.Name):
            return
        oper = node.value.id

        rel = N.get_unwrap_name(oper)
        wrapinv = WrapInvariant(rel, oper, True)
        self.wraps.add(wrapinv)
Example #7
0
def make_demand_query(symtab, query, left_clauses):
    """Create a demand query, update the query's demand_query attribute,
    and return the new demand query symbol.
    """
    ct = symtab.clausetools

    demquery_name = N.get_query_demand_query_name(query.name)

    demquery_tuple = L.tuplify(query.demand_params)
    demquery_tuple_type = symtab.analyze_expr_type(demquery_tuple)
    demquery_type = T.Set(demquery_tuple_type)

    demquery_comp = L.Comp(demquery_tuple, left_clauses)
    prefix = next(symtab.fresh_names.vars)
    demquery_comp = ct.comp_rename_lhs_vars(demquery_comp,
                                            lambda x: prefix + x)

    demquery_sym = symtab.define_query(demquery_name,
                                       type=demquery_type,
                                       node=demquery_comp,
                                       impl=query.impl)

    query.demand_query = demquery_name

    return demquery_sym
Example #8
0
 def change_rhs(self, cl, query_name):
     """Generate a new clause whose RHS rel is the name of a result
     set over the given query name.
     """
     ct = self.symtab.clausetools
     rel_name = N.get_resultset_name(query_name)
     return ct.rename_rhs_rel(cl, lambda x: rel_name)
Example #9
0
 def visit_AttrAssign(self, node):
     if node.attr not in self.objrels.Fs:
         return
     
     pair = L.Tuple([node.obj, node.value])
     var = next(self.fresh_vars)
     return (L.Assign(var, pair),
             L.RelUpdate(N.F(node.attr), L.SetAdd(), var))
Example #10
0
def make_demand_func(query):
    func = N.get_query_demand_func_name(query.name)
    uset = N.get_query_demand_set_name(query.name)

    maxsize = query.demand_set_maxsize

    if maxsize is None:
        code = L.Parser.ps('''
            def _FUNC(_elem):
                if _elem not in _U:
                    _U.reladd(_elem)
            ''',
                           subst={
                               '_FUNC': func,
                               '_U': uset
                           })
    elif maxsize == 1:
        code = L.Parser.ps('''
            def _FUNC(_elem):
                if _elem not in _U:
                    _U.relclear()
                    _U.reladd(_elem)
            ''',
                           subst={
                               '_FUNC': func,
                               '_U': uset
                           })
    elif not isinstance(maxsize, int) or maxsize <= 0:
        raise L.ProgramError('Invalid value for demand_set_maxsize')
    else:
        code = L.Parser.ps('''
            def _FUNC(_elem):
                if _elem not in _U:
                    while len(_U) >= _MAXSIZE:
                        _stale = _U.peek()
                        _U.relremove(_stale)
                    _U.reladd(_elem)
            ''',
                           subst={
                               '_FUNC': func,
                               '_U': uset,
                               '_MAXSIZE': L.Num(maxsize)
                           })

    return code
Example #11
0
 def visit_AttrDelete(self, node):
     if node.attr not in self.objrels.Fs:
         return
     
     lookup = L.Attribute(node.obj, node.attr)
     pair = L.Tuple([node.obj, lookup])
     var = next(self.fresh_vars)
     return (L.Assign(var, pair),
             L.RelUpdate(N.F(node.attr), L.SetRemove(), var))
Example #12
0
 def visit_Wrap(self, node):
     self.generic_visit(node)
     
     if not isinstance(node.value, L.Name):
         return
     oper = node.value.id
     
     rel = N.get_wrap_name(oper)
     wrapinv = WrapInvariant(rel, oper, False)
     self.wraps.add(wrapinv)
Example #13
0
    def visit_SetFromMap(self, node):
        self.generic_visit(node)

        if not isinstance(node.map, L.Name):
            return
        map = node.map.id

        rel = N.SA_name(map, node.mask)
        setfrommap = SetFromMapInvariant(rel, map, node.mask)
        self.setfrommaps.add(setfrommap)
Example #14
0
    def visit_Wrap(self, node):
        self.generic_visit(node)

        if not isinstance(node.value, L.Name):
            return
        oper = node.value.id

        rel = N.get_wrap_name(oper)
        wrapinv = WrapInvariant(rel, oper, False)
        self.wraps.add(wrapinv)
Example #15
0
 def visit_RelUpdate(self, node):
     if not isinstance(node.op, (L.SetAdd, L.SetRemove)):
         return node
     if node.rel not in self.rels:
         return node
     
     op_name = L.set_update_name(node.op)
     func_name = N.get_maint_func_name(self.result_var, node.rel, op_name)
     
     code = (node,)
     call_code = (L.Expr(L.Call(func_name, [L.Name(node.elem)])),)
     code = L.insert_rel_maint(code, call_code, node.op)
     return code
Example #16
0
    def visit_RelUpdate(self, node):
        if not isinstance(node.op, (L.SetAdd, L.SetRemove)):
            return node
        if node.rel not in self.rels:
            return node

        op_name = L.set_update_name(node.op)
        func_name = N.get_maint_func_name(self.result_var, node.rel, op_name)

        code = (node, )
        call_code = (L.Expr(L.Call(func_name, [L.Name(node.elem)])), )
        code = L.insert_rel_maint(code, call_code, node.op)
        return code
Example #17
0
 def imglookup_helper(self, node):
     """Create an AuxmapInvariant for this node if applicable.
     Return the invariant, or None if not applicable. Do not add
     the invariant yet.
     """
     if not isinstance(node.set, L.Name):
         return None
     rel = node.set.id
     
     map = N.get_auxmap_name(rel, node.mask)
     unwrap_key = len(node.bounds) == 1
     auxmap = AuxmapInvariant(map, rel, node.mask, unwrap_key, False)
     return auxmap
Example #18
0
    def imglookup_helper(self, node):
        """Create an AuxmapInvariant for this node if applicable.
        Return the invariant, or None if not applicable. Do not add
        the invariant yet.
        """
        if not isinstance(node.set, L.Name):
            return None
        rel = node.set.id

        map = N.get_auxmap_name(rel, node.mask)
        unwrap_key = len(node.bounds) == 1
        auxmap = AuxmapInvariant(map, rel, node.mask, unwrap_key, False)
        return auxmap
Example #19
0
def make_setfrommap_maint_func(fresh_vars, setfrommap: SetFromMapInvariant,
                               op: str):
    mask = setfrommap.mask
    nb = L.break_mapmask(mask)
    # Fresh variables for components of the key and value.
    key_vars = N.get_subnames('_key', nb)

    decomp_code = (L.DecompAssign(key_vars, L.Name('_key')), )

    vars = L.combine_by_mask(mask, key_vars, ['_val'])
    elem = L.tuplify(vars)
    fresh_var_prefix = next(fresh_vars)
    elem_var = fresh_var_prefix + '_elem'

    decomp_code += (L.Assign(elem_var, elem), )

    setopcls = {'assign': L.SetAdd, 'delete': L.SetRemove}[op]
    update_code = (L.RelUpdate(setfrommap.rel, setopcls(), elem_var), )

    func_name = setfrommap.get_maint_func_name(op)

    if op == 'assign':
        func = L.Parser.ps('''
            def _FUNC(_key, _val):
                _DECOMP
                _UPDATE
            ''',
                           subst={
                               '_FUNC': func_name,
                               '<c>_DECOMP': decomp_code,
                               '<c>_UPDATE': update_code
                           })
    elif op == 'delete':
        lookup_expr = L.DictLookup(L.Name(setfrommap.map), L.Name('_key'),
                                   None)
        func = L.Parser.ps('''
            def _FUNC(_key):
                _val = _LOOKUP
                _DECOMP
                _UPDATE
            ''',
                           subst={
                               '_FUNC': func_name,
                               '_LOOKUP': lookup_expr,
                               '<c>_DECOMP': decomp_code,
                               '<c>_UPDATE': update_code
                           })
    else:
        assert ()

    return func
Example #20
0
 def visit_IndefImgset(self, cost):
     # Check for constant-time relations.
     if cost.rel in const_rels:
         return Unit()
     
     # Field lookups are constant time.
     if N.is_F(cost.rel) and cost.mask == L.mask('bu'):
         return Unit()
     
     sym = symtab.get_symbols().get(cost.rel, None)
     if sym is None:
         return cost
     
     # Get types for unbound components.
     t = sym.type
     if t is None:
         return cost
     if not (isinstance(t, T.Set) and
             isinstance(t.elt, T.Tuple) and
             len(t.elt.elts) == len(cost.mask.m)):
         return cost
     
     mask = cost.mask
     elts = t.elt.elts
     # Process out aggregate SetFromMap result components,
     # which are functionally determined by the map keys.
     if N.is_SA(cost.rel) and mask.m[-1] == 'u':
         mask = mask._replace(m=mask.m[:-1])
         elts = elts[:-1]
     
     _b_elts, u_elts = L.split_by_mask(mask, elts)
     
     new_cost = type_to_cost(T.Tuple(u_elts))
     new_cost = normalize(new_cost)
     if not isinstance(new_cost, Unknown):
         cost = new_cost
     
     return cost
 def add_demand_function_call(self, query_sym, query_node, ann):
     """Return a Query node wrapped with a call to a demand function,
     if needed.
     """
     # Skip if there's no demand set associated with this query.
     if query_sym.name not in self.queries_with_usets:
         return query_node
     # Skip if we have a nodemand annotation.
     if ann is not None and ann.get('nodemand', False):
         return query_node
     
     demand_call = L.Call(N.get_query_demand_func_name(query_sym.name),
                          [L.tuplify(query_sym.demand_params)])
     return L.FirstThen(demand_call, query_node)
Example #22
0
    def add_demand_function_call(self, query_sym, query_node, ann):
        """Return a Query node wrapped with a call to a demand function,
        if needed.
        """
        # Skip if there's no demand set associated with this query.
        if query_sym.name not in self.queries_with_usets:
            return query_node
        # Skip if we have a nodemand annotation.
        if ann is not None and ann.get('nodemand', False):
            return query_node

        demand_call = L.Call(N.get_query_demand_func_name(query_sym.name),
                             [L.tuplify(query_sym.demand_params)])
        return L.FirstThen(demand_call, query_node)
Example #23
0
 def visit_RelUpdate(self, node):
     if isinstance(node.op, L.SetAdd):
         is_add = True
     elif isinstance(node.op, L.SetRemove):
         is_add = False
     else:
         return
     rel = node.rel
     elem = L.Name(node.elem)
     
     if N.is_M(rel):
         set_ = L.Subscript(elem, L.Num(0))
         value = L.Subscript(elem, L.Num(1))
         code = (L.SetUpdate(set_, node.op, value),)
     
     elif N.is_F(rel):
         attr = N.get_F(rel)
         obj = L.Subscript(elem, L.Num(0))
         value = L.Subscript(elem, L.Num(1))
         if is_add:
             code = (L.AttrAssign(obj, attr, value),)
         else:
             code = (L.AttrDelete(obj, attr),)
     
     elif N.is_MAP(rel):
         map = L.Subscript(elem, L.Num(0))
         key = L.Subscript(elem, L.Num(1))
         value = L.Subscript(elem, L.Num(2))
         if is_add:
             code = (L.DictAssign(map, key, value),)
         else:
             code = (L.DictDelete(map, key),)
     
     else:
         code = node
     
     return code
Example #24
0
    def visit_RelUpdate(self, node):
        if isinstance(node.op, L.SetAdd):
            is_add = True
        elif isinstance(node.op, L.SetRemove):
            is_add = False
        else:
            return
        rel = node.rel
        elem = L.Name(node.elem)

        if N.is_M(rel):
            set_ = L.Subscript(elem, L.Num(0))
            value = L.Subscript(elem, L.Num(1))
            code = (L.SetUpdate(set_, node.op, value),)

        elif N.is_F(rel):
            attr = N.get_F(rel)
            obj = L.Subscript(elem, L.Num(0))
            value = L.Subscript(elem, L.Num(1))
            if is_add:
                code = (L.AttrAssign(obj, attr, value),)
            else:
                code = (L.AttrDelete(obj, attr),)

        elif N.is_MAP(rel):
            map = L.Subscript(elem, L.Num(0))
            key = L.Subscript(elem, L.Num(1))
            value = L.Subscript(elem, L.Num(2))
            if is_add:
                code = (L.DictAssign(map, key, value),)
            else:
                code = (L.DictDelete(map, key),)

        else:
            code = node

        return code
Example #25
0
        def visit_IndefImgset(self, cost):
            # Check for constant-time relations.
            if cost.rel in const_rels:
                return Unit()

            # Field lookups are constant time.
            if N.is_F(cost.rel) and cost.mask == L.mask('bu'):
                return Unit()

            sym = symtab.get_symbols().get(cost.rel, None)
            if sym is None:
                return cost

            # Get types for unbound components.
            t = sym.type
            if t is None:
                return cost
            if not (isinstance(t, T.Set) and isinstance(t.elt, T.Tuple)
                    and len(t.elt.elts) == len(cost.mask.m)):
                return cost

            mask = cost.mask
            elts = t.elt.elts
            # Process out aggregate SetFromMap result components,
            # which are functionally determined by the map keys.
            if N.is_SA(cost.rel) and mask.m[-1] == 'u':
                mask = mask._replace(m=mask.m[:-1])
                elts = elts[:-1]

            _b_elts, u_elts = L.split_by_mask(mask, elts)

            new_cost = type_to_cost(T.Tuple(u_elts))
            new_cost = normalize(new_cost)
            if not isinstance(new_cost, Unknown):
                cost = new_cost

            return cost
Example #26
0
def transform_aux_comp_query(tree, symtab, query):
    """Transform a comprehension using the auxiliary map strategy.
    Create a compute function for it, and replace uses of the query
    with calls to the function.
    """
    ct = symtab.clausetools
    assert isinstance(query.node, L.Comp)
    tree = preprocess_comp(tree, symtab, query, rewrite_resexp=False)

    clauses = query.node.clauses

    func_name = N.get_compute_func_name(query.name)

    # Replace occurrences with calls to the compute function.
    class Rewriter(S.QueryRewriter):
        expand = True

        def rewrite_comp(self, symbol, name, comp):
            if name == query.name:
                return L.Call(func_name, [L.Name(p) for p in query.params])

    tree = Rewriter.run(tree, symtab)

    # Get code for running the clauses and adding to the result.
    clauses = order_clauses(ct, clauses)
    body = L.Parser.pc('''
        if _RESEXP not in _result:
            _result.add(_RESEXP)
        ''',
                       subst={'_RESEXP': query.node.resexp})
    compute_code = ct.get_code_for_clauses(clauses, query.params, body)

    # Define the compute function.
    compute_func = L.Parser.ps('''
        def _FUNC(_ARGS):
            _result = Set()
            _COMPUTE
            return _result
        ''',
                               subst={
                                   '_FUNC': func_name,
                                   '<c>_COMPUTE': compute_code
                               })
    compute_func = compute_func._replace(args=query.params)

    assert isinstance(tree, L.Module)
    tree = tree._replace(body=(compute_func, ) + tree.body)

    return tree
Example #27
0
def make_demand_set(symtab, query):
    """Create a demand set, update the query's demand_set attribute, and
    return the new demand set symbol.
    """
    uset_name = N.get_query_demand_set_name(query.name)
    uset_tuple = L.tuplify(query.demand_params)
    uset_tuple_type = symtab.analyze_expr_type(uset_tuple)
    uset_type = T.Set(uset_tuple_type)
    maxsize = query.demand_set_maxsize
    uset_lru = maxsize is not None and maxsize > 1
    uset_sym = symtab.define_relation(uset_name, type=uset_type, lru=uset_lru)

    query.demand_set = uset_name

    return uset_sym
def make_demand_set(symtab, query):
    """Create a demand set, update the query's demand_set attribute, and
    return the new demand set symbol.
    """
    uset_name = N.get_query_demand_set_name(query.name)
    uset_tuple = L.tuplify(query.demand_params)
    uset_tuple_type = symtab.analyze_expr_type(uset_tuple)
    uset_type = T.Set(uset_tuple_type)
    maxsize = query.demand_set_maxsize
    uset_lru = maxsize is not None and maxsize > 1
    uset_sym = symtab.define_relation(uset_name, type=uset_type,
                                      lru=uset_lru)
    
    query.demand_set = uset_name
    
    return uset_sym
Example #29
0
def make_setfrommap_maint_func(fresh_vars,
                               setfrommap: SetFromMapInvariant,
                               op: str):
    mask = setfrommap.mask
    nb = L.break_mapmask(mask)
    # Fresh variables for components of the key and value.
    key_vars = N.get_subnames('_key', nb)
    
    decomp_code = (L.DecompAssign(key_vars, L.Name('_key')),)
    
    vars = L.combine_by_mask(mask, key_vars, ['_val'])
    elem = L.tuplify(vars)
    fresh_var_prefix = next(fresh_vars)
    elem_var = fresh_var_prefix + '_elem'
    
    decomp_code += (L.Assign(elem_var, elem),)
    
    setopcls = {'assign': L.SetAdd,
                'delete': L.SetRemove}[op]
    update_code = (L.RelUpdate(setfrommap.rel, setopcls(), elem_var),)
    
    func_name = setfrommap.get_maint_func_name(op)
    
    if op == 'assign':
        func = L.Parser.ps('''
            def _FUNC(_key, _val):
                _DECOMP
                _UPDATE
            ''', subst={'_FUNC': func_name,
                        '<c>_DECOMP': decomp_code,
                        '<c>_UPDATE': update_code})
    elif op == 'delete':
        lookup_expr = L.DictLookup(L.Name(setfrommap.map),
                                   L.Name('_key'), None)
        func = L.Parser.ps('''
            def _FUNC(_key):
                _val = _LOOKUP
                _DECOMP
                _UPDATE
            ''', subst={'_FUNC': func_name,
                        '_LOOKUP': lookup_expr,
                        '<c>_DECOMP': decomp_code,
                        '<c>_UPDATE': update_code})
    else:
        assert()
    
    return func
Example #30
0
def make_auxmap_maint_func(fresh_vars, auxmap: AuxmapInvariant, op: L.setupop):
    """Make maintenance function for auxiliary map."""
    # Fresh variables for components of the element.
    vars = N.get_subnames('_elem', len(auxmap.mask.m))

    decomp_code = (L.DecompAssign(vars, L.Name('_elem')), )

    key, value = L.split_by_mask(auxmap.mask, vars)
    key = L.tuplify(key, unwrap=auxmap.unwrap_key)
    value = L.tuplify(value, unwrap=auxmap.unwrap_value)
    fresh_var_prefix = next(fresh_vars)
    key_var = fresh_var_prefix + '_key'
    value_var = fresh_var_prefix + '_value'

    decomp_code += L.Parser.pc('''
        _KEY_VAR = _KEY
        _VALUE_VAR = _VALUE
        ''',
                               subst={
                                   '_KEY_VAR': key_var,
                                   '_KEY': key,
                                   '_VALUE_VAR': value_var,
                                   '_VALUE': value
                               })

    img_func = {
        L.SetAdd: make_imgadd,
        L.SetRemove: make_imgremove
    }[op.__class__]
    img_code = img_func(fresh_vars, auxmap.map, key_var, value_var)

    func_name = auxmap.get_maint_func_name(op)

    func = L.Parser.ps('''
        def _FUNC(_elem):
            _DECOMP
            _IMGCODE
        ''',
                       subst={
                           '_FUNC': func_name,
                           '<c>_DECOMP': decomp_code,
                           '<c>_IMGCODE': img_code
                       })

    return func
Example #31
0
def transform_query(tree, symtab, query):
    assert query.impl is not S.Unspecified

    if isinstance(query.node, L.Comp):

        if query.impl == S.Normal:
            success = False

        elif query.impl == S.Aux:
            tree = transform_aux_comp_query(tree, symtab, query)
            success = True

        elif query.impl == S.Inc:
            tree = transform_comp_query(tree, symtab, query)
            success = True

        elif query.impl == S.Filtered:
            symtab.print('  Comp: ' + L.Parser.ts(query.node))
            tree = transform_comp_query_with_filtering(tree, symtab, query)
            success = True

        else:
            assert ()

    elif isinstance(query.node, (L.Aggr, L.AggrRestr)):

        if query.impl == S.Normal:
            success = False

        elif query.impl == S.Inc:
            result_var = N.A_name(query.name)
            tree = incrementalize_aggr(tree, symtab, query, result_var)
            # Transform any aggregate map lookups inside comprehensions,
            # including incrementalizing the SetFromMaps used in the
            # additional comp clauses.
            tree = transform_comps_with_maps(tree, symtab)
            success = True

        else:
            assert ()

    else:
        raise L.ProgramError('Unknown query kind: {}'.format(
            query.node.__class__.__name__))
    return tree, success
Example #32
0
def transform_aux_comp_query(tree, symtab, query):
    """Transform a comprehension using the auxiliary map strategy.
    Create a compute function for it, and replace uses of the query
    with calls to the function.
    """
    ct = symtab.clausetools
    assert isinstance(query.node, L.Comp)
    tree = preprocess_comp(tree, symtab, query, rewrite_resexp=False)
    
    clauses = query.node.clauses
    
    func_name = N.get_compute_func_name(query.name)
    
    # Replace occurrences with calls to the compute function.
    class Rewriter(S.QueryRewriter):
        expand = True
        def rewrite_comp(self, symbol, name, comp):
            if name == query.name:
                return L.Call(func_name, [L.Name(p) for p in query.params])
    tree = Rewriter.run(tree, symtab)
    
    # Get code for running the clauses and adding to the result.
    clauses = order_clauses(ct, clauses)
    body = L.Parser.pc('''
        if _RESEXP not in _result:
            _result.add(_RESEXP)
        ''', subst={'_RESEXP': query.node.resexp})
    compute_code = ct.get_code_for_clauses(clauses, query.params, body)
    
    # Define the compute function.
    compute_func = L.Parser.ps('''
        def _FUNC(_ARGS):
            _result = Set()
            _COMPUTE
            return _result
        ''', subst={'_FUNC': func_name,
                    '<c>_COMPUTE': compute_code})
    compute_func = compute_func._replace(args=query.params)
    
    assert isinstance(tree, L.Module)
    tree = tree._replace(body=(compute_func,) + tree.body)
    
    return tree
Example #33
0
def make_comp_maint_func(clausetools, fresh_var_prefix, fresh_join_names,
                         comp, result_var, rel, op, *,
                         counted):
    """Make maintenance function for a comprehension."""
    assert isinstance(op, (L.SetAdd, L.SetRemove))
    
    op_name = L.set_update_name(op)
    func_name = N.get_maint_func_name(result_var, rel, op_name)
    
    update = L.RelUpdate(rel, op, '_elem')
    code = clausetools.get_maint_code(fresh_var_prefix, fresh_join_names,
                                      comp, result_var, update,
                                      counted=counted)
    func = L.Parser.ps('''
        def _FUNC(_elem):
            _CODE
        ''', subst={'_FUNC': func_name,
                    '<c>_CODE': code})
    
    return func
def make_demand_query(symtab, query, left_clauses):
    """Create a demand query, update the query's demand_query attribute,
    and return the new demand query symbol.
    """
    ct = symtab.clausetools
    
    demquery_name = N.get_query_demand_query_name(query.name)
    
    demquery_tuple = L.tuplify(query.demand_params)
    demquery_tuple_type = symtab.analyze_expr_type(demquery_tuple)
    demquery_type = T.Set(demquery_tuple_type)
    
    demquery_comp = L.Comp(demquery_tuple, left_clauses)
    prefix = next(symtab.fresh_names.vars)
    demquery_comp = ct.comp_rename_lhs_vars(demquery_comp,
                                            lambda x: prefix + x)
    
    demquery_sym = symtab.define_query(demquery_name, type=demquery_type,
                                       node=demquery_comp, impl=query.impl)
    
    query.demand_query = demquery_name
    
    return demquery_sym
Example #35
0
def make_auxmap_maint_func(fresh_vars,
                           auxmap: AuxmapInvariant, op: L.setupop):
    """Make maintenance function for auxiliary map."""
    # Fresh variables for components of the element.
    vars = N.get_subnames('_elem', len(auxmap.mask.m))
    
    decomp_code = (L.DecompAssign(vars, L.Name('_elem')),)
    
    key, value = L.split_by_mask(auxmap.mask, vars)
    key = L.tuplify(key, unwrap=auxmap.unwrap_key)
    value = L.tuplify(value, unwrap=auxmap.unwrap_value)
    fresh_var_prefix = next(fresh_vars)
    key_var = fresh_var_prefix + '_key'
    value_var = fresh_var_prefix + '_value'
    
    decomp_code += L.Parser.pc('''
        _KEY_VAR = _KEY
        _VALUE_VAR = _VALUE
        ''', subst={'_KEY_VAR': key_var, '_KEY': key,
                    '_VALUE_VAR': value_var, '_VALUE': value})
    
    img_func = {L.SetAdd: make_imgadd,
                L.SetRemove: make_imgremove}[op.__class__]
    img_code = img_func(fresh_vars, auxmap.map, key_var, value_var)
    
    func_name = auxmap.get_maint_func_name(op)
    
    func = L.Parser.ps('''
        def _FUNC(_elem):
            _DECOMP
            _IMGCODE
        ''', subst={'_FUNC': func_name,
                    '<c>_DECOMP': decomp_code,
                    '<c>_IMGCODE': img_code})
    
    return func
Example #36
0
 def rhs_rel(self, cl):
     return N.TUP(len(cl.elts))
Example #37
0
 def rhs_rel(self, cl):
     return N.F(cl.attr)
Example #38
0
def make_aggr_restr_maint_func(fresh_vars, aggrinv, op):
    """Make the maintenance function for an aggregate invariant and
    an update to its restriction set.
    """
    assert isinstance(op, (L.SetAdd, L.SetRemove))
    assert aggrinv.uses_demand

    if isinstance(op, L.SetAdd):
        fresh_var_prefix = next(fresh_vars)
        value = fresh_var_prefix + "_value"
        state = fresh_var_prefix + "_state"
        keyvars = N.get_subnames("_key", len(aggrinv.params))

        decomp_key_code = (L.DecompAssign(keyvars, L.Name("_key")),)
        rellookup = L.ImgLookup(L.Name(aggrinv.rel), aggrinv.mask, keyvars)

        handler = aggrinv.get_handler()
        zero = handler.make_zero_expr()
        updatestate_code = handler.make_update_state_code(fresh_var_prefix, state, op, value)

        if aggrinv.unwrap:
            loop_template = """
                for (_VALUE,) in _RELLOOKUP:
                    _UPDATESTATE
                """
        else:
            loop_template = """
                for _VALUE in _RELLOOKUP:
                    _UPDATESTATE
                """

        loop_code = L.Parser.pc(
            loop_template, subst={"_VALUE": value, "_RELLOOKUP": rellookup, "<c>_UPDATESTATE": updatestate_code}
        )

        maint_code = L.Parser.pc(
            """
            _STATE = _ZERO
            _DECOMP_KEY
            _LOOP
            _MAP.mapassign(_KEY, _STATE)
            """,
            subst={
                "_MAP": aggrinv.map,
                "_KEY": "_key",
                "_STATE": state,
                "_ZERO": zero,
                "<c>_DECOMP_KEY": decomp_key_code,
                "<c>_LOOP": loop_code,
            },
        )

    else:
        maint_code = L.Parser.pc(
            """
            _MAP.mapdelete(_KEY)
            """,
            subst={"_MAP": aggrinv.map, "_KEY": "_key"},
        )

    func_name = aggrinv.get_restr_maint_func_name(op)

    func = L.Parser.ps(
        """
        def _FUNC(_key):
            _MAINT
        """,
        subst={"_FUNC": func_name, "<c>_MAINT": maint_code},
    )

    return func
Example #39
0
 def fresh_join_names():
     for join_name in N.fresh_name_generator(query.name + '_J{}'):
         used_join_names.append(join_name)
         join_prefixes[join_name] = current_prefix
         yield join_name
Example #40
0
def make_aggr_oper_maint_func(fresh_vars, aggrinv, op):
    """Make the maintenance function for an aggregate invariant and a
    given set update operation (add or remove) to the operand.
    """
    assert isinstance(op, (L.SetAdd, L.SetRemove))

    # Decompose the argument tuple into key and value components,
    # just like in auxmap.py.

    vars = N.get_subnames('_elem', len(aggrinv.mask.m))
    kvars, vvars = L.split_by_mask(aggrinv.mask, vars)
    ktuple = L.tuplify(kvars)
    vtuple = L.tuplify(vvars)
    fresh_var_prefix = next(fresh_vars)
    key = fresh_var_prefix + '_key'
    value = fresh_var_prefix + '_value'
    state = fresh_var_prefix + '_state'

    if aggrinv.unwrap:
        assert len(vvars) == 1
        value_expr = L.Name(vvars[0])
    else:
        value_expr = vtuple

    # Logic specific to aggregate operator.
    handler = aggrinv.get_handler()
    zero = handler.make_zero_expr()
    updatestate_code = handler.make_update_state_code(fresh_var_prefix, state,
                                                      op, value)

    subst = {
        '_KEY': key,
        '_KEY_EXPR': ktuple,
        '_VALUE': value,
        '_VALUE_EXPR': value_expr,
        '_MAP': aggrinv.map,
        '_STATE': state,
        '_ZERO': zero
    }

    if aggrinv.uses_demand:
        subst['_RESTR'] = aggrinv.restr
    else:
        # Empty conditions are only used when we don't have a
        # restriction set.
        subst['_EMPTY'] = handler.make_empty_cond(state)

    decomp_code = (L.DecompAssign(vars, L.Name('_elem')), )
    decomp_code += L.Parser.pc('''
        _KEY = _KEY_EXPR
        _VALUE = _VALUE_EXPR
        ''',
                               subst=subst)

    # Determine what kind of get/set state code to generate.
    if isinstance(op, L.SetAdd):
        definitely_preexists = aggrinv.uses_demand
        setstate_mayremove = False
    elif isinstance(op, L.SetRemove):
        definitely_preexists = True
        setstate_mayremove = not aggrinv.uses_demand
    else:
        assert ()

    if definitely_preexists:
        getstate_template = '_STATE = _MAP[_KEY]'
        delstate_template = '_MAP.mapdelete(_KEY)'
    else:
        getstate_template = '_STATE = _MAP.get(_KEY, _ZERO)'
        delstate_template = '''
            if _KEY in _MAP:
                _MAP.mapdelete(_KEY)
            '''

    if setstate_mayremove:
        setstate_template = '''
            if not _EMPTY:
                _MAP.mapassign(_KEY, _STATE)
            '''
    else:
        setstate_template = '_MAP.mapassign(_KEY, _STATE)'

    getstate_code = L.Parser.pc(getstate_template, subst=subst)
    delstate_code = L.Parser.pc(delstate_template, subst=subst)
    setstate_code = L.Parser.pc(setstate_template, subst=subst)

    maint_code = (getstate_code + updatestate_code + delstate_code +
                  setstate_code)

    # Guard in test if we have a restriction set.
    if aggrinv.uses_demand:
        maint_subst = dict(subst)
        maint_subst['<c>_MAINT'] = maint_code
        maint_code = L.Parser.pc('''
            if _KEY in _RESTR:
                _MAINT
            ''',
                                 subst=maint_subst)

    func_name = aggrinv.get_oper_maint_func_name(op)

    func = L.Parser.ps('''
        def _FUNC(_elem):
            _DECOMP
            _MAINT
        ''',
                       subst={
                           '_FUNC': func_name,
                           '<c>_DECOMP': decomp_code,
                           '<c>_MAINT': maint_code
                       })

    return func
Example #41
0
 def get_restr_maint_func_name(self, op):
     assert self.uses_demand
     op_name = L.set_update_name(op)
     return N.get_maint_func_name(self.map, self.restr, op_name)
Example #42
0
 def get_maint_func_name(self, op):
     op_name = L.set_update_name(op)
     return N.get_maint_func_name(self.rel, self.oper, op_name)
Example #43
0
 def get_maint_func_name(self, op):
     assert op in ['assign', 'delete']
     return N.get_maint_func_name(self.rel, self.map, op)
Example #44
0
def transform_firsthalf(tree, symtab, query):
    result_var = N.get_resultset_name(query.name)
    tree = preprocess_comp(tree, symtab, query)
    tree = incrementalize_comp(tree, symtab, query, result_var)
    return tree
Example #45
0
def make_aggr_oper_maint_func(fresh_vars, aggrinv, op):
    """Make the maintenance function for an aggregate invariant and a
    given set update operation (add or remove) to the operand.
    """
    assert isinstance(op, (L.SetAdd, L.SetRemove))

    # Decompose the argument tuple into key and value components,
    # just like in auxmap.py.

    vars = N.get_subnames("_elem", len(aggrinv.mask.m))
    kvars, vvars = L.split_by_mask(aggrinv.mask, vars)
    ktuple = L.tuplify(kvars)
    vtuple = L.tuplify(vvars)
    fresh_var_prefix = next(fresh_vars)
    key = fresh_var_prefix + "_key"
    value = fresh_var_prefix + "_value"
    state = fresh_var_prefix + "_state"

    if aggrinv.unwrap:
        assert len(vvars) == 1
        value_expr = L.Name(vvars[0])
    else:
        value_expr = vtuple

    # Logic specific to aggregate operator.
    handler = aggrinv.get_handler()
    zero = handler.make_zero_expr()
    updatestate_code = handler.make_update_state_code(fresh_var_prefix, state, op, value)

    subst = {
        "_KEY": key,
        "_KEY_EXPR": ktuple,
        "_VALUE": value,
        "_VALUE_EXPR": value_expr,
        "_MAP": aggrinv.map,
        "_STATE": state,
        "_ZERO": zero,
    }

    if aggrinv.uses_demand:
        subst["_RESTR"] = aggrinv.restr
    else:
        # Empty conditions are only used when we don't have a
        # restriction set.
        subst["_EMPTY"] = handler.make_empty_cond(state)

    decomp_code = (L.DecompAssign(vars, L.Name("_elem")),)
    decomp_code += L.Parser.pc(
        """
        _KEY = _KEY_EXPR
        _VALUE = _VALUE_EXPR
        """,
        subst=subst,
    )

    # Determine what kind of get/set state code to generate.
    if isinstance(op, L.SetAdd):
        definitely_preexists = aggrinv.uses_demand
        setstate_mayremove = False
    elif isinstance(op, L.SetRemove):
        definitely_preexists = True
        setstate_mayremove = not aggrinv.uses_demand
    else:
        assert ()

    if definitely_preexists:
        getstate_template = "_STATE = _MAP[_KEY]"
        delstate_template = "_MAP.mapdelete(_KEY)"
    else:
        getstate_template = "_STATE = _MAP.get(_KEY, _ZERO)"
        delstate_template = """
            if _KEY in _MAP:
                _MAP.mapdelete(_KEY)
            """

    if setstate_mayremove:
        setstate_template = """
            if not _EMPTY:
                _MAP.mapassign(_KEY, _STATE)
            """
    else:
        setstate_template = "_MAP.mapassign(_KEY, _STATE)"

    getstate_code = L.Parser.pc(getstate_template, subst=subst)
    delstate_code = L.Parser.pc(delstate_template, subst=subst)
    setstate_code = L.Parser.pc(setstate_template, subst=subst)

    maint_code = getstate_code + updatestate_code + delstate_code + setstate_code

    # Guard in test if we have a restriction set.
    if aggrinv.uses_demand:
        maint_subst = dict(subst)
        maint_subst["<c>_MAINT"] = maint_code
        maint_code = L.Parser.pc(
            """
            if _KEY in _RESTR:
                _MAINT
            """,
            subst=maint_subst,
        )

    func_name = aggrinv.get_oper_maint_func_name(op)

    func = L.Parser.ps(
        """
        def _FUNC(_elem):
            _DECOMP
            _MAINT
        """,
        subst={"_FUNC": func_name, "<c>_DECOMP": decomp_code, "<c>_MAINT": maint_code},
    )

    return func
Example #46
0
 def get_tag_name(self, var, n):
     return N.get_tag_name(self.query_name, var, n)
Example #47
0
 def get_maint_func_name(self, op):
     op_name = L.set_update_name(op)
     return N.get_maint_func_name(self.rel, self.oper, op_name)
Example #48
0
 def get_tag_name(self, var, n):
     return N.get_tag_name(self.query_name, var, n)
Example #49
0
def transform_firsthalf(tree, symtab, query):
    result_var = N.get_resultset_name(query.name)
    tree = preprocess_comp(tree, symtab, query)
    tree = incrementalize_comp(tree, symtab, query, result_var)
    return tree
Example #50
0
 def fresh_join_names():
     for join_name in N.fresh_name_generator(query.name + '_J{}'):
         used_join_names.append(join_name)
         join_prefixes[join_name] = current_prefix
         yield join_name
Example #51
0
 def get_maint_func_name(self, op):
     assert op in ['assign', 'delete']
     return N.get_maint_func_name(self.rel, self.map, op)
Example #52
0
 def get_restr_maint_func_name(self, op):
     assert self.uses_demand
     op_name = L.set_update_name(op)
     return N.get_maint_func_name(self.map, self.restr, op_name)
Example #53
0
 def get_filter_name(self, rel, n):
     return N.get_filter_name(self.query_name, rel, n)
Example #54
0
def make_aggr_restr_maint_func(fresh_vars, aggrinv, op):
    """Make the maintenance function for an aggregate invariant and
    an update to its restriction set.
    """
    assert isinstance(op, (L.SetAdd, L.SetRemove))
    assert aggrinv.uses_demand

    if isinstance(op, L.SetAdd):
        fresh_var_prefix = next(fresh_vars)
        value = fresh_var_prefix + '_value'
        state = fresh_var_prefix + '_state'
        keyvars = N.get_subnames('_key', len(aggrinv.params))

        decomp_key_code = (L.DecompAssign(keyvars, L.Name('_key')), )
        rellookup = L.ImgLookup(L.Name(aggrinv.rel), aggrinv.mask, keyvars)

        handler = aggrinv.get_handler()
        zero = handler.make_zero_expr()
        updatestate_code = handler.make_update_state_code(
            fresh_var_prefix, state, op, value)

        if aggrinv.unwrap:
            loop_template = '''
                for (_VALUE,) in _RELLOOKUP:
                    _UPDATESTATE
                '''
        else:
            loop_template = '''
                for _VALUE in _RELLOOKUP:
                    _UPDATESTATE
                '''

        loop_code = L.Parser.pc(loop_template,
                                subst={
                                    '_VALUE': value,
                                    '_RELLOOKUP': rellookup,
                                    '<c>_UPDATESTATE': updatestate_code
                                })

        maint_code = L.Parser.pc('''
            _STATE = _ZERO
            _DECOMP_KEY
            _LOOP
            _MAP.mapassign(_KEY, _STATE)
            ''',
                                 subst={
                                     '_MAP': aggrinv.map,
                                     '_KEY': '_key',
                                     '_STATE': state,
                                     '_ZERO': zero,
                                     '<c>_DECOMP_KEY': decomp_key_code,
                                     '<c>_LOOP': loop_code
                                 })

    else:
        maint_code = L.Parser.pc('''
            _MAP.mapdelete(_KEY)
            ''',
                                 subst={
                                     '_MAP': aggrinv.map,
                                     '_KEY': '_key'
                                 })

    func_name = aggrinv.get_restr_maint_func_name(op)

    func = L.Parser.ps('''
        def _FUNC(_key):
            _MAINT
        ''',
                       subst={
                           '_FUNC': func_name,
                           '<c>_MAINT': maint_code
                       })

    return func