Ejemplo n.º 1
0
class TupleFlattener(L.NodeTransformer):
    def __init__(self, tupvar_namer):
        super().__init__()
        self.tupvar_namer = tupvar_namer
        self.trels = OrderedSet()

    def process(self, tree):
        self.new_clauses = []
        tree = super().process(tree)
        return tree, self.new_clauses

    def visit_Enumerator(self, node):
        # If LHS is a tuple, skip over the top level.
        # Either way, don't descend into RHS.
        if isinstance(node.target, L.Tuple):
            elts = self.visit(node.target.elts)
            new_target = node.target._replace(elts=elts)
            return node._replace(target=new_target)
        else:
            new_target = self.generic_visit(node.target)
            return node._replace(target=new_target)

    def visit_Tuple(self, node):
        # No need to recurse, that's taken care of by the caller
        # of this visitor.
        tupvar = self.tupvar_namer.next()
        arity = len(node.elts)
        trel = make_trel(arity)
        elts = (L.sn(tupvar), ) + node.elts
        new_cl = L.Enumerator(L.tuplify(elts, lval=True), L.ln(trel))
        self.new_clauses.append(new_cl)
        self.trels.add(trel)
        return L.sn(tupvar)
Ejemplo n.º 2
0
 def rhs_rels_from_comp(self, comp):
     rels = OrderedSet()
     for cl in comp.clauses:
         rel = self.rhs_rel(cl)
         if rel is not None:
             rels.add(rel)
     return tuple(rels)
Ejemplo n.º 3
0
class TupleFlattener(L.NodeTransformer):
    
    def __init__(self, tupvar_namer):
        super().__init__()
        self.tupvar_namer = tupvar_namer
        self.trels = OrderedSet()
    
    def process(self, tree):
        self.new_clauses = []
        tree = super().process(tree)
        return tree, self.new_clauses
    
    def visit_Enumerator(self, node):
        # If LHS is a tuple, skip over the top level.
        # Either way, don't descend into RHS.
        if isinstance(node.target, L.Tuple):
            elts = self.visit(node.target.elts)
            new_target = node.target._replace(elts=elts)
            return node._replace(target=new_target)
        else:
            new_target = self.generic_visit(node.target)
            return node._replace(target=new_target)
    
    def visit_Tuple(self, node):
        # No need to recurse, that's taken care of by the caller
        # of this visitor.
        tupvar = self.tupvar_namer.next()
        arity = len(node.elts)
        trel = make_trel(arity)
        elts = (L.sn(tupvar),) + node.elts
        new_cl = L.Enumerator(L.tuplify(elts, lval=True),
                              L.ln(trel))
        self.new_clauses.append(new_cl)
        self.trels.add(trel)
        return L.sn(tupvar)
Ejemplo n.º 4
0
 def rhs_rels_from_comp(self, comp):
     rels = OrderedSet()
     for cl in comp.clauses:
         rel = self.rhs_rel(cl)
         if rel is not None:
             rels.add(rel)
     return tuple(rels)
Ejemplo n.º 5
0
class BindingFinder(L.NodeVisitor):
    
    """Return names of variables that appear in a binding context, i.e.,
    that would have a Store context in Python's AST.
    """
    
    def process(self, tree):
        self.vars = OrderedSet()
        self.write_ctx = False
        super().process(tree)
        return self.vars
    
    def visit_Fun(self, node):
        self.vars.update(node.args)
        self.generic_visit(node)
    
    def visit_For(self, node):
        self.vars.add(node.target)
        self.generic_visit(node)
    
    def visit_DecompFor(self, node):
        self.vars.update(node.vars)
        self.generic_visit(node)
    
    def visit_Assign(self, node):
        self.vars.add(node.target)
        self.generic_visit(node)
    
    def visit_DecompAssign(self, node):
        self.vars.update(node.vars)
        self.generic_visit(node)
Ejemplo n.º 6
0
class BindingFinder(L.NodeVisitor):
    """Return names of variables that appear in a binding context, i.e.,
    that would have a Store context in Python's AST.
    """
    def process(self, tree):
        self.vars = OrderedSet()
        self.write_ctx = False
        super().process(tree)
        return self.vars

    def visit_Fun(self, node):
        self.vars.update(node.args)
        self.generic_visit(node)

    def visit_For(self, node):
        self.vars.add(node.target)
        self.generic_visit(node)

    def visit_DecompFor(self, node):
        self.vars.update(node.vars)
        self.generic_visit(node)

    def visit_Assign(self, node):
        self.vars.add(node.target)
        self.generic_visit(node)

    def visit_DecompAssign(self, node):
        self.vars.update(node.vars)
        self.generic_visit(node)
Ejemplo n.º 7
0
def preprocess_var_decls(tree):
    """Eliminate global variable declarations of the form
    
        S = Set()
        M = Map()
    
    and return a list of pairs of variables names and type names (i.e.,
    'Set' or 'Map').
    """
    assert isinstance(tree, P.Module)
    pat = P.Assign([P.Name(P.PatVar('_VAR'), P.Wildcard())],
                   P.Call(P.Name(P.PatVar('_KIND'), P.Load()), [], [], None,
                          None))

    decls = OrderedSet()
    body = []
    changed = False
    for stmt in tree.body:
        match = P.match(pat, stmt)
        if match is not None:
            var, kind = match['_VAR'], match['_KIND']
            if kind not in ['Set', 'Map']:
                raise L.ProgramError(
                    'Unknown declaration initializer {}'.format(kind))
            decls.add((var, kind))
            changed = True
        else:
            body.append(stmt)

    if changed:
        tree = tree._replace(body=body)
    return tree, list(decls)
Ejemplo n.º 8
0
def preprocess_var_decls(tree):
    """Eliminate global variable declarations of the form
    
        S = Set()
        M = Map()
    
    and return a list of pairs of variables names and type names (i.e.,
    'Set' or 'Map').
    """
    assert isinstance(tree, P.Module)
    pat = P.Assign(
        [P.Name(P.PatVar("_VAR"), P.Wildcard())], P.Call(P.Name(P.PatVar("_KIND"), P.Load()), [], [], None, None)
    )

    decls = OrderedSet()
    body = []
    changed = False
    for stmt in tree.body:
        match = P.match(pat, stmt)
        if match is not None:
            var, kind = match["_VAR"], match["_KIND"]
            if kind not in ["Set", "Map"]:
                raise L.ProgramError("Unknown declaration initializer {}".format(kind))
            decls.add((var, kind))
            changed = True
        else:
            body.append(stmt)

    if changed:
        tree = tree._replace(body=body)
    return tree, list(decls)
Ejemplo n.º 9
0
class RetrievalReplacer(L.NodeTransformer):
    
    """Replace simple field and map retrieval expressions with a
    variable. A retrieval expression is simple if the object or map
    part of the expression is just a variable. Raise an error if any
    non-simple retrievals are encountered.
    
    Retrievals are processed inner-to-outer, so complex expressions
    like a.b[c.d].e can be handled, so long as they are built up using
    only variables and retrievals.
    
    The name of the replacement variable is given by the field_namer
    and map_namer functions.
    
    Two attributes, field_repls and map_repls, are made available for
    inspecting what replacements were performed. They are OrderedSets
    of triples where the first two components are the object/map and
    field/key respectively, and the third component is the replacement
    variable name. These attributes are cleared when process() is called
    again.
    """
    
    def __init__(self, field_namer, map_namer):
        super().__init__()
        self.field_namer = field_namer
        self.map_namer = map_namer
    
    def process(self, tree):
        self.field_repls = OrderedSet()
        self.map_repls = OrderedSet()
        tree = super().process(tree)
        return tree
    
    def visit_Attribute(self, node):
        node = self.generic_visit(node)
        
        if not isinstance(node.value, L.Name):
            raise L.ProgramError('Non-simple field retrieval', node=node)
        obj = node.value.id
        field = node.attr
        
        new_name = self.field_namer(obj, field)
        self.field_repls.add((obj, field, new_name))
        
        return L.Name(new_name, node.ctx)
    
    def visit_Subscript(self, node):
        node = self.generic_visit(node)
        
        if not (isinstance(node.value, L.Name) and
                isinstance(node.slice, L.Index) and
                isinstance(node.slice.value, L.Name)):
            raise L.ProgramError('Non-simple map retrieval', node=node)
        map = node.value.id
        key = node.slice.value.id
        
        new_name = self.map_namer(map, key)
        self.map_repls.add((map, key, new_name))
        
        return L.Name(new_name, node.ctx)
Ejemplo n.º 10
0
 class Finder(L.NodeVisitor):
     def process(self, tree):
         self.syms = OrderedSet()
         super().process(tree)
         return self.syms
     def visit_Query(self, node):
         self.generic_visit(node)
         self.syms.add(symtab.get_symbols()[node.name])
Ejemplo n.º 11
0
    class Finder(L.NodeVisitor):
        def process(self, tree):
            self.syms = OrderedSet()
            super().process(tree)
            return self.syms

        def visit_Query(self, node):
            self.generic_visit(node)
            self.syms.add(symtab.get_symbols()[node.name])
Ejemplo n.º 12
0
class AggrMapReplacer(L.NodeTransformer):
    
    """Replace all occurrences of map lookups with fresh variables.
    Return the transformed AST, a list of new clauses binding the
    fresh variables, and a set of SetFromMap invariants that need
    to be transformed.
    
    This is intended to be used on each of a comprehension's clauses and
    its result expression, reusing the same transformer instance.
    """
    
    def __init__(self, fresh_names):
        super().__init__()
        self.fresh_names = fresh_names
        self.repls = OrderedDict()
        """Mapping from nodes to replacement variables."""
        self.sfm_invs = OrderedSet()
        """SetFromMap invariants."""
    
    def process(self, tree):
        self.new_clauses = []
        tree = super().process(tree)
        return tree, self.new_clauses, []
    
    def visit_DictLookup(self, node):
        node = self.generic_visit(node)
        
        # Only simple map lookups are allowed.
        assert isinstance(node.value, L.Name)
        assert L.is_tuple_of_names(node.key)
        assert node.default is None
        map = node.value.id
        keyvars = L.detuplify(node.key)
        
        var = self.repls.get(node, None)
        if var is None:
            mask = L.mapmask_from_len(len(keyvars))
            rel = N.SA_name(map, mask)
            
            # Create a fresh variable.
            self.repls[node] = var = next(self.fresh_names)
            
            # Construct a clause to bind it.
            vars = list(keyvars) + [var]
            new_clause = L.SetFromMapMember(vars, rel, map, mask)
            self.new_clauses.append(new_clause)
            
            # Construct a corresponding SetFromMap invariant.
            sfm = SetFromMapInvariant(rel, map, mask)
            self.sfm_invs.add(sfm)
        
        return L.Name(var)
Ejemplo n.º 13
0
class AggrMapReplacer(L.NodeTransformer):
    """Replace all occurrences of map lookups with fresh variables.
    Return the transformed AST, a list of new clauses binding the
    fresh variables, and a set of SetFromMap invariants that need
    to be transformed.
    
    This is intended to be used on each of a comprehension's clauses and
    its result expression, reusing the same transformer instance.
    """
    def __init__(self, fresh_names):
        super().__init__()
        self.fresh_names = fresh_names
        self.repls = OrderedDict()
        """Mapping from nodes to replacement variables."""
        self.sfm_invs = OrderedSet()
        """SetFromMap invariants."""

    def process(self, tree):
        self.new_clauses = []
        tree = super().process(tree)
        return tree, self.new_clauses, []

    def visit_DictLookup(self, node):
        node = self.generic_visit(node)

        # Only simple map lookups are allowed.
        assert isinstance(node.value, L.Name)
        assert L.is_tuple_of_names(node.key)
        assert node.default is None
        map = node.value.id
        keyvars = L.detuplify(node.key)

        var = self.repls.get(node, None)
        if var is None:
            mask = L.mapmask_from_len(len(keyvars))
            rel = N.SA_name(map, mask)

            # Create a fresh variable.
            self.repls[node] = var = next(self.fresh_names)

            # Construct a clause to bind it.
            vars = list(keyvars) + [var]
            new_clause = L.SetFromMapMember(vars, rel, map, mask)
            self.new_clauses.append(new_clause)

            # Construct a corresponding SetFromMap invariant.
            sfm = SetFromMapInvariant(rel, map, mask)
            self.sfm_invs.add(sfm)

        return L.Name(var)
Ejemplo n.º 14
0
class RelmatchQueryFinder(L.NodeVisitor):
    
    """Return the set of auxmap specs that are used by some
    relmatch query or set-map lookup.
    """
    
    def process(self, tree):
        self.specs = OrderedSet()
        super().process(tree)
        return self.specs
    
    def visit_SetMatch(self, node):
        self.generic_visit(node)
        
        if is_relmatch(node):
            spec, _key = get_relmatch(node)
            self.specs.add(spec)
    
    def visit_SMLookup(self, node):
        self.generic_visit(node)
        
        if is_relsmlookup(node):
            spec, _key = get_relsmlookup(node)
            self.specs.add(spec)
Ejemplo n.º 15
0
class PlainFunctionFinder(NodeVisitor):
    """Return all names of top-level functions that only use plain
    arguments and calls. Non-top-level function definitions are not
    analyzed. Calls of non-Name nodes are not analyzed. It is an
    error for the functions to have multiple definitions.
    
    If stmt_only is True, functions that are called in expression
    context are excluded.
    """
    def __init__(self, *, stmt_only):
        super().__init__()
        self.stmt_only = stmt_only

    def process(self, tree):
        self.toplevel_funcs = OrderedSet()
        self.excluded_funcs = set()

        self.infunc = False
        super().process(tree)
        assert not self.infunc

        return self.toplevel_funcs - self.excluded_funcs

    def visit_FunctionDef(self, node):
        if self.infunc:
            return

        name = node.name
        assert name not in self.toplevel_funcs, \
            'Multiple definitions of function ' + name
        self.toplevel_funcs.add(name)

        if not is_plainfuncdef(node):
            self.excluded_funcs.add(name)

        self.infunc = True
        self.generic_visit(node)
        self.infunc = False

    def visit_Expr(self, node):
        if self.stmt_only and isinstance(node.value, Call):
            # Treat Call nodes specially by directly calling
            # generic_visit() on them, bypassing the visit_Call()
            # behavior that would mark it as a bad call.
            self.generic_visit(node.value)
        else:
            # Otherwise just recurse as normal.
            self.visit(node.value)

    def visit_Call(self, node):
        if not isinstance(node.func, Name):
            self.generic_visit(node)
            return
        name = node.func.id

        if self.stmt_only:
            # We only get here if this call occurred in expression
            # context.
            self.excluded_funcs.add(name)
        else:
            if not is_plaincall(node):
                self.excluded_funcs.add(name)

        self.generic_visit(node)

    def visit_DemQuery(self, node):
        # For our purposes here, these are interpreted as calls in
        # expression context.
        name = N.queryfunc(node.demname)
        if self.stmt_only:
            self.excluded_funcs.add(name)

        self.generic_visit(node)
Ejemplo n.º 16
0
class RelationFinder(L.NodeVisitor):
    """Find variables that we can statically infer to be relations,
    i.e. sets that are unaliased and top-level.
    
    For R to be inferred to be a relation, it must have a global-scope
    initialization having one of the following forms:
    
        R = Set()
        R = incoq.runtime.Set()
        R = set()
    
    and its only other occurrences must have the forms:
    
        - a SetUpdate naming R as the target
        
        - the RHS of membership clauses (including condition clauses)
        
        - the RHS of a For loop
    """
    def process(self, tree):
        self.inited = OrderedSet()
        self.disqual = OrderedSet()
        super().process(tree)
        return self.inited - self.disqual

    # Manage a toplevel flag to record whether we're at global scope.

    def visit_Module(self, node):
        self.toplevel = True
        self.generic_visit(node)

    def nontoplevel_helper(self, node):
        last = self.toplevel
        self.toplevel = False
        self.generic_visit(node)
        self.toplevel = last

    visit_FunctionDef = nontoplevel_helper
    visit_ClassDef = nontoplevel_helper

    def visit_Assign(self, node):
        allowed_inits = [
            L.pe('Set()'),
            L.pe('incoq.runtime.Set()'),
            L.pe('set()'),
        ]
        # If this is a relation initializer, mark the relation name
        # and don't recurse.
        if (self.toplevel and L.is_varassign(node)):
            name, value = L.get_varassign(node)
            if value in allowed_inits:
                self.inited.add(name)
                return

        self.generic_visit(node)

    def visit_SetUpdate(self, node):
        # Skip the target if it's just a name.
        if isinstance(node.target, L.Name):
            self.visit(node.elem)
        else:
            self.generic_visit(node)

    def visit_For(self, node):
        # Skip the iter if it's just a name.
        if isinstance(node.iter, L.Name):
            self.visit(node.target)
            self.visit(node.body)
            self.visit(node.orelse)
        else:
            self.generic_visit(node)

    def visit_Comp(self, node):
        # Skip the iter of each clause if it's just a name.
        # Also recognize condition clauses that express memberships.
        # Always skip the params and options.
        self.visit(node.resexp)
        for cl in node.clauses:
            if (isinstance(cl, L.Enumerator) and isinstance(cl.iter, L.Name)):
                self.visit(cl.target)
            elif (isinstance(cl, L.Compare)
                  and len(cl.ops) == len(cl.comparators) == 1
                  and isinstance(cl.ops[0], L.In)
                  and isinstance(cl.comparators[0], L.Name)):
                self.visit(cl.left)
            else:
                self.visit(cl)

    def visit_Name(self, node):
        # We got here through some disallowed use of R.
        self.disqual.add(node.id)
Ejemplo n.º 17
0
Archivo: util.py Proyecto: IncOQ/incoq
class VarsFinder(NodeVisitor):
    
    """Simple finder of variables (Name nodes).
    
    Flags:
    
        ignore_store:
            Name nodes with Store context are ignored, as are update
            operations (e.g. SetUpdate). As an exception, Name nodes
            on the LHS of Enumerators are not ignored. This is to
            ensure safety under pattern matching semantics.
        ignore_functions:
            Name nodes that appear to be functions are ignored.
        ignore_rels:
            Names that appear to be relations are ignored.
    
    The builtins None, True, and False are always excluded, as they
    are NameConstants, not variables.
    """
    
    def __init__(self, *,
                 ignore_store=False,
                 ignore_functions=False,
                 ignore_rels=False):
        super().__init__()
        self.ignore_store = ignore_store
        self.ignore_functions = ignore_functions
        self.ignore_rels = ignore_rels
    
    def process(self, tree):
        self.usedvars = OrderedSet()
        super().process(tree)
        return self.usedvars
    
    def visit_Name(self, node):
        self.generic_visit(node)
        if not (self.ignore_store and isinstance(node.ctx, Store)):
            self.usedvars.add(node.id)
    
    def visit_Call(self, node):
        class IGNORE(iast.AST):
            _meta = True
        
        if isinstance(node.func, Name) and self.ignore_functions:
            self.generic_visit(node._replace(func=IGNORE()))
        
        else:
            self.generic_visit(node)
    
    def visit_Enumerator(self, node):
        if is_vartuple(node.target):
            # Bypass ignore_store.
            vars = get_vartuple(node.target)
            self.usedvars.update(vars)
        else:
            self.visit(node.target)
        
        if not (self.ignore_rels and isinstance(node.iter, Name)):
            self.visit(node.iter)
    
    def visit_Comp(self, node):
        self.visit(node.resexp)
        # Hack to ensure we don't grab rels on RHS of
        # membership conditions.
        for i in node.clauses:
            if (self.ignore_rels and
                isinstance(i, Compare) and
                len(i.ops) == len(i.comparators) == 1 and
                isinstance(i.ops[0], In) and
                isinstance(i.comparators[0], Name)):
                self.visit(i.left)
            else:
                self.visit(i)
    
    def visit_Aggregate(self, node):
        if isinstance(node.value, Name) and self.ignore_rels:
            return
        else:
            self.generic_visit(node)
    
    def visit_SetMatch(self, node):
        if isinstance(node.target, Name) and self.ignore_rels:
            self.visit(node.key)
        else:
            self.generic_visit(node)
    
    def update_helper(self, node):
        IGNORE = object()
        
        if isinstance(node.target, Name) and self.ignore_store:
            self.generic_visit(node._replace(target=IGNORE))
        
        else:
            self.generic_visit(node)
    
    visit_SetUpdate = update_helper
    visit_RCSetRefUpdate = update_helper
    visit_AssignKey = update_helper
    visit_DelKey = update_helper
Ejemplo n.º 18
0
class PlainFunctionFinder(NodeVisitor):
    
    """Return all names of top-level functions that only use plain
    arguments and calls. Non-top-level function definitions are not
    analyzed. Calls of non-Name nodes are not analyzed. It is an
    error for the functions to have multiple definitions.
    
    If stmt_only is True, functions that are called in expression
    context are excluded.
    """
    
    def __init__(self, *, stmt_only):
        super().__init__()
        self.stmt_only = stmt_only
    
    def process(self, tree):
        self.toplevel_funcs = OrderedSet()
        self.excluded_funcs = set()
        
        self.infunc = False
        super().process(tree)
        assert not self.infunc
        
        return self.toplevel_funcs - self.excluded_funcs
    
    def visit_FunctionDef(self, node):
        if self.infunc:
            return
        
        name = node.name
        assert name not in self.toplevel_funcs, \
            'Multiple definitions of function ' + name
        self.toplevel_funcs.add(name)
        
        if not is_plainfuncdef(node):
            self.excluded_funcs.add(name)
        
        self.infunc = True
        self.generic_visit(node)
        self.infunc = False
    
    def visit_Expr(self, node):
        if self.stmt_only and isinstance(node.value, Call):
            # Treat Call nodes specially by directly calling
            # generic_visit() on them, bypassing the visit_Call()
            # behavior that would mark it as a bad call.
            self.generic_visit(node.value)
        else:
            # Otherwise just recurse as normal.
            self.visit(node.value)
    
    def visit_Call(self, node):
        if not isinstance(node.func, Name):
            self.generic_visit(node)
            return
        name = node.func.id
        
        if self.stmt_only:
            # We only get here if this call occurred in expression
            # context.
            self.excluded_funcs.add(name)
        else:
            if not is_plaincall(node):
                self.excluded_funcs.add(name)
        
        self.generic_visit(node)
    
    def visit_DemQuery(self, node):
        # For our purposes here, these are interpreted as calls in
        # expression context.
        name = N.queryfunc(node.demname)
        if self.stmt_only:
            self.excluded_funcs.add(name)
        
        self.generic_visit(node)
Ejemplo n.º 19
0
class InvariantFinder(L.NodeVisitor):
    """Find all set invariants needed in the program."""
    def process(self, tree):
        self.auxmaps = OrderedSet()
        self.setfrommaps = OrderedSet()
        self.wraps = OrderedSet()

        super().process(tree)

        return self.auxmaps, self.setfrommaps, self.wraps

    def imglookup_helper(self, node):
        """Create an AuxmapInvariant for this node if applicable.
        Return the invariant, or None if not applicable. Do not add
        the invariant yet.
        """
        if not isinstance(node.set, L.Name):
            return None
        rel = node.set.id

        map = N.get_auxmap_name(rel, node.mask)
        unwrap_key = len(node.bounds) == 1
        auxmap = AuxmapInvariant(map, rel, node.mask, unwrap_key, False)
        return auxmap

    def visit_ImgLookup(self, node):
        self.generic_visit(node)

        auxmap = self.imglookup_helper(node)
        if auxmap is not None:
            self.auxmaps.add(auxmap)

    def visit_SetFromMap(self, node):
        self.generic_visit(node)

        if not isinstance(node.map, L.Name):
            return
        map = node.map.id

        rel = N.SA_name(map, node.mask)
        setfrommap = SetFromMapInvariant(rel, map, node.mask)
        self.setfrommaps.add(setfrommap)

    def visit_Unwrap(self, node):
        # Catch case where the immediate child is an ImgLookup, in which
        # case we can generate an AuxmapInvariant with the unwrap_value
        # flag set.
        if isinstance(node.value, L.ImgLookup):
            # Recurse over children below the ImgLookup.
            self.generic_visit(node.value)

            auxmap = self.imglookup_helper(node.value)
            if auxmap is not None:
                auxmap = auxmap._replace(unwrap_value=True)
                self.auxmaps.add(auxmap)
                return

        else:
            # Don't run in the case where we already did generic_visit()
            # above but failed to return.
            self.generic_visit(node)

        # Couldn't construct auxmap for ourselves + child;
        # treat this as normal unwrap.

        if not isinstance(node.value, L.Name):
            return
        oper = node.value.id

        rel = N.get_unwrap_name(oper)
        wrapinv = WrapInvariant(rel, oper, True)
        self.wraps.add(wrapinv)

    def visit_Wrap(self, node):
        self.generic_visit(node)

        if not isinstance(node.value, L.Name):
            return
        oper = node.value.id

        rel = N.get_wrap_name(oper)
        wrapinv = WrapInvariant(rel, oper, False)
        self.wraps.add(wrapinv)
Ejemplo n.º 20
0
class RelationFinder(L.NodeVisitor):
    
    """Find variables that we can statically infer to be relations,
    i.e. sets that are unaliased and top-level.
    
    For R to be inferred to be a relation, it must have a global-scope
    initialization having one of the following forms:
    
        R = Set()
        R = incoq.runtime.Set()
        R = set()
    
    and its only other occurrences must have the forms:
    
        - a SetUpdate naming R as the target
        
        - the RHS of membership clauses (including condition clauses)
        
        - the RHS of a For loop
    """
    
    def process(self, tree):
        self.inited = OrderedSet()
        self.disqual = OrderedSet()
        super().process(tree)
        return self.inited - self.disqual
    
    # Manage a toplevel flag to record whether we're at global scope.
    
    def visit_Module(self, node):
        self.toplevel = True
        self.generic_visit(node)
    
    def nontoplevel_helper(self, node):
        last = self.toplevel
        self.toplevel = False
        self.generic_visit(node)
        self.toplevel = last
    
    visit_FunctionDef = nontoplevel_helper
    visit_ClassDef = nontoplevel_helper
    
    def visit_Assign(self, node):
        allowed_inits = [
            L.pe('Set()'),
            L.pe('incoq.runtime.Set()'),
            L.pe('set()'),
        ]
        # If this is a relation initializer, mark the relation name
        # and don't recurse.
        if (self.toplevel and
            L.is_varassign(node)):
            name, value = L.get_varassign(node)
            if value in allowed_inits:
                self.inited.add(name)
                return
        
        self.generic_visit(node)
    
    def visit_SetUpdate(self, node):
        # Skip the target if it's just a name.
        if isinstance(node.target, L.Name):
            self.visit(node.elem)
        else:
            self.generic_visit(node)
    
    def visit_For(self, node):
        # Skip the iter if it's just a name.
        if isinstance(node.iter, L.Name):
            self.visit(node.target)
            self.visit(node.body)
            self.visit(node.orelse)
        else:
            self.generic_visit(node)
    
    def visit_Comp(self, node):
        # Skip the iter of each clause if it's just a name.
        # Also recognize condition clauses that express memberships.
        # Always skip the params and options.
        self.visit(node.resexp)
        for cl in node.clauses:
            if (isinstance(cl, L.Enumerator) and
                isinstance(cl.iter, L.Name)):
                self.visit(cl.target)
            elif (isinstance(cl, L.Compare) and
                  len(cl.ops) == len(cl.comparators) == 1 and
                  isinstance(cl.ops[0], L.In) and
                  isinstance(cl.comparators[0], L.Name)):
                self.visit(cl.left)
            else:
                self.visit(cl)
    
    def visit_Name(self, node):
        # We got here through some disallowed use of R.
        self.disqual.add(node.id)
Ejemplo n.º 21
0
class VarsFinder(NodeVisitor):
    """Simple finder of variables (Name nodes).
    
    Flags:
    
        ignore_store:
            Name nodes with Store context are ignored, as are update
            operations (e.g. SetUpdate). As an exception, Name nodes
            on the LHS of Enumerators are not ignored. This is to
            ensure safety under pattern matching semantics.
        ignore_functions:
            Name nodes that appear to be functions are ignored.
        ignore_rels:
            Names that appear to be relations are ignored.
    
    The builtins None, True, and False are always excluded, as they
    are NameConstants, not variables.
    """
    def __init__(self,
                 *,
                 ignore_store=False,
                 ignore_functions=False,
                 ignore_rels=False):
        super().__init__()
        self.ignore_store = ignore_store
        self.ignore_functions = ignore_functions
        self.ignore_rels = ignore_rels

    def process(self, tree):
        self.usedvars = OrderedSet()
        super().process(tree)
        return self.usedvars

    def visit_Name(self, node):
        self.generic_visit(node)
        if not (self.ignore_store and isinstance(node.ctx, Store)):
            self.usedvars.add(node.id)

    def visit_Call(self, node):
        class IGNORE(iast.AST):
            _meta = True

        if isinstance(node.func, Name) and self.ignore_functions:
            self.generic_visit(node._replace(func=IGNORE()))

        else:
            self.generic_visit(node)

    def visit_Enumerator(self, node):
        if is_vartuple(node.target):
            # Bypass ignore_store.
            vars = get_vartuple(node.target)
            self.usedvars.update(vars)
        else:
            self.visit(node.target)

        if not (self.ignore_rels and isinstance(node.iter, Name)):
            self.visit(node.iter)

    def visit_Comp(self, node):
        self.visit(node.resexp)
        # Hack to ensure we don't grab rels on RHS of
        # membership conditions.
        for i in node.clauses:
            if (self.ignore_rels and isinstance(i, Compare)
                    and len(i.ops) == len(i.comparators) == 1
                    and isinstance(i.ops[0], In)
                    and isinstance(i.comparators[0], Name)):
                self.visit(i.left)
            else:
                self.visit(i)

    def visit_Aggregate(self, node):
        if isinstance(node.value, Name) and self.ignore_rels:
            return
        else:
            self.generic_visit(node)

    def visit_SetMatch(self, node):
        if isinstance(node.target, Name) and self.ignore_rels:
            self.visit(node.key)
        else:
            self.generic_visit(node)

    def update_helper(self, node):
        IGNORE = object()

        if isinstance(node.target, Name) and self.ignore_store:
            self.generic_visit(node._replace(target=IGNORE))

        else:
            self.generic_visit(node)

    visit_SetUpdate = update_helper
    visit_RCSetRefUpdate = update_helper
    visit_AssignKey = update_helper
    visit_DelKey = update_helper
Ejemplo n.º 22
0
class DemandTransformer(ContextTracker):
    
    """Modify each query appearing in the tree to add demand. For
    comprehensions, this means adding a new clause, while for
    aggregates, it means turning an Aggr node into an AggrRestr node.
    
    Outer queries get a demand set, while inner queries get a demand
    query. The demand_set and demand_query symbol attributes are set
    accordingly.
    
    Only the first occurrence of a query triggers new processing.
    Subsequent occurrences are rewritten to be the same as the first.
    """
    
    # Demand rewriting happens in a top-down fashion, so that inner
    # queries are rewritten after their outer comprehensions already
    # have a clause over a demand set or demand query.
    
    def process(self, tree):
        self.queries_with_usets = OrderedSet()
        """Outer queries, for which a demand set and a call to a demand
        function are added.
        """
        self.rewrite_cache = {}
        """Map from query name to rewritten AST."""
        self.demand_queries = set()
        """Set of names of demand queries that we introduced, which
        shouldn't be recursed into.
        """
        
        return super().process(tree)
    
    def add_demand_function_call(self, query_sym, query_node, ann):
        """Return a Query node wrapped with a call to a demand function,
        if needed.
        """
        # Skip if there's no demand set associated with this query.
        if query_sym.name not in self.queries_with_usets:
            return query_node
        # Skip if we have a nodemand annotation.
        if ann is not None and ann.get('nodemand', False):
            return query_node
        
        demand_call = L.Call(N.get_query_demand_func_name(query_sym.name),
                             [L.tuplify(query_sym.demand_params)])
        return L.FirstThen(demand_call, query_node)
    
    def visit_Module(self, node):
        node = self.generic_visit(node)
        
        # Add declarations for demand functions.
        funcs = []
        for query in self.queries_with_usets:
            query_sym = self.symtab.get_queries()[query]
            func = make_demand_func(query_sym)
            funcs.append(func)
        
        node = node._replace(body=tuple(funcs) + node.body)
        return node
    
    def rewrite_with_demand(self, query_sym, node):
        """Given a query symbol and its associated Comp or Aggr node,
        return the demand-transformed version of that node (not
        transforming any subqueries).
        """
        symtab = self.symtab
        demand_params = query_sym.demand_params
        
        if not query_sym.uses_demand:
            return node
        
        # Make a demand set or demand query.
        left_clauses = self.get_left_clauses()
        if left_clauses is None:
            dem_sym = make_demand_set(symtab, query_sym)
            dem_node = L.Name(dem_sym.name)
            dem_clause = L.RelMember(demand_params, dem_sym.name)
            self.queries_with_usets.add(query_sym.name)
        else:
            dem_sym = make_demand_query(symtab, query_sym, left_clauses)
            dem_node = dem_sym.make_node()
            dem_clause = L.VarsMember(demand_params, dem_node)
            self.demand_queries.add(dem_sym.name)
        
        # Determine the rewritten node.
        if isinstance(node, L.Comp):
            node = node._replace(clauses=(dem_clause,) + node.clauses)
        elif isinstance(node, L.Aggr):
            node = L.AggrRestr(node.op, node.value, demand_params, dem_node)
        else:
            raise AssertionError('No rule for handling demand for {} node'
                                 .format(node.__class__.__name__))
        
        return node
    
    def visit_Query(self, node):
        # If this is a demand query that we added, it does not need
        # transformation.
        if node.name in self.demand_queries:
            return node
        
        query_sym = self.symtab.get_queries()[node.name]
        
        # If we've seen it before, reuse previous result.
        if node.name in self.rewrite_cache:
            # Possibly wrap with a call to the demand function.
            ann = node.ann
            node = self.rewrite_cache[node.name]
            node = self.add_demand_function_call(query_sym, node, ann)
            return node
        
        # Rewrite to use demand.
        inner_node = self.rewrite_with_demand(query_sym, query_sym.node)
        node = node._replace(query=inner_node)
        
        # Recurse to handle subqueries.
        node = super().visit_Query(node)
        
        # Update symbol.
        query_sym.node = node.query
        
        # Update cache.
        self.rewrite_cache[query_sym.name] = node
        
        # Possibly wrap with a call to the demand function.
        node = self.add_demand_function_call(query_sym, node, node.ann)
        
        return node
Ejemplo n.º 23
0
class InvariantFinder(L.NodeVisitor):
    
    """Find all set invariants needed in the program."""
    
    def process(self, tree):
        self.auxmaps = OrderedSet()
        self.setfrommaps = OrderedSet()
        self.wraps = OrderedSet()
        
        super().process(tree)
        
        return self.auxmaps, self.setfrommaps, self.wraps
    
    def imglookup_helper(self, node):
        """Create an AuxmapInvariant for this node if applicable.
        Return the invariant, or None if not applicable. Do not add
        the invariant yet.
        """
        if not isinstance(node.set, L.Name):
            return None
        rel = node.set.id
        
        map = N.get_auxmap_name(rel, node.mask)
        unwrap_key = len(node.bounds) == 1
        auxmap = AuxmapInvariant(map, rel, node.mask, unwrap_key, False)
        return auxmap
    
    def visit_ImgLookup(self, node):
        self.generic_visit(node)
        
        auxmap = self.imglookup_helper(node)
        if auxmap is not None:
            self.auxmaps.add(auxmap)
    
    def visit_SetFromMap(self, node):
        self.generic_visit(node)
        
        if not isinstance(node.map, L.Name):
            return
        map = node.map.id
        
        rel = N.SA_name(map, node.mask)
        setfrommap = SetFromMapInvariant(rel, map, node.mask)
        self.setfrommaps.add(setfrommap)
    
    def visit_Unwrap(self, node):
        # Catch case where the immediate child is an ImgLookup, in which
        # case we can generate an AuxmapInvariant with the unwrap_value
        # flag set.
        if isinstance(node.value, L.ImgLookup):
            # Recurse over children below the ImgLookup.
            self.generic_visit(node.value)
            
            auxmap = self.imglookup_helper(node.value)
            if auxmap is not None:
                auxmap = auxmap._replace(unwrap_value=True)
                self.auxmaps.add(auxmap)
                return
        
        else:
            # Don't run in the case where we already did generic_visit()
            # above but failed to return.
            self.generic_visit(node)
        
        # Couldn't construct auxmap for ourselves + child;
        # treat this as normal unwrap.
        
        if not isinstance(node.value, L.Name):
            return
        oper = node.value.id
        
        rel = N.get_unwrap_name(oper)
        wrapinv = WrapInvariant(rel, oper, True)
        self.wraps.add(wrapinv)
    
    def visit_Wrap(self, node):
        self.generic_visit(node)
        
        if not isinstance(node.value, L.Name):
            return
        oper = node.value.id
        
        rel = N.get_wrap_name(oper)
        wrapinv = WrapInvariant(rel, oper, False)
        self.wraps.add(wrapinv)
Ejemplo n.º 24
0
class DemandTransformer(ContextTracker):
    """Modify each query appearing in the tree to add demand. For
    comprehensions, this means adding a new clause, while for
    aggregates, it means turning an Aggr node into an AggrRestr node.
    
    Outer queries get a demand set, while inner queries get a demand
    query. The demand_set and demand_query symbol attributes are set
    accordingly.
    
    Only the first occurrence of a query triggers new processing.
    Subsequent occurrences are rewritten to be the same as the first.
    """

    # Demand rewriting happens in a top-down fashion, so that inner
    # queries are rewritten after their outer comprehensions already
    # have a clause over a demand set or demand query.

    def process(self, tree):
        self.queries_with_usets = OrderedSet()
        """Outer queries, for which a demand set and a call to a demand
        function are added.
        """
        self.rewrite_cache = {}
        """Map from query name to rewritten AST."""
        self.demand_queries = set()
        """Set of names of demand queries that we introduced, which
        shouldn't be recursed into.
        """

        return super().process(tree)

    def add_demand_function_call(self, query_sym, query_node, ann):
        """Return a Query node wrapped with a call to a demand function,
        if needed.
        """
        # Skip if there's no demand set associated with this query.
        if query_sym.name not in self.queries_with_usets:
            return query_node
        # Skip if we have a nodemand annotation.
        if ann is not None and ann.get('nodemand', False):
            return query_node

        demand_call = L.Call(N.get_query_demand_func_name(query_sym.name),
                             [L.tuplify(query_sym.demand_params)])
        return L.FirstThen(demand_call, query_node)

    def visit_Module(self, node):
        node = self.generic_visit(node)

        # Add declarations for demand functions.
        funcs = []
        for query in self.queries_with_usets:
            query_sym = self.symtab.get_queries()[query]
            func = make_demand_func(query_sym)
            funcs.append(func)

        node = node._replace(body=tuple(funcs) + node.body)
        return node

    def rewrite_with_demand(self, query_sym, node):
        """Given a query symbol and its associated Comp or Aggr node,
        return the demand-transformed version of that node (not
        transforming any subqueries).
        """
        symtab = self.symtab
        demand_params = query_sym.demand_params

        if not query_sym.uses_demand:
            return node

        # Make a demand set or demand query.
        left_clauses = self.get_left_clauses()
        if left_clauses is None:
            dem_sym = make_demand_set(symtab, query_sym)
            dem_node = L.Name(dem_sym.name)
            dem_clause = L.RelMember(demand_params, dem_sym.name)
            self.queries_with_usets.add(query_sym.name)
        else:
            dem_sym = make_demand_query(symtab, query_sym, left_clauses)
            dem_node = dem_sym.make_node()
            dem_clause = L.VarsMember(demand_params, dem_node)
            self.demand_queries.add(dem_sym.name)

        # Determine the rewritten node.
        if isinstance(node, L.Comp):
            node = node._replace(clauses=(dem_clause, ) + node.clauses)
        elif isinstance(node, L.Aggr):
            node = L.AggrRestr(node.op, node.value, demand_params, dem_node)
        else:
            raise AssertionError(
                'No rule for handling demand for {} node'.format(
                    node.__class__.__name__))

        return node

    def visit_Query(self, node):
        # If this is a demand query that we added, it does not need
        # transformation.
        if node.name in self.demand_queries:
            return node

        query_sym = self.symtab.get_queries()[node.name]

        # If we've seen it before, reuse previous result.
        if node.name in self.rewrite_cache:
            # Possibly wrap with a call to the demand function.
            ann = node.ann
            node = self.rewrite_cache[node.name]
            node = self.add_demand_function_call(query_sym, node, ann)
            return node

        # Rewrite to use demand.
        inner_node = self.rewrite_with_demand(query_sym, query_sym.node)
        node = node._replace(query=inner_node)

        # Recurse to handle subqueries.
        node = super().visit_Query(node)

        # Update symbol.
        query_sym.node = node.query

        # Update cache.
        self.rewrite_cache[query_sym.name] = node

        # Possibly wrap with a call to the demand function.
        node = self.add_demand_function_call(query_sym, node, node.ann)

        return node
Ejemplo n.º 25
0
class TypeAnalysisStepper(L.AdvNodeVisitor):
    """Run one iteration of transfer functions for all the program's
    nodes. Return the store (variable -> type mapping) and a boolean
    indicating whether the store has been changed (i.e. whether new
    information was inferred).
    """
    def __init__(self, store, height_limit=None, unknown=Top, fixed_vars=None):
        super().__init__()
        self.store = store
        """Mapping from symbol names to inferred types.
        Each type may only change in a monotonically increasing way.
        """
        self.height_limit = height_limit
        """Maximum height of type terms in the store. None for no limit."""
        self.illtyped = OrderedSet()
        """Nodes where the well-typedness constraints are violated."""
        self.changed = True
        """True if the last call to process() updated the store
        (or if there was no call so far).
        """
        self.unknown = unknown
        """Type to use for variables not in the store. Should be Bottom,
        Top, or None. None indicates that an error should be raised for
        unknown variables.
        """
        if fixed_vars is None:
            fixed_vars = []
        self.fixed_vars = fixed_vars
        """Names of variables whose types cannot be changed by inference."""

    def process(self, tree):
        self.changed = False
        super().process(tree)
        return self.store

    def get_store(self, name):
        try:
            return self.store[name]
        except KeyError:
            if self.unknown is None:
                raise
            else:
                return self.unknown

    def update_store(self, name, type):
        if name in self.fixed_vars:
            return self.store[name]

        old_type = self.get_store(name)
        new_type = old_type.join(type)
        if self.height_limit is not None:
            new_type = new_type.widen(self.height_limit)
        if new_type != old_type:
            self.changed = True
            self.store[name] = new_type
        return new_type

    def mark_bad(self, node):
        self.illtyped.add(node)

    def readonly(f):
        """Decorator for handlers for expression nodes that only
        make sense in read context.
        """
        @wraps(f)
        def wrapper(self, node, *, type=None):
            if type is not None:
                self.mark_bad(node)
            return f(self, node, type=type)

        return wrapper

    @readonly
    def default_expr_handler(self, node, *, type=None):
        """Expression handler that just recurses and returns Top."""
        self.generic_visit(node)
        return Top

    # Each visitor handler has a monotonic transfer function and
    # possibly a constraint for well-typedness.
    #
    # The behavior of each handler is described in a comment using
    # the following informal syntax.
    #
    #    X := Y          Assign the join of X and Y to X
    #    Check X <= Y    Well-typedness constraint that X is a
    #                    subtype of Y
    #    Return X        Return X as the type of an expression
    #    Fail            Mark an error at this node and return Top
    #
    # This syntax is augmented by If/Elif/Else and pattern matching,
    # e.g. iter == Set<T> introduces T as the element type of iter.
    # join(T1, T2) is the lattice join of T1 and T2.
    #
    # Expression visitors have a keyword argument 'type', and can be
    # used in read or write context. In read context, type is None.
    # In write context, type is the type passed in from context. In
    # both cases the type of the expression is returned. Handlers
    # that do not tolerate write context are decorated as @readonly;
    # they still run but record a well-typedness error.

    def get_sequence_elt(self, node, t_seq, seq_cls):
        """Given a sequence type, and a sequence type constructor
        (e.g. Sequence, List, or Set), return the element type of
        the sequence type. If the sequence type cannot safely be
        converted to the given type constructor's form (for instance
        if it is not actually a sequence), return Top and mark node
        as ill-typed.
        """
        # Join to ensure that we're looking at a type object that
        # is an instance of seq_cls, as opposed to some other type
        # object in the lattice.
        t_seq = t_seq.join(seq_cls(Bottom))
        if not t_seq.issmaller(seq_cls(Top)):
            self.mark_bad(node)
            return Top
        return t_seq.elt

    def get_map_keyval(self, node, t_map):
        """Given a map type, return its key and value type. If the
        given type cannot safely be converted to a map type, return
        a pair of Tops instead and mark the node as ill-typed.
        """
        t_map = t_map.join(Map(Bottom, Bottom))
        if not t_map.issmaller(Map(Top, Top)):
            self.mark_bad(node)
            return Top, Top
        return t_map.key, t_map.value

    def get_tuple_elts(self, node, t_tup, arity):
        """Given a tuple type, return the element types. If the given
        type cannot safely be converted to a tuple of the given arity,
        return arity many Tops, and mark node as ill-typed.
        """
        t_tup = t_tup.join(Tuple([Bottom] * arity))
        if not t_tup.issmaller(Tuple([Top] * arity)):
            self.mark_bad(node)
            return [Top] * arity
        return t_tup.elts

    # Use default handler for Return.

    def visit_For(self, node):
        # If join(iter, Sequence<Bottom>) == Sequence<T>:
        #   target := T
        # Else:
        #   target := Top
        #
        # Check iter <= Sequence<Top>
        t_iter = self.visit(node.iter)
        t_target = self.get_sequence_elt(node, t_iter, Sequence)
        self.update_store(node.target, type=t_target)
        self.visit(node.body)

    def visit_DecompFor(self, node):
        # If join(iter, Sequence<Tuple<Bottom, ..., Bottom>) ==
        #    Sequence<Tuple<T1, ..., Tn>>, n == len(vars):
        #   vars_i := T_i for each i
        # Else:
        #   vars_i := Top for each i
        #
        # Check iter <= Sequence<Tuple<Top, ..., Top>>
        n = len(node.vars)
        t_iter = self.visit(node.iter)
        t_target = self.get_sequence_elt(node, t_iter, Sequence)
        t_vars = self.get_tuple_elts(node, t_target, n)
        for v, t in zip(node.vars, t_vars):
            self.update_store(v, t)
        self.visit(node.body)

    def visit_While(self, node):
        # Check test <= Bool
        t_test = self.visit(node.test)
        if not t_test.issmaller(Bool):
            self.mark_bad(node)
        self.visit(node.body)

    def visit_If(self, node):
        # Check test <= Bool
        t_test = self.visit(node.test)
        if not t_test.issmaller(Bool):
            self.mark_bad(node)
        self.visit(node.body)
        self.visit(node.orelse)

    # Use default handler for Pass, Break, Continue, and Expr.

    def visit_Assign(self, node):
        # target := value
        t_value = self.visit(node.value)
        self.update_store(node.target, t_value)

    def visit_DecompAssign(self, node):
        # If join(value, Tuple<Bottom, ..., Bottom>) ==
        #    Tuple<T1, ..., Tn>, n == len(vars):
        #   vars_i := T_i for each i
        # Else:
        #   vars_i := Top for each i
        #
        # Check value <= Tuple<Top, ..., Top>
        n = len(node.vars)
        t_value = self.visit(node.value)
        t_vars = self.get_tuple_elts(node, t_value, n)
        for v, t in zip(node.vars, t_vars):
            self.update_store(v, t)

    def visit_SetUpdate(self, node):
        # target := Set<value>
        # Check target <= Set<Top>
        t_value = self.visit(node.value)
        t_target = self.visit(node.target, type=Set(t_value))
        if not t_target.issmaller(Set(Top)):
            self.mark_bad(node)

    def visit_SetBulkUpdate(self, node):
        # If join(value, Set<Bottom>) == Set<T>:
        #   target := Set<T>
        # Else:
        #   target := Set<Top>
        #
        # Check value <= Set<Top> and target <= Set<Top>
        t_value = self.visit(node.value)
        t_elt = self.get_sequence_elt(node, t_value, Set)
        t_target = self.visit(node.target, type=Set(t_elt))
        if not (t_value.issmaller(Set(Top)) and t_target.issmaller(Set(Top))):
            self.mark_bad(node)

    def visit_SetClear(self, node):
        # target := Set<Bottom>
        # Check target <= Set<Top>
        t_target = self.visit(node.target, type=Set(Bottom))
        if not t_target.issmaller(Set(Top)):
            self.mark_bad(node)

    def visit_RelUpdate(self, node):
        # rel := Set<elem>
        # Check rel <= Set<Top>
        t_value = self.get_store(node.elem)
        t_rel = self.update_store(node.rel, Set(t_value))
        if not t_rel.issmaller(Set(Top)):
            self.mark_bad(node)

    def visit_RelClear(self, node):
        # rel := Set<Bottom>
        # Check rel <= Set<Top>
        t_rel = self.update_store(node.rel, Set(Bottom))
        if not t_rel.issmaller(Set(Top)):
            self.mark_bad(node)

    def visit_DictAssign(self, node):
        # target := Map<key, value>
        # Check target <= Map<Top, Top>
        t_key = self.visit(node.key)
        t_value = self.visit(node.value)
        t_target = self.visit(node.target, type=Map(t_key, t_value))
        if not t_target.issmaller(Map(Top, Top)):
            self.mark_bad(node)

    def visit_DictDelete(self, node):
        # target := Map<key, Bottom>
        # Check target <= Map<Top, Top>
        t_key = self.visit(node.key)
        t_target = self.visit(node.target, type=Map(t_key, Bottom))
        if not t_target.issmaller(Map(Top, Top)):
            self.mark_bad(node)

    def visit_DictClear(self, node):
        # target := Map<Bottom, Bottom>
        # Check target <= Map<Top, Top>
        t_target = self.visit(node.target, type=Map(Bottom, Bottom))
        if not t_target.issmaller(Map(Top, Top)):
            self.mark_bad(node)

    def visit_MapAssign(self, node):
        # map := Map<key, value>
        # Check map <= Map<Top, Top>
        t_key = self.get_store(node.key)
        t_value = self.get_store(node.value)
        t_map = self.update_store(node.map, Map(t_key, t_value))
        if not t_map.issmaller(Map(Top, Top)):
            self.mark_bad(node)

    def visit_MapDelete(self, node):
        # map := Map<key, Bottom>
        # Check map <= Map<Top, Top>
        t_key = self.get_store(node.key)
        t_map = self.update_store(node.map, Map(t_key, Bottom))
        if not t_map.issmaller(Map(Top, Top)):
            self.mark_bad(node)

    def visit_MapClear(self, node):
        # target := Map<Bottom, Bottom>
        # Check target <= Map<Top, Top>
        t_map = self.update_store(node.map, Map(Bottom, Bottom))
        if not t_map.issmaller(Map(Top, Top)):
            self.mark_bad(node)

    # Attribute handlers not implemented:
    # visit_AttrAssign
    # visit_AttrDelete

    @readonly
    def visit_UnaryOp(self, node, *, type=None):
        # If op == Not:
        #   Return Bool
        #   Check operand <= Bool
        # Else:
        #   Return Number
        #   Check operand <= Number
        t_operand = self.visit(node.operand)
        if isinstance(node.op, L.Not):
            t = Bool
        else:
            t = Number
        if not t_operand.issmaller(t):
            self.mark_bad(node)
        return t

    @readonly
    def visit_BoolOp(self, node, *, type=None):
        # Return Bool
        # Check v <= Bool for v in values
        t_values = [self.visit(v) for v in node.values]
        if not all(t.issmaller(Bool) for t in t_values):
            self.mark_bad(node)
        return Bool

    @readonly
    def visit_BinOp(self, node, *, type=None):
        # Return join(left, right)
        t_left = self.visit(node.left)
        t_right = self.visit(node.right)
        return t_left.join(t_right)

    @readonly
    def visit_Compare(self, node, *, type=None):
        # Return Bool.
        self.visit(node.left)
        self.visit(node.right)
        return Bool

    @readonly
    def visit_IfExp(self, node, *, type=None):
        # Return join(body, orelse)
        # Check test <= Bool
        t_test = self.visit(node.test)
        t_body = self.visit(node.body)
        t_orelse = self.visit(node.orelse)
        if not t_test.issmaller(Bool):
            self.mark_bad(node)
        return t_body.join(t_orelse)

    @readonly
    def visit_GeneralCall(self, node, *, type=None):
        # Return Top
        self.generic_visit(node)
        return Top

    @readonly
    def visit_Call(self, node, *, type=None):
        checker = FunctionTypeChecker()
        t_args = [self.visit(a) for a in node.args]
        t_result = checker.get_call_type(node, t_args)
        if t_result is None:
            self.mark_bad(node)
            return Top
        return t_result

    @readonly
    def visit_Num(self, node, *, type=None):
        # Return Number
        return Number

    @readonly
    def visit_Str(self, node, *, type=None):
        # Return String
        return String

    @readonly
    def visit_NameConstant(self, node, *, type=None):
        # For True/False:
        #   Return Bool
        # For None:
        #   Return Top
        if node.value in [True, False]:
            return Bool
        elif node.value is None:
            return Top
        else:
            assert ()

    def visit_Name(self, node, *, type=None):
        # Read or update the type in the store, depending on
        # whether we're in read or write context.
        name = node.id
        if type is None:
            return self.get_store(name)
        else:
            return self.update_store(name, type)

    @readonly
    def visit_List(self, node, *, type=None):
        # Return List<join(T1, ..., Tn)>
        t_elts = [self.visit(e) for e in node.elts]
        t_elt = Bottom.join(*t_elts)
        return List(t_elt)

    @readonly
    def visit_Set(self, node, *, type=None):
        # Return Set<join(T1, ..., Tn)>
        t_elts = [self.visit(e) for e in node.elts]
        t_elt = Bottom.join(*t_elts)
        return Set(t_elt)

    @readonly
    def visit_Tuple(self, node, *, type=None):
        # Return Tuple<elts>
        t_elts = [self.visit(e) for e in node.elts]
        return Tuple(t_elts)

    # TODO: More precise behavior requires adding objects to the
    # type algebra.

    visit_Attribute = default_expr_handler

    @readonly
    def visit_Subscript(self, node, *, type=None):
        # If value == Bottom:
        #   Return Bottom
        # Elif value == Tuple<T0, ..., Tn>:
        #   If index == Num(k) node, 0 <= k <= n:
        #     Return Tk
        #   Else:
        #     Return join(T0, ..., Tn)
        # Elif join(value, List<Bottom>) == List<T>:
        #   Return T
        # Else:
        #   Return Top
        #
        # Check value <= List<Top> or value is a Tuple
        # Check index <= Number
        t_value = self.visit(node.value)
        t_index = self.visit(node.index)
        if not t_index.issmaller(Number):
            self.mark_bad(node)

        # Try Tuple case first. Since we don't have a type for tuples
        # of arbitrary arity, we'll use an isinstance() check. This
        # may have to change if we add new subtypes of Tuple to the
        # lattice.
        if isinstance(t_value, Tuple):
            if (isinstance(node.index, L.Num)
                    and 0 <= node.index.n < len(t_value.elts)):
                return t_value.elts[node.index.n]
            else:
                return Bottom.join(*t_value.elts)

        # Otherwise, treat it as a list or list subtype.
        return self.get_sequence_elt(node, t_value, List)

    def visit_DictLookup(self, node, *, type=None):
        # If type != None:
        #   value := Map<Bottom, type>
        #
        # If join(value, Map<Bottom, Bottom>) == Map<K, V>:
        #   R = V
        # Else:
        #   R = Top
        # Return join(R, default)
        #
        # Check value <= Map<Top, Top>
        t_value = Map(Bottom, type) if type is not None else None
        t_value = self.visit(node.value, type=t_value)
        t_default = (self.visit(node.default)
                     if node.default is not None else None)
        t_value = t_value.join(Map(Bottom, Bottom))
        if not t_value.issmaller(Map(Top, Top)):
            self.mark_bad(node)
            return Top
        return t_value.value.join(t_default)

    visit_FirstThen = default_expr_handler

    @readonly
    def visit_ImgLookup(self, node, *, type=None):
        # Check rel <= Set<Top>
        # Return Set<Top>
        t_rel = self.visit(node.set)
        if not t_rel.issmaller(Set(Top)):
            self.mark_bad(node)
            return Top
        return Set(Top)

    @readonly
    def visit_SetFromMap(self, node, *, type=None):
        # Check map <= Map<Top, Top>
        # Return Set<Top>
        t_map = self.visit(node.map)
        if not t_map.issmaller(Map(Top, Top)):
            self.mark_bad(node)
            return Top
        return Set(Top)

    visit_Unwrap = default_expr_handler
    visit_Wrap = default_expr_handler

    visit_IsSet = default_expr_handler
    visit_HasField = default_expr_handler
    visit_IsMap = default_expr_handler
    visit_HasArity = default_expr_handler

    @readonly
    def visit_Query(self, node, *, type=None):
        # Return query
        return self.visit(node.query)

    @readonly
    def visit_Comp(self, node, *, type=None):
        # Return Set<resexp>
        for cl in node.clauses:
            self.visit(cl)
        t_resexp = self.visit(node.resexp)
        return Set(t_resexp)

    @readonly
    def visit_Member(self, node, *, type=None):
        # If join(iter, Set<Bottom>) == Set<T>:
        #   target := T
        # Else:
        #   target := Top
        #
        # Check iter <= Set<Top>
        t_iter = self.visit(node.iter)
        t_target = self.get_sequence_elt(node, t_iter, Set)
        self.visit(node.target, type=t_target)

    @readonly
    def visit_RelMember(self, node, *, type=None):
        # If join(rel, Set<Tuple<Bottom, ..., Bottom>>) ==
        #    Set<Tuple<T1, ..., Tn>>, n == len(vars):
        #   vars_i := T_i for each i
        # Else:
        #   vars_i := Top for each i
        #
        # Check rel <= Set<Tuple<Top, ..., Top>>
        n = len(node.vars)
        t_rel = self.get_store(node.rel)
        t_target = self.get_sequence_elt(node, t_rel, Set)
        t_vars = self.get_tuple_elts(node, t_target, n)
        for v, t in zip(node.vars, t_vars):
            self.update_store(v, t)

    @readonly
    def visit_SingMember(self, node, *, type=None):
        # If join(value, Tuple<Bottom, ..., Bottom>) ==
        #    Tuple<T1, ..., Tn>, n == len(vars):
        #   vars_i := T_i for each i
        # Else:
        #   vars_i := Top for each i
        #
        # Check value <= Tuple<Top, ..., Top>
        n = len(node.vars)
        t_value = self.visit(node.value)
        t_vars = self.get_tuple_elts(node, t_value, n)
        for v, t in zip(node.vars, t_vars):
            self.update_store(v, t)

    @readonly
    def visit_WithoutMember(self, node, *, type=None):
        # We don't have an easy way to propagate information into
        # the nested clause, or else we'd flow type information from
        # value to cl.target. Could fix by using the type parameter,
        # with the convention that for membership clauses, type is the
        # type of the element.
        self.generic_visit(node)

    @readonly
    def visit_VarsMember(self, node, *, type=None):
        # If join(iter, Set<Tuple<Bottom, ..., Bottom>>) ==
        #    Set<Tuple<T1, ..., Tn>>, n == len(vars):
        #   vars_i := T_i for each i
        # Else:
        #   vars_i := Top for each i
        #
        # Check iter <= Set<Tuple<Top, ..., Top>>
        n = len(node.vars)
        t_iter = self.visit(node.iter)
        t_target = self.get_sequence_elt(node, t_iter, Set)
        t_vars = self.get_tuple_elts(node, t_target, n)
        for v, t in zip(node.vars, t_vars):
            self.update_store(v, t)

    @readonly
    def visit_SetFromMapMember(self, node, *, type=None):
        # If join(rel, Set<Tuple<Bottom, ..., Bottom> ==
        #      Set<Tuple<T1, ..., Tn>> and
        #    join(map, Map<Tuple<Bottom, ..., Bottom>, Bottom> ==
        #      Map<Tuple<U1, ..., Un-1>, Un>, n == len(vars):
        #   vars_i := join(T_i, U_i) for each i
        # Else:
        #   vars_i := Top for each i
        #
        # Check map <= Map<Tuple<Bottom, ..., Bottom>, Bottom>
        # Check rel <= Set<Bottom, ..., Bottom>
        n = len(node.vars)
        t_rel = self.get_store(node.rel)
        t_relelt = self.get_sequence_elt(node, t_rel, Set)
        t_relvars = self.get_tuple_elts(node, t_relelt, n)
        t_map = self.get_store(node.map)
        t_key, t_value = self.get_map_keyval(node, t_map)
        t_keyvars = self.get_tuple_elts(node, t_key, n - 1)
        t_mapvars = list(t_keyvars) + [t_value]
        for v, t, u in zip(node.vars, t_relvars, t_mapvars):
            self.update_store(v, t.join(u))

    # Object domain clauses not implemented:
    # MMember
    # FMember
    # MAPMember
    # TUPMember

    @readonly
    def visit_Cond(self, node, *, type=None):
        self.visit(node.cond)

    def aggrop_helper(self, node, op, t_elt):
        # If op is count or op is sum:
        #   Check t_elt <= Number
        #   Return Number
        # Elif op is min or op is max:
        #   Return t_elt
        if isinstance(node.op, (L.Count, L.Sum)):
            if not t_elt.issmaller(Number):
                self.mark_bad(node)
                return Top
            return Number
        elif isinstance(node.op, (L.Min, L.Max)):
            return t_elt
        else:
            assert ()

    @readonly
    def visit_Aggr(self, node, *, type=None):
        # If join(value, Set<Bottom>) == Set<T>:
        #   Return aggrop_helper(op, T)
        # Else:
        #   Return Top
        # Check value <= Set<Top>
        t_value = self.visit(node.value)
        t_elt = self.get_sequence_elt(node, t_value, Set)
        return self.aggrop_helper(node, node.op, t_elt)

    @readonly
    def visit_AggrRestr(self, node, *, type=None):
        # As for Aggr, except we have the additional condition:
        #
        # Check restr <= Set<Tuple<Top, ..., Top>> of arity |params|
        t_value = self.visit(node.value)
        t_elt = self.get_sequence_elt(node, t_value, Set)
        t = self.aggrop_helper(node, node.op, t_elt)

        n = len(node.params)
        t_restr = self.visit(node.restr)
        if not t_restr.issmaller(Set(Tuple([Top] * n))):
            self.mark_bad(node)

        return t
Ejemplo n.º 26
0
class TypeAnalysisStepper(L.AdvNodeVisitor):
    
    """Run one iteration of transfer functions for all the program's
    nodes. Return the store (variable -> type mapping) and a boolean
    indicating whether the store has been changed (i.e. whether new
    information was inferred).
    """
    
    def __init__(self, store, height_limit=None, unknown=Top,
                 fixed_vars=None):
        super().__init__()
        self.store = store
        """Mapping from symbol names to inferred types.
        Each type may only change in a monotonically increasing way.
        """
        self.height_limit = height_limit
        """Maximum height of type terms in the store. None for no limit."""
        self.illtyped = OrderedSet()
        """Nodes where the well-typedness constraints are violated."""
        self.changed = True
        """True if the last call to process() updated the store
        (or if there was no call so far).
        """
        self.unknown = unknown
        """Type to use for variables not in the store. Should be Bottom,
        Top, or None. None indicates that an error should be raised for
        unknown variables.
        """
        if fixed_vars is None:
            fixed_vars = []
        self.fixed_vars = fixed_vars
        """Names of variables whose types cannot be changed by inference."""
    
    def process(self, tree):
        self.changed = False
        super().process(tree)
        return self.store
    
    def get_store(self, name):
        try:
            return self.store[name]
        except KeyError:
            if self.unknown is None:
                raise
            else:
                return self.unknown
    
    def update_store(self, name, type):
        if name in self.fixed_vars:
            return self.store[name]
        
        old_type = self.get_store(name)
        new_type = old_type.join(type)
        if self.height_limit is not None:
            new_type = new_type.widen(self.height_limit)
        if new_type != old_type:
            self.changed = True
            self.store[name] = new_type
        return new_type
    
    def mark_bad(self, node):
        self.illtyped.add(node)
    
    def readonly(f):
        """Decorator for handlers for expression nodes that only
        make sense in read context.
        """
        @wraps(f)
        def wrapper(self, node, *, type=None):
            if type is not None:
                self.mark_bad(node)
            return f(self, node, type=type)
        return wrapper
    
    @readonly
    def default_expr_handler(self, node, *, type=None):
        """Expression handler that just recurses and returns Top."""
        self.generic_visit(node)
        return Top
    
    # Each visitor handler has a monotonic transfer function and
    # possibly a constraint for well-typedness.
    #
    # The behavior of each handler is described in a comment using
    # the following informal syntax.
    #
    #    X := Y          Assign the join of X and Y to X
    #    Check X <= Y    Well-typedness constraint that X is a
    #                    subtype of Y
    #    Return X        Return X as the type of an expression
    #    Fail            Mark an error at this node and return Top
    #
    # This syntax is augmented by If/Elif/Else and pattern matching,
    # e.g. iter == Set<T> introduces T as the element type of iter.
    # join(T1, T2) is the lattice join of T1 and T2.
    #
    # Expression visitors have a keyword argument 'type', and can be
    # used in read or write context. In read context, type is None.
    # In write context, type is the type passed in from context. In
    # both cases the type of the expression is returned. Handlers
    # that do not tolerate write context are decorated as @readonly;
    # they still run but record a well-typedness error.
    
    def get_sequence_elt(self, node, t_seq, seq_cls):
        """Given a sequence type, and a sequence type constructor
        (e.g. Sequence, List, or Set), return the element type of
        the sequence type. If the sequence type cannot safely be
        converted to the given type constructor's form (for instance
        if it is not actually a sequence), return Top and mark node
        as ill-typed.
        """
        # Join to ensure that we're looking at a type object that
        # is an instance of seq_cls, as opposed to some other type
        # object in the lattice.
        t_seq = t_seq.join(seq_cls(Bottom))
        if not t_seq.issmaller(seq_cls(Top)):
            self.mark_bad(node)
            return Top
        return t_seq.elt
    
    def get_map_keyval(self, node, t_map):
        """Given a map type, return its key and value type. If the
        given type cannot safely be converted to a map type, return
        a pair of Tops instead and mark the node as ill-typed.
        """
        t_map = t_map.join(Map(Bottom, Bottom))
        if not t_map.issmaller(Map(Top, Top)):
            self.mark_bad(node)
            return Top, Top
        return t_map.key, t_map.value
    
    def get_tuple_elts(self, node, t_tup, arity):
        """Given a tuple type, return the element types. If the given
        type cannot safely be converted to a tuple of the given arity,
        return arity many Tops, and mark node as ill-typed.
        """
        t_tup = t_tup.join(Tuple([Bottom] * arity))
        if not t_tup.issmaller(Tuple([Top] * arity)):
            self.mark_bad(node)
            return [Top] * arity
        return t_tup.elts
    
    # Use default handler for Return.
    
    def visit_For(self, node):
        # If join(iter, Sequence<Bottom>) == Sequence<T>:
        #   target := T
        # Else:
        #   target := Top
        #
        # Check iter <= Sequence<Top>
        t_iter = self.visit(node.iter)
        t_target = self.get_sequence_elt(node, t_iter, Sequence)
        self.update_store(node.target, type=t_target)
        self.visit(node.body)
    
    def visit_DecompFor(self, node):
        # If join(iter, Sequence<Tuple<Bottom, ..., Bottom>) ==
        #    Sequence<Tuple<T1, ..., Tn>>, n == len(vars):
        #   vars_i := T_i for each i
        # Else:
        #   vars_i := Top for each i
        #
        # Check iter <= Sequence<Tuple<Top, ..., Top>>
        n = len(node.vars)
        t_iter = self.visit(node.iter)
        t_target = self.get_sequence_elt(node, t_iter, Sequence)
        t_vars = self.get_tuple_elts(node, t_target, n)
        for v, t in zip(node.vars, t_vars):
            self.update_store(v, t)
        self.visit(node.body)
    
    def visit_While(self, node):
        # Check test <= Bool
        t_test = self.visit(node.test)
        if not t_test.issmaller(Bool):
            self.mark_bad(node)
        self.visit(node.body)
    
    def visit_If(self, node):
        # Check test <= Bool
        t_test = self.visit(node.test)
        if not t_test.issmaller(Bool):
            self.mark_bad(node)
        self.visit(node.body)
        self.visit(node.orelse)
    
    # Use default handler for Pass, Break, Continue, and Expr.
    
    def visit_Assign(self, node):
        # target := value
        t_value = self.visit(node.value)
        self.update_store(node.target, t_value)
    
    def visit_DecompAssign(self, node):
        # If join(value, Tuple<Bottom, ..., Bottom>) ==
        #    Tuple<T1, ..., Tn>, n == len(vars):
        #   vars_i := T_i for each i
        # Else:
        #   vars_i := Top for each i
        #
        # Check value <= Tuple<Top, ..., Top>
        n = len(node.vars)
        t_value = self.visit(node.value)
        t_vars = self.get_tuple_elts(node, t_value, n)
        for v, t in zip(node.vars, t_vars):
            self.update_store(v, t)
    
    def visit_SetUpdate(self, node):
        # target := Set<value>
        # Check target <= Set<Top>
        t_value = self.visit(node.value)
        t_target = self.visit(node.target, type=Set(t_value))
        if not t_target.issmaller(Set(Top)):
            self.mark_bad(node)
    
    def visit_SetBulkUpdate(self, node):
        # If join(value, Set<Bottom>) == Set<T>:
        #   target := Set<T>
        # Else:
        #   target := Set<Top>
        #
        # Check value <= Set<Top> and target <= Set<Top>
        t_value = self.visit(node.value)
        t_elt = self.get_sequence_elt(node, t_value, Set)
        t_target = self.visit(node.target, type=Set(t_elt))
        if not (t_value.issmaller(Set(Top)) and
                t_target.issmaller(Set(Top))):
            self.mark_bad(node)
    
    def visit_SetClear(self, node):
        # target := Set<Bottom>
        # Check target <= Set<Top>
        t_target = self.visit(node.target, type=Set(Bottom))
        if not t_target.issmaller(Set(Top)):
            self.mark_bad(node)
    
    def visit_RelUpdate(self, node):
        # rel := Set<elem>
        # Check rel <= Set<Top>
        t_value = self.get_store(node.elem)
        t_rel = self.update_store(node.rel, Set(t_value))
        if not t_rel.issmaller(Set(Top)):
            self.mark_bad(node)
    
    def visit_RelClear(self, node):
        # rel := Set<Bottom>
        # Check rel <= Set<Top>
        t_rel = self.update_store(node.rel, Set(Bottom))
        if not t_rel.issmaller(Set(Top)):
            self.mark_bad(node)
    
    def visit_DictAssign(self, node):
        # target := Map<key, value>
        # Check target <= Map<Top, Top>
        t_key = self.visit(node.key)
        t_value = self.visit(node.value)
        t_target = self.visit(node.target, type=Map(t_key, t_value))
        if not t_target.issmaller(Map(Top, Top)):
            self.mark_bad(node)
    
    def visit_DictDelete(self, node):
        # target := Map<key, Bottom>
        # Check target <= Map<Top, Top>
        t_key = self.visit(node.key)
        t_target = self.visit(node.target, type=Map(t_key, Bottom))
        if not t_target.issmaller(Map(Top, Top)):
            self.mark_bad(node)
    
    def visit_DictClear(self, node):
        # target := Map<Bottom, Bottom>
        # Check target <= Map<Top, Top>
        t_target = self.visit(node.target, type=Map(Bottom, Bottom))
        if not t_target.issmaller(Map(Top, Top)):
            self.mark_bad(node)
    
    def visit_MapAssign(self, node):
        # map := Map<key, value>
        # Check map <= Map<Top, Top>
        t_key = self.get_store(node.key)
        t_value = self.get_store(node.value)
        t_map = self.update_store(node.map, Map(t_key, t_value))
        if not t_map.issmaller(Map(Top, Top)):
            self.mark_bad(node)
    
    def visit_MapDelete(self, node):
        # map := Map<key, Bottom>
        # Check map <= Map<Top, Top>
        t_key = self.get_store(node.key)
        t_map = self.update_store(node.map, Map(t_key, Bottom))
        if not t_map.issmaller(Map(Top, Top)):
            self.mark_bad(node)
    
    def visit_MapClear(self, node):
        # target := Map<Bottom, Bottom>
        # Check target <= Map<Top, Top>
        t_map = self.update_store(node.map, Map(Bottom, Bottom))
        if not t_map.issmaller(Map(Top, Top)):
            self.mark_bad(node)
    
    # Attribute handlers not implemented:
    # visit_AttrAssign
    # visit_AttrDelete 
    
    @readonly
    def visit_UnaryOp(self, node, *, type=None):
        # If op == Not:
        #   Return Bool
        #   Check operand <= Bool
        # Else:
        #   Return Number
        #   Check operand <= Number
        t_operand = self.visit(node.operand)
        if isinstance(node.op, L.Not):
            t = Bool
        else:
            t = Number
        if not t_operand.issmaller(t):
            self.mark_bad(node)
        return t
    
    @readonly
    def visit_BoolOp(self, node, *, type=None):
        # Return Bool
        # Check v <= Bool for v in values
        t_values = [self.visit(v) for v in node.values]
        if not all(t.issmaller(Bool) for t in t_values):
            self.mark_bad(node)
        return Bool
    
    @readonly
    def visit_BinOp(self, node, *, type=None):
        # Return join(left, right)
        t_left = self.visit(node.left)
        t_right = self.visit(node.right)
        return t_left.join(t_right)
    
    @readonly
    def visit_Compare(self, node, *, type=None):
        # Return Bool.
        self.visit(node.left)
        self.visit(node.right)
        return Bool
    
    @readonly
    def visit_IfExp(self, node, *, type=None):
        # Return join(body, orelse)
        # Check test <= Bool
        t_test = self.visit(node.test)
        t_body = self.visit(node.body)
        t_orelse = self.visit(node.orelse)
        if not t_test.issmaller(Bool):
            self.mark_bad(node)
        return t_body.join(t_orelse)
    
    @readonly
    def visit_GeneralCall(self, node, *, type=None):
        # Return Top
        self.generic_visit(node)
        return Top
    
    @readonly
    def visit_Call(self, node, *, type=None):
        checker = FunctionTypeChecker()
        t_args = [self.visit(a) for a in node.args]
        t_result = checker.get_call_type(node, t_args)
        if t_result is None:
            self.mark_bad(node)
            return Top
        return t_result
    
    @readonly
    def visit_Num(self, node, *, type=None):
        # Return Number
        return Number
    
    @readonly
    def visit_Str(self, node, *, type=None):
        # Return String
        return String
    
    @readonly
    def visit_NameConstant(self, node, *, type=None):
        # For True/False:
        #   Return Bool
        # For None:
        #   Return Top
        if node.value in [True, False]:
            return Bool
        elif node.value is None:
            return Top
        else:
            assert()
    
    def visit_Name(self, node, *, type=None):
        # Read or update the type in the store, depending on
        # whether we're in read or write context.
        name = node.id
        if type is None:
            return self.get_store(name)
        else:
            return self.update_store(name, type)
    
    @readonly
    def visit_List(self, node, *, type=None):
        # Return List<join(T1, ..., Tn)>
        t_elts = [self.visit(e) for e in node.elts]
        t_elt = Bottom.join(*t_elts)
        return List(t_elt)
    
    @readonly
    def visit_Set(self, node, *, type=None):
        # Return Set<join(T1, ..., Tn)>
        t_elts = [self.visit(e) for e in node.elts]
        t_elt = Bottom.join(*t_elts)
        return Set(t_elt)
    
    @readonly
    def visit_Tuple(self, node, *, type=None):
        # Return Tuple<elts>
        t_elts = [self.visit(e) for e in node.elts]
        return Tuple(t_elts)
    
    # TODO: More precise behavior requires adding objects to the
    # type algebra.
    
    visit_Attribute = default_expr_handler
    
    @readonly
    def visit_Subscript(self, node, *, type=None):
        # If value == Bottom:
        #   Return Bottom
        # Elif value == Tuple<T0, ..., Tn>:
        #   If index == Num(k) node, 0 <= k <= n:
        #     Return Tk
        #   Else:
        #     Return join(T0, ..., Tn)
        # Elif join(value, List<Bottom>) == List<T>:
        #   Return T
        # Else:
        #   Return Top
        #
        # Check value <= List<Top> or value is a Tuple
        # Check index <= Number
        t_value = self.visit(node.value)
        t_index = self.visit(node.index)
        if not t_index.issmaller(Number):
            self.mark_bad(node)
        
        # Try Tuple case first. Since we don't have a type for tuples
        # of arbitrary arity, we'll use an isinstance() check. This
        # may have to change if we add new subtypes of Tuple to the
        # lattice.
        if isinstance(t_value, Tuple):
            if (isinstance(node.index, L.Num) and
                0 <= node.index.n < len(t_value.elts)):
                return t_value.elts[node.index.n]
            else:
                return Bottom.join(*t_value.elts)
        
        # Otherwise, treat it as a list or list subtype.
        return self.get_sequence_elt(node, t_value, List)
    
    def visit_DictLookup(self, node, *, type=None):
        # If type != None:
        #   value := Map<Bottom, type>
        #
        # If join(value, Map<Bottom, Bottom>) == Map<K, V>:
        #   R = V
        # Else:
        #   R = Top
        # Return join(R, default)
        #
        # Check value <= Map<Top, Top>
        t_value = Map(Bottom, type) if type is not None else None
        t_value = self.visit(node.value, type=t_value)
        t_default = (self.visit(node.default)
                     if node.default is not None else None)
        t_value = t_value.join(Map(Bottom, Bottom))
        if not t_value.issmaller(Map(Top, Top)):
            self.mark_bad(node)
            return Top
        return t_value.value.join(t_default)
    
    visit_FirstThen = default_expr_handler
    
    @readonly
    def visit_ImgLookup(self, node, *, type=None):
        # Check rel <= Set<Top>
        # Return Set<Top>
        t_rel = self.visit(node.set)
        if not t_rel.issmaller(Set(Top)):
            self.mark_bad(node)
            return Top
        return Set(Top)
    
    @readonly
    def visit_SetFromMap(self, node, *, type=None):
        # Check map <= Map<Top, Top>
        # Return Set<Top>
        t_map = self.visit(node.map)
        if not t_map.issmaller(Map(Top, Top)):
            self.mark_bad(node)
            return Top
        return Set(Top)
    
    visit_Unwrap = default_expr_handler
    visit_Wrap = default_expr_handler
    
    visit_IsSet = default_expr_handler
    visit_HasField = default_expr_handler
    visit_IsMap = default_expr_handler
    visit_HasArity = default_expr_handler
    
    @readonly
    def visit_Query(self, node, *, type=None):
        # Return query
        return self.visit(node.query)
    
    @readonly
    def visit_Comp(self, node, *, type=None):
        # Return Set<resexp>
        for cl in node.clauses:
            self.visit(cl)
        t_resexp = self.visit(node.resexp)
        return Set(t_resexp)
    
    @readonly
    def visit_Member(self, node, *, type=None):
        # If join(iter, Set<Bottom>) == Set<T>:
        #   target := T
        # Else:
        #   target := Top
        #
        # Check iter <= Set<Top>
        t_iter = self.visit(node.iter)
        t_target = self.get_sequence_elt(node, t_iter, Set)
        self.visit(node.target, type=t_target)
    
    @readonly
    def visit_RelMember(self, node, *, type=None):
        # If join(rel, Set<Tuple<Bottom, ..., Bottom>>) ==
        #    Set<Tuple<T1, ..., Tn>>, n == len(vars):
        #   vars_i := T_i for each i
        # Else:
        #   vars_i := Top for each i
        #
        # Check rel <= Set<Tuple<Top, ..., Top>>
        n = len(node.vars)
        t_rel = self.get_store(node.rel)
        t_target = self.get_sequence_elt(node, t_rel, Set)
        t_vars = self.get_tuple_elts(node, t_target, n)
        for v, t in zip(node.vars, t_vars):
            self.update_store(v, t)
    
    @readonly
    def visit_SingMember(self, node, *, type=None):
        # If join(value, Tuple<Bottom, ..., Bottom>) ==
        #    Tuple<T1, ..., Tn>, n == len(vars):
        #   vars_i := T_i for each i
        # Else:
        #   vars_i := Top for each i
        #
        # Check value <= Tuple<Top, ..., Top>
        n = len(node.vars)
        t_value = self.visit(node.value)
        t_vars = self.get_tuple_elts(node, t_value, n)
        for v, t in zip(node.vars, t_vars):
            self.update_store(v, t)
    
    @readonly
    def visit_WithoutMember(self, node, *, type=None):
        # We don't have an easy way to propagate information into
        # the nested clause, or else we'd flow type information from
        # value to cl.target. Could fix by using the type parameter,
        # with the convention that for membership clauses, type is the
        # type of the element.
        self.generic_visit(node)
    
    @readonly
    def visit_VarsMember(self, node, *, type=None):
        # If join(iter, Set<Tuple<Bottom, ..., Bottom>>) ==
        #    Set<Tuple<T1, ..., Tn>>, n == len(vars):
        #   vars_i := T_i for each i
        # Else:
        #   vars_i := Top for each i
        #
        # Check iter <= Set<Tuple<Top, ..., Top>>
        n = len(node.vars)
        t_iter = self.visit(node.iter)
        t_target = self.get_sequence_elt(node, t_iter, Set)
        t_vars = self.get_tuple_elts(node, t_target, n)
        for v, t in zip(node.vars, t_vars):
            self.update_store(v, t)
    
    @readonly
    def visit_SetFromMapMember(self, node, *, type=None):
        # If join(rel, Set<Tuple<Bottom, ..., Bottom> ==
        #      Set<Tuple<T1, ..., Tn>> and
        #    join(map, Map<Tuple<Bottom, ..., Bottom>, Bottom> ==
        #      Map<Tuple<U1, ..., Un-1>, Un>, n == len(vars):
        #   vars_i := join(T_i, U_i) for each i
        # Else:
        #   vars_i := Top for each i
        #
        # Check map <= Map<Tuple<Bottom, ..., Bottom>, Bottom>
        # Check rel <= Set<Bottom, ..., Bottom>
        n = len(node.vars)
        t_rel = self.get_store(node.rel)
        t_relelt = self.get_sequence_elt(node, t_rel, Set)
        t_relvars = self.get_tuple_elts(node, t_relelt, n)
        t_map = self.get_store(node.map)
        t_key, t_value = self.get_map_keyval(node, t_map)
        t_keyvars = self.get_tuple_elts(node, t_key, n - 1)
        t_mapvars = list(t_keyvars) + [t_value]
        for v, t, u in zip(node.vars, t_relvars, t_mapvars):
            self.update_store(v, t.join(u))
    
    # Object domain clauses not implemented:
    # MMember
    # FMember
    # MAPMember
    # TUPMember
    
    @readonly
    def visit_Cond(self, node, *, type=None):
        self.visit(node.cond)
    
    def aggrop_helper(self, node, op, t_elt):
        # If op is count or op is sum:
        #   Check t_elt <= Number
        #   Return Number
        # Elif op is min or op is max:
        #   Return t_elt
        if isinstance(node.op, (L.Count, L.Sum)):
            if not t_elt.issmaller(Number):
                self.mark_bad(node)
                return Top
            return Number
        elif isinstance(node.op, (L.Min, L.Max)):
            return t_elt
        else:
            assert()
    
    @readonly
    def visit_Aggr(self, node, *, type=None):
        # If join(value, Set<Bottom>) == Set<T>:
        #   Return aggrop_helper(op, T)
        # Else:
        #   Return Top
        # Check value <= Set<Top>
        t_value = self.visit(node.value)
        t_elt = self.get_sequence_elt(node, t_value, Set)
        return self.aggrop_helper(node, node.op, t_elt)
    
    @readonly
    def visit_AggrRestr(self, node, *, type=None):
        # As for Aggr, except we have the additional condition:
        #
        # Check restr <= Set<Tuple<Top, ..., Top>> of arity |params|
        t_value = self.visit(node.value)
        t_elt = self.get_sequence_elt(node, t_value, Set)
        t = self.aggrop_helper(node, node.op, t_elt)
        
        n = len(node.params)
        t_restr = self.visit(node.restr)
        if not t_restr.issmaller(Set(Tuple([Top] * n))):
            self.mark_bad(node)
        
        return t