示例#1
0
    def visit_Member(self, node):
        node = self.generic_visit(node)

        if (isinstance(node.iter, L.Name) and node.iter.id in self.rels
                and L.is_tuple_of_names(node.target)):
            return L.RelMember(L.detuplify(node.target), node.iter.id)
        return node
示例#2
0
    def functionally_determines(self, cl, bindenv):
        mask = L.mask_from_bounds(self.lhs_vars(cl), bindenv)

        if mask == L.mask('bu'):
            return True
        else:
            return super().functionally_determines(cl, bindenv)
示例#3
0
    def get_priority(self, cl, bindenv):
        mask = L.mask_from_bounds(self.lhs_vars(cl), bindenv)

        if mask == L.mask('bu'):
            return Priority.Constant
        else:
            return super().get_priority(cl, bindenv)
def py_preprocess(tree):
    """Take in a Python AST tree, partially preprocess it, and return
    the corresponding IncAST tree along with parsed information.
    """
    # Rewrite QUERY directives to replace their source strings with
    # the corresponding parsed Python ASTs. Provided that the other
    # preprocessing steps are functional (i.e., apply equally to
    # multiple occurrences of the same AST), this ensures that any
    # subsequent steps that modify occurrences of a query will also
    # modify its occurrence in the QUERY directive.
    tree = preprocess_query_directives(tree)

    # Admit some constructs as syntactic sugar that would otherwise
    # be excluded from IncAST.
    tree = preprocess_constructs(tree)

    # Get rid of import statement and qualifiers for the runtime
    # library.
    tree = preprocess_runtime_import(tree)

    # Get rid of main boilerplate.
    tree = preprocess_main_call(tree)

    # Get relation declarations.
    tree, decls = preprocess_var_decls(tree)

    # Get symbol info.
    tree, info = preprocess_directives(tree)

    # Convert the tree and parsed query info to IncAST.
    tree = L.import_incast(tree)
    info.query_info = [(L.import_incast(query), value) for query, value in info.query_info]

    return tree, decls, info
示例#5
0
 def get_priority(self, cl, bindenv):
     mask = L.mask_from_bounds(self.lhs_vars(cl), bindenv)
     
     if mask == L.mask('bu'):
         return Priority.Constant
     else:
         return super().get_priority(cl, bindenv)
示例#6
0
 def functionally_determines(self, cl, bindenv):
     mask = L.mask_from_bounds(self.lhs_vars(cl), bindenv)
     
     if mask == L.mask('bu'):
         return True
     else:
         return super().functionally_determines(cl, bindenv)
示例#7
0
 def get_code(self, cl, bindenv, body):
     assert_unique(cl.vars)
     mask = L.mask_from_bounds(cl.vars, bindenv)
     keyvars, valvar = L.split_by_mask(cl.mask, cl.vars)
     valvar = valvar[0]
     
     # Can also handle all-unbound case by iterating over dict.items(),
     # but requires fresh var for decomposing key tuple.
     
     if L.mask_is_allbound(mask):
         comparison = L.Parser.pe('_KEY in _MAP and _MAP[_KEY] == _VALUE',
                                  subst={'_MAP': cl.map,
                                         '_KEY': L.tuplify(keyvars),
                                         '_VALUE': valvar})
         code = (L.If(comparison, body, ()),)
     
     elif mask == cl.mask:
         code = L.Parser.pc('''
             if _KEY in _MAP:
                 _VALUE = _MAP[_KEY]
                 _BODY
             ''', subst={'_MAP': cl.map,
                         '_KEY': L.tuplify(keyvars),
                         '_VALUE': valvar,
                         '<c>_BODY': body})
     
     else:
         code = super().get_code(cl, bindenv, body)
     
     return code
示例#8
0
def make_demand_query(symtab, query, left_clauses):
    """Create a demand query, update the query's demand_query attribute,
    and return the new demand query symbol.
    """
    ct = symtab.clausetools

    demquery_name = N.get_query_demand_query_name(query.name)

    demquery_tuple = L.tuplify(query.demand_params)
    demquery_tuple_type = symtab.analyze_expr_type(demquery_tuple)
    demquery_type = T.Set(demquery_tuple_type)

    demquery_comp = L.Comp(demquery_tuple, left_clauses)
    prefix = next(symtab.fresh_names.vars)
    demquery_comp = ct.comp_rename_lhs_vars(demquery_comp,
                                            lambda x: prefix + x)

    demquery_sym = symtab.define_query(demquery_name,
                                       type=demquery_type,
                                       node=demquery_comp,
                                       impl=query.impl)

    query.demand_query = demquery_name

    return demquery_sym
示例#9
0
 def get_code(self, cl, bindenv, body):
     vars = self.lhs_vars(cl)
     assert_unique(vars)
     mask = L.mask_from_bounds(vars, bindenv)
     
     lookup_expr = L.DictLookup(L.Name(cl.map), L.Name(cl.key), None)
     
     if L.mask_is_allbound(mask):
         comparison = L.Compare(L.Name(cl.value), L.Eq(), lookup_expr)
         code = (L.If(comparison, body, ()),)
         needs_typecheck = True
     
     elif mask == L.mask('bbu'):
         code = (L.Assign(cl.value, lookup_expr),)
         code += body
         needs_typecheck = True
     
     elif mask == L.mask('buu'):
         items_expr = L.Parser.pe('_MAP.items()', subst={'_MAP': cl.map})
         code = (L.DecompFor([cl.key, cl.value], items_expr, body),)
         needs_typecheck = True
     
     else:
         code = super().get_code(cl, bindenv, body)
         needs_typecheck = False
     
     if needs_typecheck and self.use_typecheck:
         code = (L.If(L.IsMap(L.Name(cl.map)), code, ()),)
     
     return code
示例#10
0
 def visit_DictLookup(self, node):
     node = self.generic_visit(node)
     
     # Only simple map lookups are allowed.
     assert isinstance(node.value, L.Name)
     assert L.is_tuple_of_names(node.key)
     assert node.default is None
     map = node.value.id
     keyvars = L.detuplify(node.key)
     
     var = self.repls.get(node, None)
     if var is None:
         mask = L.mapmask_from_len(len(keyvars))
         rel = N.SA_name(map, mask)
         
         # Create a fresh variable.
         self.repls[node] = var = next(self.fresh_names)
         
         # Construct a clause to bind it.
         vars = list(keyvars) + [var]
         new_clause = L.SetFromMapMember(vars, rel, map, mask)
         self.new_clauses.append(new_clause)
         
         # Construct a corresponding SetFromMap invariant.
         sfm = SetFromMapInvariant(rel, map, mask)
         self.sfm_invs.add(sfm)
     
     return L.Name(var)
示例#11
0
 def rewrite_comp(self, symbol, name, comp):
     if name == query.name:
         if query.params == ():
             return L.Name(result_var)
         else:
             mask = L.keymask_from_len(len(query.params), orig_arity)
             return L.ImgLookup(L.Name(result_var), mask, query.params)
示例#12
0
def convert_subquery_clause(clause):
    """Given a clause, if it is a VarsMember clause for an
    incrementalized subquery, return an equivalent RelMember clause.
    For any other clause return the clause unchanged.
    
    The two forms recognized are:
    
        - right-hand side is a Name node
        
        - right-hand side is an ImgLookup node on a Name, with a keymask
    """
    if not isinstance(clause, L.VarsMember):
        return clause

    if isinstance(clause.iter, L.Name):
        return L.RelMember(clause.vars, clause.iter.id)
    elif (isinstance(clause.iter, L.ImgLookup)
          and isinstance(clause.iter.set, L.Name)
          and L.is_keymask(clause.iter.mask)):
        nb, nu = L.break_keymask(clause.iter.mask)
        assert nb == len(clause.iter.bounds)
        assert nu == len(clause.vars)
        return L.RelMember(clause.iter.bounds + clause.vars,
                           clause.iter.set.id)

    return clause
 def visit_Member(self, node):
     # For clauses that wrap around another clause, like
     # WithoutMember, reorient the target and iter before recursing.
     handled = False
     
     # <target> in <expr> - {<elem>}
     if (isinstance(node.iter, L.BinOp) and
         isinstance(node.iter.op, L.Sub) and
         isinstance(node.iter.right, L.Set) and
         len(node.iter.right.elts) == 1):
         inner_clause = L.Member(node.target, node.iter.left)
         node = L.WithoutMember(inner_clause, node.iter.right.elts[0])
         handled = True
     
     node = self.generic_visit(node)
     if handled:
         return node
     
     # <vars> in {<elem>}
     if (L.is_tuple_of_names(node.target) and
         isinstance(node.iter, L.Set) and
         len(node.iter.elts) == 1):
         return L.SingMember(L.detuplify(node.target),
                             node.iter.elts[0])
     
     return node
示例#14
0
def py_preprocess(tree):
    """Take in a Python AST tree, partially preprocess it, and return
    the corresponding IncAST tree along with parsed information.
    """
    # Rewrite QUERY directives to replace their source strings with
    # the corresponding parsed Python ASTs. Provided that the other
    # preprocessing steps are functional (i.e., apply equally to
    # multiple occurrences of the same AST), this ensures that any
    # subsequent steps that modify occurrences of a query will also
    # modify its occurrence in the QUERY directive.
    tree = preprocess_query_directives(tree)

    # Admit some constructs as syntactic sugar that would otherwise
    # be excluded from IncAST.
    tree = preprocess_constructs(tree)

    # Get rid of import statement and qualifiers for the runtime
    # library.
    tree = preprocess_runtime_import(tree)

    # Get rid of main boilerplate.
    tree = preprocess_main_call(tree)

    # Get relation declarations.
    tree, decls = preprocess_var_decls(tree)

    # Get symbol info.
    tree, info = preprocess_directives(tree)

    # Convert the tree and parsed query info to IncAST.
    tree = L.import_incast(tree)
    info.query_info = [(L.import_incast(query), value)
                       for query, value in info.query_info]

    return tree, decls, info
示例#15
0
def make_comp_maint_func(clausetools, fresh_var_prefix, fresh_join_names, comp,
                         result_var, rel, op, *, counted):
    """Make maintenance function for a comprehension."""
    assert isinstance(op, (L.SetAdd, L.SetRemove))

    op_name = L.set_update_name(op)
    func_name = N.get_maint_func_name(result_var, rel, op_name)

    update = L.RelUpdate(rel, op, '_elem')
    code = clausetools.get_maint_code(fresh_var_prefix,
                                      fresh_join_names,
                                      comp,
                                      result_var,
                                      update,
                                      counted=counted)
    func = L.Parser.ps('''
        def _FUNC(_elem):
            _CODE
        ''',
                       subst={
                           '_FUNC': func_name,
                           '<c>_CODE': code
                       })

    return func
示例#16
0
    def visit_DictLookup(self, node):
        node = self.generic_visit(node)

        # Only simple map lookups are allowed.
        assert isinstance(node.value, L.Name)
        assert L.is_tuple_of_names(node.key)
        assert node.default is None
        map = node.value.id
        keyvars = L.detuplify(node.key)

        var = self.repls.get(node, None)
        if var is None:
            mask = L.mapmask_from_len(len(keyvars))
            rel = N.SA_name(map, mask)

            # Create a fresh variable.
            self.repls[node] = var = next(self.fresh_names)

            # Construct a clause to bind it.
            vars = list(keyvars) + [var]
            new_clause = L.SetFromMapMember(vars, rel, map, mask)
            self.new_clauses.append(new_clause)

            # Construct a corresponding SetFromMap invariant.
            sfm = SetFromMapInvariant(rel, map, mask)
            self.sfm_invs.add(sfm)

        return L.Name(var)
示例#17
0
 def get_loop_for_join(self, comp, body, query_name):
     """Given a join, create code for iterating over it and running
     body. The join is wrapped in a Query node with the given name.
     """
     assert self.is_join(comp)
     vars = self.lhs_vars_from_comp(comp)
     return (L.DecompFor(vars, L.Query(query_name, comp, None), body), )
示例#18
0
def convert_subquery_clause(clause):
    """Given a clause, if it is a VarsMember clause for an
    incrementalized subquery, return an equivalent RelMember clause.
    For any other clause return the clause unchanged.
    
    The two forms recognized are:
    
        - right-hand side is a Name node
        
        - right-hand side is an ImgLookup node on a Name, with a keymask
    """
    if not isinstance(clause, L.VarsMember):
        return clause
    
    if isinstance(clause.iter, L.Name):
        return L.RelMember(clause.vars, clause.iter.id)
    elif (isinstance(clause.iter, L.ImgLookup) and
          isinstance(clause.iter.set, L.Name) and
          L.is_keymask(clause.iter.mask)):
        nb, nu = L.break_keymask(clause.iter.mask)
        assert nb == len(clause.iter.bounds)
        assert nu == len(clause.vars)
        return L.RelMember(clause.iter.bounds + clause.vars,
                           clause.iter.set.id)
    
    return clause
 def visit_Call(self, node):
     node = self.generic_visit(node)
     if node.func == 'len':
         if not len(node.args) == 1:
             raise L.ProgramError('Expected one argument for len()')
         return L.Aggr(L.Count(), node.args[0])
     return node
示例#20
0
    def rewrite_with_demand(self, query_sym, node):
        """Given a query symbol and its associated Comp or Aggr node,
        return the demand-transformed version of that node (not
        transforming any subqueries).
        """
        symtab = self.symtab
        demand_params = query_sym.demand_params

        if not query_sym.uses_demand:
            return node

        # Make a demand set or demand query.
        left_clauses = self.get_left_clauses()
        if left_clauses is None:
            dem_sym = make_demand_set(symtab, query_sym)
            dem_node = L.Name(dem_sym.name)
            dem_clause = L.RelMember(demand_params, dem_sym.name)
            self.queries_with_usets.add(query_sym.name)
        else:
            dem_sym = make_demand_query(symtab, query_sym, left_clauses)
            dem_node = dem_sym.make_node()
            dem_clause = L.VarsMember(demand_params, dem_node)
            self.demand_queries.add(dem_sym.name)

        # Determine the rewritten node.
        if isinstance(node, L.Comp):
            node = node._replace(clauses=(dem_clause, ) + node.clauses)
        elif isinstance(node, L.Aggr):
            node = L.AggrRestr(node.op, node.value, demand_params, dem_node)
        else:
            raise AssertionError(
                'No rule for handling demand for {} node'.format(
                    node.__class__.__name__))

        return node
 def visit_Member(self, node):
     node = self.generic_visit(node)
     
     if (L.is_tuple_of_names(node.target) and
         isinstance(node.iter, L.Query)):
         node = L.VarsMember(L.detuplify(node.target), node.iter)
     
     return node
示例#22
0
    def visit_RelClear(self, node):
        if node.rel not in self.rels:
            return node

        code = (node, )
        clear_code = (L.RelClear(self.result_var), )
        code = L.insert_rel_maint(code, clear_code, L.SetRemove())
        return code
示例#23
0
 def visit_DecompFor(self, node):
     if (isinstance(node.iter, L.Name) and node.iter.id in self.rels):
         if len(node.vars) != 1:
             raise L.TransformationError(
                 'Singleton unwrapping requires all DecompFor loops '
                 'over relation to have exactly one target variable')
         return L.For(node.vars[0], node.iter, node.body)
     return node
示例#24
0
 def visit_AttrAssign(self, node):
     if node.attr not in self.objrels.Fs:
         return
     
     pair = L.Tuple([node.obj, node.value])
     var = next(self.fresh_vars)
     return (L.Assign(var, pair),
             L.RelUpdate(N.F(node.attr), L.SetAdd(), var))
示例#25
0
 def visit_DictAssign(self, node):
     if not self.objrels.MAP:
         return
     
     triple = L.Tuple([node.target, node.key, node.value])
     var = next(self.fresh_vars)
     return (L.Assign(var, triple),
             L.RelUpdate(N.MAP, L.SetAdd(), var))
示例#26
0
 def visit_DictDelete(self, node):
     if not self.objrels.MAP:
         return
     
     lookup = L.DictLookup(node.target, node.key, None)
     triple = L.Tuple([node.target, node.key, lookup])
     var = next(self.fresh_vars)
     return (L.Assign(var, triple),
             L.RelUpdate(N.MAP, L.SetRemove(), var))
示例#27
0
 def visit_AttrDelete(self, node):
     if node.attr not in self.objrels.Fs:
         return
     
     lookup = L.Attribute(node.obj, node.attr)
     pair = L.Tuple([node.obj, lookup])
     var = next(self.fresh_vars)
     return (L.Assign(var, pair),
             L.RelUpdate(N.F(node.attr), L.SetRemove(), var))
示例#28
0
 def rewrite_aggr(self, symbol, name, aggr):
     if isinstance(aggr.value, L.Name):
         relsym = symtab.get_relations()[aggr.value.id]
         rel_type = relsym.type
         if not (isinstance(rel_type, T.Set)
                 and isinstance(rel_type.elt, T.Tuple)):
             new_value = L.Unwrap(L.Wrap(aggr.value))
             aggr = aggr._replace(value=new_value)
     return aggr
示例#29
0
    def get_priority(self, cl, bindenv):
        mask = L.mask_from_bounds(self.lhs_vars(cl), bindenv)

        if L.mask_is_allbound(mask):
            return Priority.Constant
        elif L.mask_is_allunbound(mask):
            return Priority.Unpreferred
        else:
            return Priority.Normal
示例#30
0
    def visit_MapClear(self, node):
        sfm = self.setfrommaps_by_map.get(node.map, None)
        if sfm is None:
            return node

        code = (node, )
        clear_code = (L.RelClear(sfm.rel), )
        code = L.insert_rel_maint(code, clear_code, L.SetRemove())
        return code
示例#31
0
 def get_priority(self, cl, bindenv):
     mask = L.mask_from_bounds(self.lhs_vars(cl), bindenv)
     
     if L.mask_is_allbound(mask):
         return Priority.Constant
     elif L.mask_is_allunbound(mask):
         return Priority.Unpreferred
     else:
         return Priority.Normal
示例#32
0
 def visit_SetUpdate(self, node):
     if not isinstance(node.op, (L.SetAdd, L.SetRemove)):
         return
     if not self.objrels.M:
         return
     
     pair = L.Tuple([node.target, node.value])
     var = next(self.fresh_vars)
     return (L.Assign(var, pair),
             L.RelUpdate(N.M, node.op, var))
示例#33
0
 def visit_Query(self, node):
     self.generic_visit(node)
     if node.ann is not None:
         if not isinstance(node.ann, (dict, frozendict)):
             raise L.ProgramError('Query annotation must be a '
                                  'dictionary')
         for key in node.ann.keys():
             if key not in S.annotations:
                 raise L.ProgramError(
                     'Unknown annotation key "{}"'.format(key))
示例#34
0
    def visit_MapDelete(self, node):
        sfm = self.setfrommaps_by_map.get(node.map, None)
        if sfm is None:
            return node

        code = (node, )
        func_name = sfm.get_maint_func_name('delete')
        call_code = (L.Expr(L.Call(func_name, [L.Name(node.key)])), )
        code = L.insert_rel_maint(code, call_code, L.SetRemove())
        return code
示例#35
0
 def Tuple_helper(self, node):
     if not L.is_tuple_of_names(node):
         raise L.ProgramError('Non-simple tuple expression: {}'
                              .format(node))
     elts = L.detuplify(node)
     
     name = self.tuple_namer(elts)
     clause = L.TUPMember(name, elts)
     self.objrels.TUPs.append(len(elts))
     self.after_clauses.insert(0, clause)
     return name
示例#36
0
    def process(expr):
        if not (isinstance(expr, L.Member)
                and isinstance(expr.target, L.Name)):
            return expr, [], []

        if isinstance(expr.iter, L.Unwrap):
            rhs = expr.iter.value
        else:
            rhs = L.Wrap(expr.iter)
        cl = L.VarsMember([expr.target.id], rhs)
        return cl, [], []
示例#37
0
def aggrinv_from_query(symtab, query, result_var):
    """Determine the aggregate invariant info for a given query."""
    node = query.node

    assert isinstance(node, (L.Aggr, L.AggrRestr))
    oper = node.value
    op = node.op

    if isinstance(oper, L.Unwrap):
        unwrap = True
        oper = oper.value
    else:
        unwrap = False

    # Get rel, mask, and param info.
    if isinstance(oper, L.Name):
        rel = oper.id
        # Mask will be all-unbound, filled in below.
        mask = None
        params = ()
    elif (isinstance(oper, L.ImgLookup) and isinstance(oper.set, L.Name)):
        rel = oper.set.id
        mask = oper.mask
        params = oper.bounds
    else:
        raise L.ProgramError('Unknown aggregate form: {}'.format(node))

    # Lookup symbol, use type info to determine the relation's arity.
    t_rel = get_rel_type(symtab, rel)
    if not (isinstance(t_rel, T.Set) and isinstance(t_rel.elt, T.Tuple)):
        raise L.ProgramError(
            'Invalid type for aggregate operand: {}'.format(t_rel))
    arity = len(t_rel.elt.elts)

    if mask is None:
        mask = L.mask('u' * arity)
    else:
        # Confirm that this arity is consistent with the above mask.
        assert len(mask.m) == arity

    if isinstance(node, L.AggrRestr):
        # Check that the restriction parameters match the ImgLookup
        # parameters
        if node.params != params:
            raise L.TransformationError('AggrRestr params do not match '
                                        'ImgLookup params')
        if not isinstance(node.restr, L.Name):
            raise L.ProgramError('Bad AggrRestr restriction expr')
        restr = node.restr.id
    else:
        restr = None

    return AggrInvariant(result_var, op, rel, mask, unwrap, params, restr)
示例#38
0
        def visit_Fun(self, node):
            node = self.generic_visit(node)

            if node.name in func_costs:
                cost = func_costs[node.name]
                simp_cost = rewrite_cost_using_types(cost, symtab)
                cost_str = PrettyPrinter.run(cost)
                simp_cost_str = PrettyPrinter.run(simp_cost)
                comment = (L.Comment('Cost: O({})'.format(cost_str)),
                           L.Comment('      O({})'.format(simp_cost_str)))
                node = node._replace(body=comment + node.body)
            return node
示例#39
0
 def visit_RelUpdate(self, node):
     if not isinstance(node.op, (L.SetAdd, L.SetRemove)):
         return node
     if node.rel not in self.rels:
         return node
     
     op_name = L.set_update_name(node.op)
     func_name = N.get_maint_func_name(self.result_var, node.rel, op_name)
     
     code = (node,)
     call_code = (L.Expr(L.Call(func_name, [L.Name(node.elem)])),)
     code = L.insert_rel_maint(code, call_code, node.op)
     return code
示例#40
0
    def visit_RelUpdate(self, node):
        if not isinstance(node.op, (L.SetAdd, L.SetRemove)):
            return node
        if node.rel not in self.rels:
            return node

        op_name = L.set_update_name(node.op)
        func_name = N.get_maint_func_name(self.result_var, node.rel, op_name)

        code = (node, )
        call_code = (L.Expr(L.Call(func_name, [L.Name(node.elem)])), )
        code = L.insert_rel_maint(code, call_code, node.op)
        return code
示例#41
0
    def rewrite_resexp_with_params(self, comp, params):
        """Assuming the result expression is a tuple expression,
        rewrite it to prepend components for the given parameter
        variables.
        """
        lhs_vars = self.lhs_vars_from_comp(comp)
        assert set(params).issubset(set(lhs_vars)), \
            'params: {}, lhs_vars: {}'.format(params, lhs_vars)
        assert isinstance(comp.resexp, L.Tuple)

        new_resexp = L.Tuple(
            tuple(L.Name(p) for p in params) + comp.resexp.elts)
        return comp._replace(resexp=new_resexp)
示例#42
0
 def get_code(self, cl, bindenv, body):
     vars = self.lhs_vars(cl)
     assert_unique(vars)
     mask = L.mask_from_bounds(vars, bindenv)
     
     comparison = L.Compare(L.Name(cl.tup), L.Eq(), L.tuplify(cl.elts))
     
     if L.mask_is_allbound(mask):
         code = (L.If(comparison, body, ()),)
         needs_typecheck = True
     
     elif mask.m.startswith('b'):
         elts_mask = L.mask_from_bounds(cl.elts, bindenv)
         code = L.bind_by_mask(elts_mask, cl.elts, L.Name(cl.tup))
         if L.mask_is_allunbound(elts_mask):
             code += body
         else:
             code += (L.If(comparison, body, ()),)
         needs_typecheck = True
     
     elif mask == L.mask('u' + 'b' * len(cl.elts)):
         code = (L.Assign(cl.tup, L.tuplify(cl.elts)),)
         code += body
         needs_typecheck = False
     
     else:
         raise L.TransformationError('Cannot emit code for TUP clause '
                                     'that would require an auxiliary '
                                     'map; use demand filtering')
     
     if needs_typecheck and self.use_typecheck:
         code = (L.If(L.HasArity(L.Name(cl.tup), len(cl.elts)), code, ()),)
     
     return code
示例#43
0
    def visit_RelUpdate(self, node):
        if not isinstance(node.op, (L.SetAdd, L.SetRemove)):
            return node

        if node.rel == self.aggrinv.rel:
            func = self.aggrinv.get_oper_maint_func_name(node.op)
            code = L.insert_rel_maint_call(node, func)
        elif self.aggrinv.uses_demand and node.rel == self.aggrinv.restr:
            func = self.aggrinv.get_restr_maint_func_name(node.op)
            code = L.insert_rel_maint_call(node, func)
        else:
            code = node

        return code
示例#44
0
 def visit_RelClear(self, node):
     code = (node,)
     
     auxmaps = self.auxmaps_by_rel.get(node.rel, set())
     for auxmap in auxmaps:
         clear_code = (L.MapClear(auxmap.map),)
         code = L.insert_rel_maint(code, clear_code, L.SetRemove())
     
     wraps = self.wraps_by_rel.get(node.rel, set())
     for wrap in wraps:
         clear_code = (L.RelClear(wrap.rel),)
         code = L.insert_rel_maint(code, clear_code, L.SetRemove())
     
     return code
示例#45
0
def transform_source(input_source, *, options=None, query_options=None):
    """Take in the Python source code to a module and return the
    transformed source code and the symbol table.
    """
    tree = P.Parser.p(input_source)
    
    t1 = process_time()
    tree, symtab = transform_ast(tree, options=options,
                                 query_options=query_options)
    t2 = process_time()
    
    source = P.Parser.ts(tree)
    # All good human beings have trailing newlines in their
    # text files.
    source = source + '\n'
    
    symtab.stats['lines'] = get_loc_source(source)
    # L.tree_size() is for IncASTs, but it should also work for
    # Python ASTs. We have to re-parse the source to get rid of
    # our Comment pseudo-nodes.
    tree = P.Parser.p(source)
    symtab.stats['ast_nodes'] = L.tree_size(tree)
    symtab.stats['time'] = t2 - t1
    
    return source, symtab
示例#46
0
def flatten_memberships(comp):
    """Transform the comprehension to rewrite set memberships (Member
    nodes) as MMember clauses. Return an ObjRelations indicating whether
    an M set is needed.
    """
    M = False
    def process(clause):
        nonlocal M
        if isinstance(clause, L.Member):
            # MMember.
            if (isinstance(clause.target, L.Name) and
                isinstance(clause.iter, L.Name)):
                set_ = clause.iter.id
                elem = clause.target.id
                M = True
                clause = L.MMember(set_, elem)
            
            # Subquery clause, leave as Member for now.
            elif (isinstance(clause.target, L.Name) and
                  isinstance(clause.iter, L.Unwrap)):
                pass
            
            else:
                raise L.ProgramError('Cannot flatten Member clause: {}'
                                     .format(clause))
            
        
        return clause, [], []
    
    tree = L.rewrite_comp(comp, process)
    objrels = ObjRelations(M, [], False, [])
    return tree, objrels
示例#47
0
 def rewrite_comp(self, symbol, name, comp):
     if name == query.name:
         if query.params == ():
             return L.Name(result_var)
         else:
             mask = L.keymask_from_len(len(query.params), orig_arity)
             return L.ImgLookup(L.Name(result_var), mask, query.params)
示例#48
0
def make_auxmap_type(auxmapinv, reltype):
    """Given a mask and a relation type, determine the corresponding
    auxiliary map type.
    
    We obtain by lattice join the smallest relation type that is at
    least as big as the given relation type and that has the correct
    arity. This should have the form {(T1, ..., Tn)}. The map type is
    then from a tuple of some Ts to a set of tuples of the remaining Ts.
    
    If no such type exists, e.g. if the given relation type is {Top}
    or a set of tuples of incorrect arity, we instead give the map type
    {Top: Top}.
    """
    mask = auxmapinv.mask
    arity = len(mask.m)
    bottom_reltype = T.Set(T.Tuple([T.Bottom] * arity))
    top_reltype = T.Set(T.Tuple([T.Top] * arity))
    
    norm_type = reltype.join(bottom_reltype)
    well_typed = norm_type.issmaller(top_reltype)
    
    if well_typed:
        assert (isinstance(norm_type, T.Set) and
                isinstance(norm_type.elt, T.Tuple) and
                len(norm_type.elt.elts) == arity)
        t_bs, t_us = L.split_by_mask(mask, norm_type.elt.elts)
        t_key = t_bs[0] if auxmapinv.unwrap_key else T.Tuple(t_bs)
        t_value = t_us[0] if auxmapinv.unwrap_value else T.Tuple(t_us)
        map_type = T.Map(t_key, T.Set(t_value))
    else:
        map_type = T.Map(T.Top, T.Top)
    
    return map_type
示例#49
0
def is_duplicate_safe(clausetools, comp):
    """Return whether we can rule out duplicates analytically."""
    if not L.is_injective(comp.resexp):
        return False
    vars = L.IdentFinder.find_vars(comp.resexp)
    determined = clausetools.all_vars_determined(comp.clauses, vars)
    return determined
示例#50
0
def make_setfrommap_type(mask, maptype):
    """Given a mask and a map type, determine the corresponding relation
    type.
    
    We obtain by lattice join the smallest map type that is at least as
    big as the given map type and that has the correct key tuple arity.
    This should have the form {(K1, ..., Kn): V}. The relation type is
    then a set of tuples of these types interleaved according to the
    mask.
    
    If no such type exists, e.g. if the given relation type is {Top: Top}
    or the key is not a tuple of correct arity, we instead give the
    relation type {Top}.
    """
    nb = mask.m.count('b')
    assert mask.m.count('u') == 1
    bottom_maptype = T.Map(T.Tuple([T.Bottom] * nb), T.Bottom)
    top_maptype = T.Map(T.Tuple([T.Top] * nb), T.Top)
    
    norm_type = maptype.join(bottom_maptype)
    well_typed = norm_type.issmaller(top_maptype)
    
    if well_typed:
        assert (isinstance(norm_type, T.Map) and
                isinstance(norm_type.key, T.Tuple) and
                len(norm_type.key.elts) == nb)
        t_elts = L.combine_by_mask(mask, norm_type.key.elts,
                                   [norm_type.value])
        rel_type = T.Set(T.Tuple(t_elts))
    else:
        rel_type = T.Set(T.Top)
    
    return rel_type
示例#51
0
 def is_join(self, comp):
     lhs_vars = self.lhs_vars_from_clauses(comp.clauses)
     try:
         res_vars = L.detuplify(comp.resexp)
     except ValueError:
         return False
     return sorted(res_vars) == sorted(lhs_vars)
示例#52
0
def make_setfrommap_maint_func(fresh_vars,
                               setfrommap: SetFromMapInvariant,
                               op: str):
    mask = setfrommap.mask
    nb = L.break_mapmask(mask)
    # Fresh variables for components of the key and value.
    key_vars = N.get_subnames('_key', nb)
    
    decomp_code = (L.DecompAssign(key_vars, L.Name('_key')),)
    
    vars = L.combine_by_mask(mask, key_vars, ['_val'])
    elem = L.tuplify(vars)
    fresh_var_prefix = next(fresh_vars)
    elem_var = fresh_var_prefix + '_elem'
    
    decomp_code += (L.Assign(elem_var, elem),)
    
    setopcls = {'assign': L.SetAdd,
                'delete': L.SetRemove}[op]
    update_code = (L.RelUpdate(setfrommap.rel, setopcls(), elem_var),)
    
    func_name = setfrommap.get_maint_func_name(op)
    
    if op == 'assign':
        func = L.Parser.ps('''
            def _FUNC(_key, _val):
                _DECOMP
                _UPDATE
            ''', subst={'_FUNC': func_name,
                        '<c>_DECOMP': decomp_code,
                        '<c>_UPDATE': update_code})
    elif op == 'delete':
        lookup_expr = L.DictLookup(L.Name(setfrommap.map),
                                   L.Name('_key'), None)
        func = L.Parser.ps('''
            def _FUNC(_key):
                _val = _LOOKUP
                _DECOMP
                _UPDATE
            ''', subst={'_FUNC': func_name,
                        '_LOOKUP': lookup_expr,
                        '<c>_DECOMP': decomp_code,
                        '<c>_UPDATE': update_code})
    else:
        assert()
    
    return func
示例#53
0
 def visit_RelClear(self, node):
     if node.rel not in self.rels:
         return node
     
     code = (node,)
     clear_code = (L.RelClear(self.result_var),)
     code = L.insert_rel_maint(code, clear_code, L.SetRemove())
     return code
示例#54
0
 def get_code(self, cl, bindenv, body):
     assert_unique(cl.vars)
     mask = L.mask_from_bounds(cl.vars, bindenv)
     check_eq = L.Compare(L.tuplify(cl.vars), L.Eq(), cl.value)
     
     if L.mask_is_allbound(mask):
         code = (L.If(check_eq, body, ()),)
     
     elif L.mask_is_allunbound(mask):
         code = (L.DecompAssign(cl.vars, cl.value),)
         code += body
     
     else:
         code = L.bind_by_mask(mask, cl.vars, cl.value)
         code += (L.If(check_eq, body, ()),)
     
     return code
示例#55
0
def match_eq_cond(tree):
    """If tree is a condition clause with form <var> == <var>, return
    a pair of the variables. Otherwise return None.
    """
    result = L.match(eq_cond_pattern, tree)
    if result is None:
        return None
    else:
        return result['LEFT'], result['RIGHT']
示例#56
0
 def get_code(self, cl, bindenv, body):
     lhs_vars = self.visitor.lhs_vars(cl)
     new_body = L.Parser.pc('''
         if _VARS != _VALUE:
             _BODY
         ''', subst={'_VARS': L.tuplify(lhs_vars),
                     '_VALUE': cl.value,
                     '<c>_BODY': body})
     return self.visitor.get_code(cl.cl, bindenv, new_body)
示例#57
0
 def visit_MapClear(self, node):
     sfm = self.setfrommaps_by_map.get(node.map, None)
     if sfm is None:
         return node
     
     code = (node,)
     clear_code = (L.RelClear(sfm.rel),)
     code = L.insert_rel_maint(code, clear_code, L.SetRemove())
     return code
示例#58
0
 def visit_RelUpdate(self, node):
     if not isinstance(node.op, (L.SetAdd, L.SetRemove)):
         return node
     
     code = (node,)
     
     auxmaps = self.auxmaps_by_rel.get(node.rel, set())
     for auxmap in auxmaps:
         func_name = auxmap.get_maint_func_name(node.op)
         call_code = (L.Expr(L.Call(func_name, [L.Name(node.elem)])),)
         code = L.insert_rel_maint(code, call_code, node.op)
     
     wraps = self.wraps_by_rel.get(node.rel, set())
     for wrap in wraps:
         func_name = wrap.get_maint_func_name(node.op)
         call_code = (L.Expr(L.Call(func_name, [L.Name(node.elem)])),)
         code = L.insert_rel_maint(code, call_code, node.op)
     
     return code
示例#59
0
 def visit_MapDelete(self, node):
     sfm = self.setfrommaps_by_map.get(node.map, None)
     if sfm is None:
         return node
     
     code = (node,)
     func_name = sfm.get_maint_func_name('delete')
     call_code = (L.Expr(L.Call(func_name, [L.Name(node.key)])),)
     code = L.insert_rel_maint(code, call_code, L.SetRemove())
     return code