Example #1
0
 def from_AST(cls, node, factory):
     """Construct from an Enumerator node of form
     
         <vars> in <rel>
     
     Alternatively, the rhs may be a setmatch of a rel, where
     the mask is a lookupmask and the key is a vartuple.
     """
     checktype(node, L.Enumerator)
     
     lhs = L.get_vartuple(node.target)
     rhs = node.iter
     
     if L.is_name(rhs):
         rel = L.get_name(rhs)
     
     elif isinstance(rhs, L.SetMatch) and L.is_vartuple(rhs.key):
         keyvars = L.get_vartuple(rhs.key)
         # Make sure we're dealing with a lookupmask and that the
         # key vars agree with the mask.
         mask = Mask(rhs.mask)
         assert mask.is_lookupmask
         assert mask.lookup_arity == len(keyvars)
         
         lhs = keyvars + lhs
         rel = L.get_name(rhs.target)
     
     else:
         raise TypeError
     
     return cls(lhs, rel)
Example #2
0
    def from_AST(cls, node, factory):
        """Construct from an Enumerator node of form
        
            <vars> in <rel>
        
        Alternatively, the rhs may be a setmatch of a rel, where
        the mask is a lookupmask and the key is a vartuple.
        """
        checktype(node, L.Enumerator)

        lhs = L.get_vartuple(node.target)
        rhs = node.iter

        if L.is_name(rhs):
            rel = L.get_name(rhs)

        elif isinstance(rhs, L.SetMatch) and L.is_vartuple(rhs.key):
            keyvars = L.get_vartuple(rhs.key)
            # Make sure we're dealing with a lookupmask and that the
            # key vars agree with the mask.
            mask = Mask(rhs.mask)
            assert mask.is_lookupmask
            assert mask.lookup_arity == len(keyvars)

            lhs = keyvars + lhs
            rel = L.get_name(rhs.target)

        else:
            raise TypeError

        return cls(lhs, rel)
Example #3
0
def unflatten_set_clause(cl):
    """Opposite of above. Unflatten clauses over the M-set. Works for
    both enumerators and conditions. Returns the (possibly unchanged)
    clause.
    """
    # Enumerator case.
    if isinstance(cl, L.Enumerator):
        res = get_menum(cl)
        if res is None:
            return cl
        cont, item = res
        
        cont = L.ContextSetter.run(cont, L.Load)
        new_cl = L.Enumerator(item, cont)
        return new_cl
    
    # Condition case.
    if isinstance(cl, L.expr) and L.is_cmp(cl):
        lhs, op, rhs = L.get_cmp(cl)
        if not (isinstance(op, L.In) and
                isinstance(lhs, L.Tuple) and len(lhs.elts) == 2 and
                L.is_name(rhs) and is_mrel(L.get_name(rhs))):
            return cl
        cont, item = lhs.elts
        new_cl = L.cmp(item, L.In(), cont)
        return new_cl
    
    return cl
Example #4
0
 def visit_SetUpdate(self, node):
     rel = L.get_name(node.target)
     if rel not in self.at_rels:
         return
     
     # New code gets inserted after the update.
     # This is true even if the update was a removal.
     # It shouldn't matter where we do the U-set update,
     # so long as the invariants are properly maintained
     # at the time.
     
     prefix = self.manager.namegen.next_prefix()
     vars = [prefix + v for v in self.projvars]
     
     if node.op == 'add':
         funcname = L.N.demfunc(self.demname)
     else:
         funcname = L.N.undemfunc(self.demname)
     
     call_func = L.Call(L.ln(funcname),
                           tuple(L.ln(v) for v in vars),
                           (), None, None)
     postcode = L.pc('''
         for S_PROJVARS in DELTA.elements():
             CALL_FUNC
         DELTA.clear()
         ''', subst={'S_PROJVARS': L.tuplify(vars, lval=True),
                     'DELTA': self.delta_name,
                     'CALL_FUNC': call_func})
     
     return self.with_outer_maint(node, funcname, L.ts(node),
                                  (), postcode)
Example #5
0
 def from_AST(cls, node, factory):
     """Construct from an Enumerator node of form
     
         var in {<rel>.smlookup(<mask>, <key vars>)}
     
     """
     checktype(node, L.Enumerator)
     
     var = L.get_name(node.target)
     sm = L.get_singletonset(node.iter)
     checktype(sm, L.SMLookup)
     rel = L.get_name(sm.target)
     mask = Mask(sm.mask)
     keyvars = L.get_vartuple(sm.key)
     # Ensure the mask is consistent with how it's used.
     if mask != Mask.from_keylen(len(keyvars)):
         raise TypeError
     
     lhs = keyvars + (var,)
     return cls(lhs, rel)
Example #6
0
    def from_AST(cls, node, factory):
        """Construct from an Enumerator node of form
        
            var in {<rel>.smlookup(<mask>, <key vars>)}
        
        """
        checktype(node, L.Enumerator)

        var = L.get_name(node.target)
        sm = L.get_singletonset(node.iter)
        checktype(sm, L.SMLookup)
        rel = L.get_name(sm.target)
        mask = Mask(sm.mask)
        keyvars = L.get_vartuple(sm.key)
        # Ensure the mask is consistent with how it's used.
        if mask != Mask.from_keylen(len(keyvars)):
            raise TypeError

        lhs = keyvars + (var, )
        return cls(lhs, rel)
Example #7
0
 def from_AST(cls, node, factory):
     """Construct from Enumerator node of form
     
         (<var>, <var>) in _M
     """
     checktype(node, L.Enumerator)
     
     lhs = L.get_vartuple(node.target)
     rel = L.get_name(node.iter)
     if not len(lhs) == 2:
         raise TypeError
     cont, item = lhs
     if not is_mrel(rel):
         raise TypeError
     return cls(cont, item)
Example #8
0
 def from_AST(cls, node, factory):
     """Construct from Enumerator node of form
     
         (<var>, <var>, <var>) in _MAP
     """
     checktype(node, L.Enumerator)
     
     lhs = L.get_vartuple(node.target)
     rel = L.get_name(node.iter)
     if not len(lhs) == 3:
         raise TypeError
     map, key, value = lhs
     if not is_maprel(rel):
         raise TypeError
     return cls(map, key, value)
Example #9
0
 def from_expr(cls, node):
     """Construct from a membership expression
     
         (<var>, <var>) in _M
     """
     checktype(node, L.AST)
     
     left, op, right = L.get_cmp(node)
     checktype(op, L.In)
     lhs = L.get_vartuple(left)
     assert len(lhs) == 2
     cont, item = lhs
     rel = L.get_name(right)
     assert is_mrel(rel)
     return cls(cont, item)
Example #10
0
    def from_expr(cls, node):
        """Construct from a membership expression
        
            (<var>, <var>) in _M
        """
        checktype(node, L.AST)

        left, op, right = L.get_cmp(node)
        checktype(op, L.In)
        lhs = L.get_vartuple(left)
        assert len(lhs) == 2
        cont, item = lhs
        rel = L.get_name(right)
        assert is_mrel(rel)
        return cls(cont, item)
Example #11
0
    def from_AST(cls, node, factory):
        """Construct from Enumerator node of form
        
            (<var>, <var>) in _M
        """
        checktype(node, L.Enumerator)

        lhs = L.get_vartuple(node.target)
        rel = L.get_name(node.iter)
        if not len(lhs) == 2:
            raise TypeError
        cont, item = lhs
        if not is_mrel(rel):
            raise TypeError
        return cls(cont, item)
Example #12
0
 def visit_Maintenance(self, node):
     # Figure out whether this is an update to one of the
     # given sets by scanning the description string. Hackish.
     update_node = L.ps(node.desc)
     target = update_node.target
     is_relevant = (isinstance(target, L.Name) and
                    L.get_name(target) in self.rels)
     
     if node.name in self.invs and is_relevant:
         self.count += 1
         # Don't recurse. We don't want to double-count the update,
         # and there shouldn't be any original updates in the
         # inserted precode/postcode.
     else:
         self.generic_visit(node)
Example #13
0
    def from_AST(cls, node, factory):
        """Construct from Enumerator node of form
        
            (<var>, <var>, <var>) in _MAP
        """
        checktype(node, L.Enumerator)

        lhs = L.get_vartuple(node.target)
        rel = L.get_name(node.iter)
        if not len(lhs) == 3:
            raise TypeError
        map, key, value = lhs
        if not is_maprel(rel):
            raise TypeError
        return cls(map, key, value)
Example #14
0
 def from_expr(cls, node):
     """Construct from a membership condition expression of form
     
         <vars> in <rel>
     
     Note that this is syntactically different from the form used
     in comprehensions, even though their textual representation
     in source code is the same.
     """
     checktype(node, L.AST)
     
     left, op, right = L.get_cmp(node)
     checktype(op, L.In)
     lhs = L.get_vartuple(left)
     rel = L.get_name(right)
     return cls(lhs, rel)
Example #15
0
    def from_expr(cls, node):
        """Construct from a membership condition expression of form
        
            <vars> in <rel>
        
        Note that this is syntactically different from the form used
        in comprehensions, even though their textual representation
        in source code is the same.
        """
        checktype(node, L.AST)

        left, op, right = L.get_cmp(node)
        checktype(op, L.In)
        lhs = L.get_vartuple(left)
        rel = L.get_name(right)
        return cls(lhs, rel)
Example #16
0
    def from_AST(cls, node, factory):
        """Construct from enumerator of form
        
            (tupvar, elt1, ..., eltn) in _TUPN
        """
        checktype(node, L.Enumerator)

        lhs = L.get_vartuple(node.target)
        rel = L.get_name(node.iter)
        if not is_trel(rel):
            raise TypeError

        tup, *elts = lhs
        arity = get_trel(rel)
        assert arity == len(elts)

        return cls(tup, tuple(elts))
Example #17
0
    def from_AST(cls, node, factory):
        """Construct from Enumerator node of form
        
            <vars> in deltamatch(<rel>, <mask>, <val>, <limit>)
        """
        checktype(node, L.Enumerator)

        lhs = L.get_vartuple(node.target)
        checktype(node.iter, L.DeltaMatch)
        rel = L.get_name(node.iter.target)
        mask = Mask(node.iter.mask)
        val = node.iter.elem
        limit = node.iter.limit
        if limit not in [0, 1]:
            raise TypeError

        inferred_mask = Mask.from_vars(lhs, lhs)
        assert mask == inferred_mask

        return cls(lhs, rel, val, limit)
Example #18
0
 def from_AST(cls, node, factory):
     """Construct from Enumerator node of form
     
         <vars> in deltamatch(<rel>, <mask>, <val>, <limit>)
     """
     checktype(node, L.Enumerator)
     
     lhs = L.get_vartuple(node.target)
     checktype(node.iter, L.DeltaMatch)
     rel = L.get_name(node.iter.target)
     mask = Mask(node.iter.mask)
     val = node.iter.elem
     limit = node.iter.limit
     if limit not in [0, 1]:
         raise TypeError
     
     inferred_mask = Mask.from_vars(lhs, lhs)
     assert mask == inferred_mask
     
     return cls(lhs, rel, val, limit)
Example #19
0
    def visit_For(self, node):
        # Recurse only after we've handled the potential special case.

        if node.iter != self.comp:
            return self.generic_visit(node)

        spec = self.spec

        special_case = (
            node.orelse == () and spec.is_duplicate_safe
            and L.is_vartuple(node.target) and L.is_vartuple(spec.resexp)
            and (L.get_vartuple(node.target) == L.get_vartuple(spec.resexp) or
                 (L.is_name(node.target) and L.get_name(node.target) == '_')))
        if special_case:
            code = ()
            code += (L.Comment('Iterate ' + str(spec)), )
            code += spec.join.get_code(spec.params,
                                       node.body,
                                       augmented=self.augmented)
            return self.visit(code)
        else:
            return self.generic_visit(node)
Example #20
0
 def visit_For(self, node):
     # Recurse only after we've handled the potential special case.
     
     if node.iter != self.comp:
         return self.generic_visit(node)
     
     spec = self.spec
     
     special_case = (
         node.orelse == () and
         spec.is_duplicate_safe and
         L.is_vartuple(node.target) and
         L.is_vartuple(spec.resexp) and
         (L.get_vartuple(node.target) == L.get_vartuple(spec.resexp) or
          (L.is_name(node.target) and L.get_name(node.target) == '_'))
     )
     if special_case:
         code = ()
         code += (L.Comment('Iterate ' + str(spec)),)
         code += spec.join.get_code(spec.params, node.body,
                                    augmented=self.augmented)
         return self.visit(code)
     else:
         return self.generic_visit(node)
Example #21
0
 def visit_SetUpdate(self, node):
     if isinstance(node.target, L.Name):
         self.rels.add(L.get_name(node.target))