Beispiel #1
0
    def from_node(cls, node):
        checktype(node, L.Aggregate)

        if isinstance(node.value, L.DemQuery):
            assert all(isinstance(a, L.Name) for a in node.value.args)
            oper_demparams = tuple(a.id for a in node.value.args)
            oper_demname = node.value.demname
            oper = node.value.value
        else:
            oper_demparams = None
            oper_demname = None
            oper = node.value

        if isinstance(oper, L.Name):
            rel = oper.id
            relmask = Mask.U
            params = ()

        elif (isinstance(oper, L.SetMatch) and isinstance(oper.target, L.Name)
              and L.is_vartuple(oper.key)):
            rel = oper.target.id
            relmask = Mask(oper.mask)
            params = L.get_vartuple(oper.key)

        else:
            raise L.ProgramError('Bad aggregate operand', node=node)

        return cls(node.op, rel, relmask, params, oper_demname, oper_demparams)
Beispiel #2
0
 def from_node(cls, node):
     checktype(node, L.Aggregate)
     
     if isinstance(node.value, L.DemQuery):
         assert all(isinstance(a, L.Name) for a in node.value.args)
         oper_demparams = tuple(a.id for a in node.value.args)
         oper_demname = node.value.demname
         oper = node.value.value
     else:
         oper_demparams = None
         oper_demname = None
         oper = node.value
     
     if isinstance(oper, L.Name):
         rel = oper.id
         relmask = Mask.U
         params = ()
     
     elif (isinstance(oper, L.SetMatch) and
           isinstance(oper.target, L.Name) and
           L.is_vartuple(oper.key)):
         rel = oper.target.id
         relmask = Mask(oper.mask)
         params = L.get_vartuple(oper.key)
     
     else:
         raise L.ProgramError('Bad aggregate operand', node=node)
     
     return cls(node.op, rel, relmask, params,
                oper_demname, oper_demparams)
Beispiel #3
0
    def from_AST(cls, node, factory):
        """Construct from an Enumerator node of form
        
            <vars> in <rel>
        
        Alternatively, the rhs may be a setmatch of a rel, where
        the mask is a lookupmask and the key is a vartuple.
        """
        checktype(node, L.Enumerator)

        lhs = L.get_vartuple(node.target)
        rhs = node.iter

        if L.is_name(rhs):
            rel = L.get_name(rhs)

        elif isinstance(rhs, L.SetMatch) and L.is_vartuple(rhs.key):
            keyvars = L.get_vartuple(rhs.key)
            # Make sure we're dealing with a lookupmask and that the
            # key vars agree with the mask.
            mask = Mask(rhs.mask)
            assert mask.is_lookupmask
            assert mask.lookup_arity == len(keyvars)

            lhs = keyvars + lhs
            rel = L.get_name(rhs.target)

        else:
            raise TypeError

        return cls(lhs, rel)
Beispiel #4
0
 def from_AST(cls, node, factory):
     """Construct from an Enumerator node of form
     
         <vars> in <rel>
     
     Alternatively, the rhs may be a setmatch of a rel, where
     the mask is a lookupmask and the key is a vartuple.
     """
     checktype(node, L.Enumerator)
     
     lhs = L.get_vartuple(node.target)
     rhs = node.iter
     
     if L.is_name(rhs):
         rel = L.get_name(rhs)
     
     elif isinstance(rhs, L.SetMatch) and L.is_vartuple(rhs.key):
         keyvars = L.get_vartuple(rhs.key)
         # Make sure we're dealing with a lookupmask and that the
         # key vars agree with the mask.
         mask = Mask(rhs.mask)
         assert mask.is_lookupmask
         assert mask.lookup_arity == len(keyvars)
         
         lhs = keyvars + lhs
         rel = L.get_name(rhs.target)
     
     else:
         raise TypeError
     
     return cls(lhs, rel)
Beispiel #5
0
 def visit_Enumerator(self, node):
     if node.iter == self.comp:
         if not L.is_vartuple(node.target):
             raise self.Failure
         arity = len(L.get_vartuple(node.target))
         if self.arity != arity:
             raise self.Failure
         return
     
     self.generic_visit(node)
Beispiel #6
0
    def visit_Enumerator(self, node):
        if node.iter == self.comp:
            if not L.is_vartuple(node.target):
                raise self.Failure
            arity = len(L.get_vartuple(node.target))
            if self.arity != arity:
                raise self.Failure
            return

        self.generic_visit(node)
Beispiel #7
0
    def visit_For(self, node):
        # Recurse only after we've handled the potential special case.

        if node.iter != self.comp:
            return self.generic_visit(node)

        spec = self.spec

        special_case = (
            node.orelse == () and spec.is_duplicate_safe
            and L.is_vartuple(node.target) and L.is_vartuple(spec.resexp)
            and (L.get_vartuple(node.target) == L.get_vartuple(spec.resexp) or
                 (L.is_name(node.target) and L.get_name(node.target) == '_')))
        if special_case:
            code = ()
            code += (L.Comment('Iterate ' + str(spec)), )
            code += spec.join.get_code(spec.params,
                                       node.body,
                                       augmented=self.augmented)
            return self.visit(code)
        else:
            return self.generic_visit(node)
Beispiel #8
0
 def visit_For(self, node):
     # Recurse only after we've handled the potential special case.
     
     if node.iter != self.comp:
         return self.generic_visit(node)
     
     spec = self.spec
     
     special_case = (
         node.orelse == () and
         spec.is_duplicate_safe and
         L.is_vartuple(node.target) and
         L.is_vartuple(spec.resexp) and
         (L.get_vartuple(node.target) == L.get_vartuple(spec.resexp) or
          (L.is_name(node.target) and L.get_name(node.target) == '_'))
     )
     if special_case:
         code = ()
         code += (L.Comment('Iterate ' + str(spec)),)
         code += spec.join.get_code(spec.params, node.body,
                                    augmented=self.augmented)
         return self.visit(code)
     else:
         return self.generic_visit(node)
Beispiel #9
0
    def membercond_to_enum(cls, cl):
        """For a condition clause that expresses a membership, return
        an equivalent enumerator clause. For other kinds of conditions,
        return the same clause. For enumerators, raise TypeError.
        """
        if cl.kind is not Clause.KIND_COND:
            raise TypeError

        compre_ast = None
        clast = cl.to_AST()
        if L.is_cmp(clast):
            lhs, op, rhs = L.get_cmp(clast)
            if (L.is_vartuple(lhs) and isinstance(op, L.In)):
                compre_ast = L.Enumerator(
                    L.tuplify(L.get_vartuple(lhs), lval=True), rhs)

        if compre_ast is None:
            return cl
        else:
            return cls.from_AST(compre_ast)
Beispiel #10
0
 def membercond_to_enum(cls, cl):
     """For a condition clause that expresses a membership, return
     an equivalent enumerator clause. For other kinds of conditions,
     return the same clause. For enumerators, raise TypeError.
     """
     if cl.kind is not Clause.KIND_COND:
         raise TypeError
     
     compre_ast = None
     clast = cl.to_AST()
     if L.is_cmp(clast):
         lhs, op, rhs = L.get_cmp(clast)
         if (L.is_vartuple(lhs) and
             isinstance(op, L.In)):
             compre_ast = L.Enumerator(
                     L.tuplify(L.get_vartuple(lhs), lval=True),
                     rhs)
     
     if compre_ast is None:
         return cl
     else:
         return cls.from_AST(compre_ast)
Beispiel #11
0
    def expr_tosizecost(self, expr):
        """Turn an iterated expression into a cost bound for its
        cardinality.
        """
        if isinstance(expr, L.Name):
            return NameCost(expr.id)

        # Catch case of iterating over a delta set.
        # We'll just say O(delta set), even though it can have
        # duplicates.
        elif (isinstance(expr, L.Call) and isinstance(expr.func, L.Attribute)
              and isinstance(expr.func.value, L.Name)
              and expr.func.attr == 'elements'):
            return NameCost(expr.func.value.id)

        elif isinstance(expr, L.SetMatch):
            if isinstance(expr.target, (L.Set, L.DeltaMatch)):
                return UnitCost()
            elif (isinstance(expr.target, L.Name) and L.is_vartuple(expr.key)):
                keys = L.get_vartuple(expr.key)
                if all(k in self.args for k in keys):
                    return DefImgsetCost(expr.target.id, Mask(expr.mask),
                                         L.get_vartuple(expr.key))
                else:
                    return IndefImgsetCost(expr.target.id, Mask(expr.mask))
            else:
                return self.WarnUnknownCost(expr)

        elif isinstance(expr, L.DeltaMatch):
            return UnitCost()

        elif isinstance(expr, (L.Set, L.List, L.Tuple, L.Dict)):
            return UnitCost()

        else:
            return self.WarnUnknownCost(expr)