def __init__(self, name, bases, body, env, ast_def_class): "@types: str, list[ast.AST], list[ast.AST], LinkedList" assert IS(name, str) self.name = name self.env = env self.attrs = {} self.ast = ast_def_class for base in bases: if IS(base, Attribute) or base.id == "object": continue baseClasses = lookup(base.id, env) if baseClasses and len(baseClasses) == 1 and IS(baseClasses[0], ClassType): # limit to one possible type baseClass = baseClasses[0] for key, val in baseClass.attrs.iteritems(): self.attrs[key] = val else: logger.error("Can't infer base of %s: %s %s" % (name, baseClasses, base.id)) self.__saveClassAttrs(body)
def inferSeq(exp, env, stk): debug("Infering sequence", exp) if exp == []: # reached end without return return ([contType], env) e = exp[0] if IS(e, If): _ = infer(e.test, env, stk) (t1, env1) = inferSeq(e.body, close(e.body, env), stk) (t2, env2) = inferSeq(e.orelse, close(e.orelse, env), stk) if isTerminating(t1) and isTerminating(t2): # both terminates for e2 in exp[1:]: putInfo(e2, TypeError("unreachable code")) return (union([t1, t2]), env) elif isTerminating(t1) and not isTerminating(t2): # t1 terminates (t3, env3) = inferSeq(exp[1:], env2, stk) t2 = finalize(t2) return (union([t1, t2, t3]), env3) elif not isTerminating(t1) and isTerminating(t2): # t2 terminates (t3, env3) = inferSeq(exp[1:], env1, stk) t1 = finalize(t1) return (union([t1, t2, t3]), env3) else: # both non-terminating (t3, env3) = inferSeq(exp[1:], mergeEnv(env1, env2), stk) t1 = finalize(t1) t2 = finalize(t2) return (union([t1, t2, t3]), env3) elif IS(e, While): # todo evaluate test (t1, env1) = inferSeq(e.body, close(e.body, env), stk) (t2, env2) = inferSeq(e.orelse, close(e.orelse, env), stk) if isTerminating(t1) and isTerminating(t2): # both terminates for e2 in exp[1:]: putInfo(e2, TypeError("unreachable code")) return (union([t1, t2]), env) elif isTerminating(t1) and not isTerminating(t2): # t1 terminates (t3, env3) = inferSeq(exp[1:], env2, stk) t2 = finalize(t2) return (union([t1, t2, t3]), env3) elif not isTerminating(t1) and isTerminating(t2): # t2 terminates (t3, env3) = inferSeq(exp[1:], env1, stk) t1 = finalize(t1) return (union([t1, t2, t3]), env3) else: # both non-terminating (t3, env3) = inferSeq(exp[1:], mergeEnv(env1, env2), stk) t1 = finalize(t1) t2 = finalize(t2) return (union([t1, t2, t3]), env3) elif IS(e, For): values = infer(e.iter, env, stk) value = flatten(values) env = bind(e.target, value, env) (t1, env1) = inferSeq(e.body, close(e.body, env), stk) (t2, env2) = inferSeq(e.orelse, close(e.orelse, env), stk) if isTerminating(t1) and isTerminating(t2): # both terminates for e2 in exp[1:]: putInfo(e2, TypeError("unreachable code")) return (union([t1, t2]), env) elif isTerminating(t1) and not isTerminating(t2): # t1 terminates (t3, env3) = inferSeq(exp[1:], env2, stk) t2 = finalize(t2) return (union([t1, t2, t3]), env3) elif not isTerminating(t1) and isTerminating(t2): # t2 terminates (t3, env3) = inferSeq(exp[1:], env1, stk) t1 = finalize(t1) return (union([t1, t2, t3]), env3) else: # both non-terminating (t3, env3) = inferSeq(exp[1:], mergeEnv(env1, env2), stk) t1 = finalize(t1) t2 = finalize(t2) return (union([t1, t2, t3]), env3) elif IS(e, Assign): t = infer(e.value, env, stk) for x in e.targets: env = bind(x, t, env) return inferSeq(exp[1:], env, stk) elif IS(e, AugAssign): t = infer(e.value, env, stk) env = bind(e.target, t, env) return inferSeq(exp[1:], env, stk) elif IS(e, FunctionDef): cs = lookup(e.name, env) if not cs: debug("Function %s not found in scope %s" % (e.name, env)) for c in cs: c.env = env # create circular env to support recursion for d in e.args.defaults: # infer types for default arguments dt = infer(d, env, stk) c.defaults.append(dt) return inferSeq(exp[1:], env, stk) elif IS(e, Return): if e.value is None: t1 = [PrimType(None)] else: t1 = infer(e.value, env, stk) (t2, env2) = inferSeq(exp[1:], env, stk) for e2 in exp[1:]: putInfo(e2, TypeError("unreachable code")) return (t1, env) elif IS(e, Expr): t1 = infer(e.value, env, stk) return inferSeq(exp[1:], env, stk) elif IS(e, ImportFrom): _, module_symbols = get_module_symbols(e.module) for module_name in e.names: name_to_import = module_name.name name_import_as = module_name.asname or name_to_import module_symbol = lookup(name_to_import, module_symbols) env = bind(getName(name_import_as, e.lineno), module_symbol, env) return inferSeq(exp[1:], env, stk) elif IS(e, Import): for module_name in e.names: name_to_import = module_name.name module, module_env = get_module_symbols(name_to_import) name_import_as = module_name.asname or module_name.name module_class = ClassType("module", [], module.body, module_env, e) module_obj = [ObjType(module_class, [], nil, e)] env = bind(getName(name_import_as, e.lineno), module_obj, env) return inferSeq(exp[1:], env, stk) elif IS(e, ClassDef): cs = lookup(e.name, env) if not cs: debug("Class def %s not found in scope %s" % (e.name, env)) for c in cs: c.env = env (t2, env2) = inferSeq(exp[1:], env, stk) return (t2, env2) elif IS(e, Break): return inferSeq(exp[1:], env, stk) elif IS(e, Continue): return inferSeq(exp[1:], env, stk) elif IS(e, TryExcept): (t1, env1) = inferSeq(e.body, close(e.body, env), stk) (t2, env2) = inferSeq(e.orelse, close(e.orelse, env), stk) (_, _) = inferSeq(e.handlers, close(e.handlers, env), stk) if isTerminating(t1) and isTerminating(t2): # both terminates for e2 in exp[1:]: putInfo(e2, TypeError("unreachable code")) return (union([t1, t2]), env) elif isTerminating(t1) and not isTerminating(t2): # t1 terminates (t3, env3) = inferSeq(exp[1:], env2, stk) t2 = finalize(t2) return (union([t1, t2, t3]), env3) elif not isTerminating(t1) and isTerminating(t2): # t2 terminates (t3, env3) = inferSeq(exp[1:], env1, stk) t1 = finalize(t1) return (union([t1, t2, t3]), env3) else: # both non-terminating (t3, env3) = inferSeq(exp[1:], mergeEnv(env1, env2), stk) t1 = finalize(t1) t2 = finalize(t2) return (union([t1, t2, t3]), env3) elif IS(e, TryFinally): (t1, env1) = inferSeq(e.body, close(e.body, env), stk) (t2, env2) = inferSeq(e.finalbody, close(e.finalbody, env), stk) if isTerminating(t1) and isTerminating(t2): # both terminates for e2 in exp[1:]: putInfo(e2, TypeError("unreachable code")) return (union([t1, t2]), env) elif isTerminating(t1) and not isTerminating(t2): # t1 terminates (t3, env3) = inferSeq(exp[1:], env2, stk) t2 = finalize(t2) return (union([t1, t2, t3]), env3) elif not isTerminating(t1) and isTerminating(t2): # t2 terminates (t3, env3) = inferSeq(exp[1:], env1, stk) t1 = finalize(t1) return (union([t1, t2, t3]), env3) else: # both non-terminating (t3, env3) = inferSeq(exp[1:], mergeEnv(env1, env2), stk) t1 = finalize(t1) t2 = finalize(t2) return (union([t1, t2, t3]), env3) elif IS(e, ExceptHandler): (t1, env1) = inferSeq(e.body, close(e.body, env), stk) (t3, env3) = inferSeq(exp[1:], env1, stk) return (union([t1, t3]), env3) elif IS(e, Raise): return inferSeq(exp[1:], env, stk) elif IS(e, Pass): return inferSeq(exp[1:], env, stk) elif IS(e, Print): return inferSeq(exp[1:], env, stk) elif IS(e, With): # TODO infer e.context_expr, # call __enter__ from e.context_expr # bind e.optional_vars to the result of __enter__ # call __exit__ (t1, env1) = inferSeq(e.body, close(e.body, env), stk) (t2, env2) = inferSeq(exp[1:], env1, stk) return (union([t1, t2]), env2) elif IS(e, Assert): return inferSeq(exp[1:], env, stk) elif IS(e, Global): # TODO this should affect bind behaviour when assigning # We don't have a way to change env for now, # we can only append # see tests/assign.py return inferSeq(exp[1:], env, stk) elif IS(e, ast.Delete): return inferSeq(exp[1:], env, stk) elif IS(e, ast.Subscript): return inferSeq(exp[1:], env, stk) elif IS(e, ast.Exec): return inferSeq(exp[1:], env, stk) else: raise TypeError("recognized node in effect context", e)
def invokeClosure(call, actualParams, clo, env, stk): """ @types: ast.Call, list[ast.AST], Closure, LinkedList, LinkedList -> list[Type] """ debug("invoking closure", clo.func, "with args", actualParams) debug(clo.func.body) func = clo.func fenv = clo.env pos = nil kwarg = nil # bind positionals first poslen = min(len(func.args.args), len(actualParams)) for i in xrange(poslen): t = infer(actualParams[i], env, stk) pos = bind(func.args.args[i], t, pos) # put extra positionals into vararg if provided # report error and go on otherwise if len(actualParams) > len(func.args.args): if func.args.vararg == None: err = TypeError("excess arguments to function") putInfo(call, err) return [err] else: ts = [] for i in xrange(len(func.args.args), len(actualParams)): t = infer(actualParams[i], env, stk) ts = ts + t pos = bind(func.args.vararg, ts, pos) # bind keywords, collect kwarg ids = map(getId, func.args.args) for k in call.keywords: ts = infer(k.value, env, stk) tloc1 = lookup(k.arg, pos) if tloc1 != None: putInfo(call, TypeError("multiple values for keyword argument", k.arg, tloc1)) elif k.arg not in ids: kwarg = bind(k.arg, ts, kwarg) else: pos = bind(k.arg, ts, pos) # put extras in kwarg or report them # bind call.keywords to func.args.kwarg if kwarg != nil: if func.args.kwarg != None: pos = bind(func.args.kwarg, [DictType(reverse(kwarg))], pos) else: putInfo(call, TypeError("unexpected keyword arguements", kwarg)) elif func.args.kwarg != None: pos = bind(func.args.kwarg, [DictType(nil)], pos) # bind defaults, avoid overwriting bound vars # types for defaults are already inferred when the function was defined i = len(func.args.args) - len(func.args.defaults) _ = len(func.args.args) for j in xrange(len(clo.defaults)): tloc = lookup(getId(func.args.args[i]), pos) if tloc == None: pos = bind(func.args.args[i], clo.defaults[j], pos) i += 1 # finish building the input type fromtype = maplist(lambda p: SimplePair(first(p), typeOnly(rest(p))), pos) # check whether the same call site is on stack with same input types # if so, we are back to a loop, terminate if onStack(call, fromtype, stk): return [bottomType] # push the call site onto the stack and analyze the function body stk = ext(call, fromtype, stk) fenv = append(pos, fenv) to = infer(func.body, fenv, stk) # record the function type putInfo(func, FuncType(reverse(fromtype), to)) return to
def infer(exp, env, stk): "@types: ast.AST|object, LinkedList, LinkedList -> list[Type]" debug("infering", exp, exp.__class__) assert exp is not None if IS(exp, Module): return infer(exp.body, env, stk) elif IS(exp, list): env = close(exp, env) (t, _) = inferSeq(exp, env, stk) # env ignored (out of scope) return t elif IS(exp, Num): # we need objects, not types return [exp] # [PrimType(type(exp.n))] elif IS(exp, Str): # we need objects, not types return [exp] # [PrimType(type(exp.s))] elif IS(exp, Name): b = lookup(exp.id, env) debug("infering name:", b, env) if b != None: putInfo(exp, b) return b else: try: t = eval(exp.id) # try use information from Python interpreter return [PrimType(t)] except NameError as _: putInfo(exp, UnknownType(exp)) return [UnknownType(exp)] elif IS(exp, Lambda): c = Closure(exp, env) for d in exp.args.defaults: dt = infer(d, env, stk) c.defaults.append(dt) return [c] elif IS(exp, Call): return invoke(exp, env, stk) elif IS(exp, Attribute): t = infer(exp.value, env, stk) if t: attribs = [] # find attr name in object and return it for o in t: if not IS(o, (ObjType, ClassType, DictType)): attribs.append(TypeError("unknown object", o)) continue if exp.attr in o.attrs: attribs.append(AttrType(o.attrs[exp.attr], o, exp.value)) else: attribs.append(TypeError("no such attribute", exp.attr)) return attribs else: return [UnknownType(exp)] ## ignore complex types for now # elif IS(exp, List): # eltTypes = [] # for e in exp.elts: # t = infer(e, env, stk) # eltTypes.append(t) # return [Bind(ListType(eltTypes), exp)] # elif IS(exp, Tuple): # eltTypes = [] # for e in exp.elts: # t = infer(e, env, stk) # eltTypes.append(t) # return [Bind(TupleType(eltTypes), exp)] elif IS(exp, ObjType): return exp elif IS(exp, ast.List): infered_elts = flatten([infer(el, env, stk) for el in exp.elts]) return [ListType(tuple(infered_elts))] elif IS(exp, ast.Dict): infered_keys = [infer(key, env, stk) for key in exp.keys] infered_values = [infer(value, env, stk) for value in exp.values] temp_dict = defaultdict(list) dic = nil for keys, value in zip(infered_keys, infered_values): for key in keys: try: temp_dict[key] = value # only the last value # with the same key is stored except TypeError, _: # unhashable instance dic = ext(key, value, dic) for key, value in temp_dict.iteritems(): dic = ext(key, value, dic) return [DictType(dic)]