def visit_Subscript(self, node): self.visit(node.value) # type of a[1:2, 3, 4:1] is the type of: declval(a)(slice, long, slice) if isextslice(node.slice): self.visit(node.slice) def f(t): def et(a, *b): return "{0}({1})".format(a, ", ".join(b)) dim_types = tuple(self.result[d] for d in node.slice.elts) return self.builder.ExpressionType(et, (t, ) + dim_types) elif isnum(node.slice) and node.slice.value >= 0: # type of a[2] is the type of an elements of a # this special case is to make type inference easier # for the back end compiler def f(t): return self.builder.ElementType(node.slice.value, t) else: # type of a[i] is the return type of the matching function self.visit(node.slice) def f(x): return self.builder.ExpressionType( "{0}[{1}]".format, (x, self.result[node.slice])) f and self.combine(node, node.value, unary_op=f)
def visit_Subscript(self, node): ''' Resulting node alias stores the subscript relationship if we don't know anything about the subscripted node. >>> from pythran import passmanager >>> pm = passmanager.PassManager('demo') >>> module = ast.parse('def foo(a): return a[0]') >>> result = pm.gather(Aliases, module) >>> Aliases.dump(result, filter=ast.Subscript) a[0] => ['a[0]'] If we know something about the container, e.g. in case of a list, we can use this information to get more accurate informations: >>> module = ast.parse('def foo(a, b, c): return [a, b][c]') >>> result = pm.gather(Aliases, module) >>> Aliases.dump(result, filter=ast.Subscript) [a, b][c] => ['a', 'b'] Moreover, in case of a tuple indexed by a constant value, we can further refine the aliasing information: >>> fun = """ ... def f(a, b): return a, b ... def foo(a, b): return f(a, b)[0]""" >>> module = ast.parse(fun) >>> result = pm.gather(Aliases, module) >>> Aliases.dump(result, filter=ast.Subscript) f(a, b)[0] => ['a'] Nothing is done for slices, even if the indices are known :-/ >>> module = ast.parse('def foo(a, b, c): return [a, b, c][1:]') >>> result = pm.gather(Aliases, module) >>> Aliases.dump(result, filter=ast.Subscript) [a, b, c][1:] => ['<unbound-value>'] ''' if isinstance(node.slice, ast.Index): aliases = set() self.visit(node.slice) value_aliases = self.visit(node.value) for alias in value_aliases: if isinstance(alias, ContainerOf): if isinstance(node.slice.value, ast.Slice): continue if isnum(node.slice.value): if node.slice.value.value != alias.index: continue # FIXME: what if the index is a slice variable... aliases.add(alias.containee) elif isinstance(getattr(alias, 'ctx', None), ast.Param): aliases.add(ast.Subscript(alias, node.slice, node.ctx)) if not aliases: aliases = None else: # could be enhanced through better handling of containers aliases = None self.generic_visit(node) return self.add(node, aliases)
def is_true_predicate(node): # FIXME: there may be more patterns here if isnum(node) and node.value: return True if isinstance(node, ast.Attribute) and node.attr == 'True': return True return False
def can_use_c_for(self, node): """ Check if a for loop can use classic C syntax. To use C syntax: - target should not be assign in the loop - range should be use as iterator - order have to be known at compile time """ assert isinstance(node.target, ast.Name) pattern_range = ast.Call(func=ast.Attribute(value=ast.Name( 'builtins', ast.Load(), None, None), attr='range', ctx=ast.Load()), args=AST_any(), keywords=[]) is_assigned = set() for stmt in node.body: is_assigned.update({n.id for n in self.gather(IsAssigned, stmt)}) nodes = ASTMatcher(pattern_range).search(node.iter) if node.iter not in nodes or node.target.id in is_assigned: return False args = node.iter.args if len(args) < 3: return True if isnum(args[2]): return True return False
def fold_mult_right(self, node): if not isinstance(node.right, (ast.List, ast.Tuple)): return False if not isnum(node.left): return False # FIXME: remove that check once we have a proper type inference engine if not isintegral(node.left): raise PythranSyntaxError("Multiplying a sequence by a float", node) return isinstance(node.op, ast.Mult)
def _Attribute(self, t): self.dispatch(t.value) # Special case: 3.__abs__() is a syntax error, so if t.value # is an integer literal then we need to either parenthesize # it or add an extra space to get 3 .__abs__(). if isnum(t.value) and isinstance(t.value.value, int): self.write(" ") self.write(".") self.write(t.attr)
def visit_Call(self, node): func_aliases = self.aliases[node.func] if len(func_aliases) == 1 and next(iter(func_aliases)) is _make_shape: self.result.update( a for a in node.args if isnum(a) and isinstance(a.value, int) and a.value >= 0) return return self.generic_visit(node)
def handle_real_loop_comparison(self, args, target, upper_bound): """ Handle comparison for real loops. Add the correct comparison operator if possible. """ # order is 1 for increasing loop, -1 for decreasing loop and 0 if it is # not known at compile time if len(args) <= 2: order = 1 elif isnum(args[2]): order = -1 + 2 * (int(args[2].value) > 0) elif isnum(args[1]) and isnum(args[0]): order = -1 + 2 * (int(args[1].value) > int(args[0].value)) else: order = 0 comparison = "{} < {}" if order == 1 else "{} > {}" comparison = comparison.format(target, upper_bound) return comparison
def visit_If(self, node): test = self.visit(node.test) body = [self.visit(n) for n in node.body] orelse = [self.visit(n) for n in node.orelse] # compound statement required for some OpenMP Directives if isnum(node.test) and node.test.value == 1: stmt = Block(body) else: stmt = If(test, Block(body), Block(orelse) if orelse else None) return self.process_locals(node, self.process_omp_attachements(node, stmt))
def is_true_predicate(node): # FIXME: there may be more patterns here if isnum(node) and node.value: return True if isinstance(node, ast.Attribute) and node.attr == 'True': return True if isinstance(node, (ast.List, ast.Tuple, ast.Set)) and node.elts: return True if isinstance(node, ast.Dict) and node.keys: return True return False
def visit_BinOp(self, node): self.generic_visit(node) if ASTMatcher(Square.POW_PATTERN).match(node): return self.replace(node.left) elif isinstance(node.op, ast.Pow) and isnum(node.right): n = node.right.value if int(n) == n and n > 0: return self.expand_pow(node.left, n) else: return node else: return node
def visit_Subscript(self, node): """ >>> import gast as ast >>> from pythran import passmanager, backend >>> pm = passmanager.PassManager("test") >>> node = ast.parse("def foo(a): a[1:][3]") >>> _, node = pm.apply(PartialConstantFolding, node) >>> _, node = pm.apply(ConstantFolding, node) >>> print(pm.dump(backend.Python, node)) def foo(a): a[4] >>> node = ast.parse("def foo(a): a[::2][3]") >>> _, node = pm.apply(PartialConstantFolding, node) >>> _, node = pm.apply(ConstantFolding, node) >>> print(pm.dump(backend.Python, node)) def foo(a): a[6] >>> node = ast.parse("def foo(a): a[-4:][5]") >>> _, node = pm.apply(PartialConstantFolding, node) >>> _, node = pm.apply(ConstantFolding, node) >>> print(pm.dump(backend.Python, node)) def foo(a): a[1] """ self.generic_visit(node) if not isinstance(node.value, ast.Subscript): return node if not isinstance(node.value.slice, ast.Slice): return node if not isinstance(node.slice, ast.Index): return node if not isnum(node.slice.value): return node slice_ = node.value.slice index = node.slice node = node.value node.slice = index lower = slice_.lower or ast.Constant(0, None) step = slice_.step or ast.Constant(1, None) node.slice.value = ast.BinOp(lower, ast.Add(), ast.BinOp(index.value, ast.Mult(), step)) self.update = True return node
def visit_Slice(self, node): """ Set slicing type using continuous information if provided. Also visit subnodes as they may contains relevant typing information. """ self.generic_visit(node) if node.step is None or (isnum(node.step) and node.step.value == 1): self.result[node] = self.builder.NamedType( 'pythonic::types::contiguous_slice') else: self.result[node] = self.builder.NamedType( 'pythonic::types::slice')
def visit_Slice(self, node): args = [] for field in ('lower', 'upper', 'step'): nfield = getattr(node, field) arg = (self.visit(nfield) if nfield else 'pythonic::builtins::None') args.append(arg) if node.step is None or (isnum(node.step) and node.step.value == 1): return "pythonic::types::contiguous_slice({},{})".format(args[0], args[1]) else: return "pythonic::types::slice({},{},{})".format(*args)
def visit_Call(self, node): func_aliases = self.aliases[node.func] for alias in func_aliases: if getattr(alias, "immediate_arguments", []): for i, arg in enumerate(node.args): if i in alias.immediate_arguments: self.result.add(arg) if len(func_aliases) == 1 and next(iter(func_aliases)) is _make_shape: self.result.update( a for a in node.args if isnum(a) and isinstance(a.value, int) and a.value >= 0) return return self.generic_visit(node)
def isrange(self, elts): if not elts: return None if not all(isnum(x) and isinstance(x.value, int) for x in elts): return None unboxed_ints = [x.value for x in elts] start = unboxed_ints[0] if len(unboxed_ints) == 1: return start, start + 1, 1 else: step = unboxed_ints[1] - start stop = unboxed_ints[-1] + step if unboxed_ints == list(range(start, stop, step)): return start, stop, step else: return None
def _UnaryOp(self, t): self.write("(") self.write(self.unop[t.op.__class__.__name__]) self.write(" ") # If we're applying unary minus to a number, parenthesize the number. # This is necessary: -2147483648 is different from -(2147483648) on # a 32-bit machine (the first is an int, the second a long), and # -7j is different from -(7j). (The first has real part 0.0, the # second has real part -0.0.) if isinstance(t.op, ast.USub) and isnum(t.operand): self.write("(") self.dispatch(t.operand) self.write(")") else: self.dispatch(t.operand) self.write(")")
def visit_Slice(self, node): """ Set slicing type using continuous information if provided. Also visit subnodes as they may contains relevant typing information. """ self.generic_visit(node) if node.step is None or (isnum(node.step) and node.step.value == 1): if all(self.range_values[p].low >= 0 for p in (node.lower, node.upper)): ntype = "pythonic::types::fast_contiguous_slice" else: ntype = "pythonic::types::contiguous_slice" self.result[node] = self.builder.NamedType(ntype) else: self.result[node] = self.builder.NamedType( 'pythonic::types::slice')
def node_to_id(self, n, depth=()): if isinstance(n, ast.Name): return (n.id, depth) elif isinstance(n, ast.Subscript): if isinstance(n.slice, ast.Slice): return self.node_to_id(n.value, depth) else: index = n.slice.value if isnum(n.slice) else None return self.node_to_id(n.value, depth + (index, )) # use alias information if any elif isinstance(n, ast.Call): for alias in self.strict_aliases[n]: if alias is n: # no specific alias info continue try: return self.node_to_id(alias, depth) except UnboundableRValue: continue raise UnboundableRValue()
def visit_Subscript(self, node): value = self.visit(node.value) # we cannot overload the [] operator in that case if isstr(node.value): value = 'pythonic::types::str({})'.format(value) # positive static index case if (isnum(node.slice) and (node.slice.value >= 0) and isinstance(node.slice.value, int)): return "std::get<{0}>({1})".format(node.slice.value, value) # positive indexing case elif self.all_positive(node.slice): slice_ = self.visit(node.slice) return "{1}.fast({0})".format(slice_, value) # extended slice case elif isextslice(node.slice): slices = [self.visit(elt) for elt in node.slice.elts] return "{1}({0})".format(','.join(slices), value) # standard case else: slice_ = self.visit(node.slice) return "{1}[{0}]".format(slice_, value)
def fold_mult_right(self, node): if not isinstance(node.right, (ast.List, ast.Tuple)): return False if not isnum(node.left): return False return isinstance(node.op, ast.Mult)
def analyse(node, env, non_generic=None): """Computes the type of the expression given by node. The type of the node is computed in the context of the context of the supplied type environment env. Data types can be introduced into the language simply by having a predefined set of identifiers in the initial environment. Environment; this way there is no need to change the syntax or more importantly, the type-checking program when extending the language. Args: node: The root of the abstract syntax tree. env: The type environment is a mapping of expression identifier names to type assignments. non_generic: A set of non-generic variables, or None Returns: The computed type of the expression. Raises: InferenceError: The type of the expression could not be inferred, PythranTypeError: InferenceError with user friendly message + location """ if non_generic is None: non_generic = set() # expr if isinstance(node, gast.Name): if isinstance(node.ctx, (gast.Store)): new_type = TypeVariable() non_generic.add(new_type) env[node.id] = new_type return get_type(node.id, env, non_generic) elif isinstance(node, gast.Constant): if isinstance(node.value, str): return Str() elif isinstance(node.value, int): return Integer() elif isinstance(node.value, float): return Float() elif isinstance(node.value, complex): return Complex() elif node.value is None: return NoneType else: raise NotImplementedError elif isinstance(node, gast.Compare): left_type = analyse(node.left, env, non_generic) comparators_type = [analyse(comparator, env, non_generic) for comparator in node.comparators] ops_type = [analyse(op, env, non_generic) for op in node.ops] prev_type = left_type result_type = TypeVariable() for op_type, comparator_type in zip(ops_type, comparators_type): try: unify(Function([prev_type, comparator_type], result_type), op_type) prev_type = comparator_type except InferenceError: raise PythranTypeError( "Invalid comparison, between `{}` and `{}`".format( prev_type, comparator_type ), node) return result_type elif isinstance(node, gast.Call): if is_getattr(node): self_type = analyse(node.args[0], env, non_generic) attr_name = node.args[1].value _, attr_signature = attributes[attr_name] attr_type = tr(attr_signature) result_type = TypeVariable() try: unify(Function([self_type], result_type), attr_type) except InferenceError: if isinstance(prune(attr_type), MultiType): msg = 'no attribute found, tried:\n{}'.format(attr_type) else: msg = 'tried {}'.format(attr_type) raise PythranTypeError( "Invalid attribute for getattr call with self" "of type `{}`, {}".format(self_type, msg), node) else: fun_type = analyse(node.func, env, non_generic) arg_types = [analyse(arg, env, non_generic) for arg in node.args] result_type = TypeVariable() try: unify(Function(arg_types, result_type), fun_type) except InferenceError: # recover original type fun_type = analyse(node.func, env, non_generic) if isinstance(prune(fun_type), MultiType): msg = 'no overload found, tried:\n{}'.format(fun_type) else: msg = 'tried {}'.format(fun_type) raise PythranTypeError( "Invalid argument type for function call to " "`Callable[[{}], ...]`, {}" .format(', '.join('{}'.format(at) for at in arg_types), msg), node) return result_type elif isinstance(node, gast.IfExp): test_type = analyse(node.test, env, non_generic) unify(Function([test_type], Bool()), tr(MODULES['builtins']['bool'])) if is_test_is_none(node.test): none_id = node.test.left.id body_env = env.copy() body_env[none_id] = NoneType else: none_id = None body_env = env body_type = analyse(node.body, body_env, non_generic) if none_id: orelse_env = env.copy() if is_option_type(env[none_id]): orelse_env[none_id] = prune(env[none_id]).types[0] else: orelse_env[none_id] = TypeVariable() else: orelse_env = env orelse_type = analyse(node.orelse, orelse_env, non_generic) try: return merge_unify(body_type, orelse_type) except InferenceError: raise PythranTypeError( "Incompatible types from different branches:" "`{}` and `{}`".format( body_type, orelse_type ), node ) elif isinstance(node, gast.UnaryOp): operand_type = analyse(node.operand, env, non_generic) op_type = analyse(node.op, env, non_generic) result_type = TypeVariable() try: unify(Function([operand_type], result_type), op_type) return result_type except InferenceError: raise PythranTypeError( "Invalid operand for `{}`: `{}`".format( symbol_of[type(node.op)], operand_type ), node ) elif isinstance(node, gast.BinOp): left_type = analyse(node.left, env, non_generic) op_type = analyse(node.op, env, non_generic) right_type = analyse(node.right, env, non_generic) result_type = TypeVariable() try: unify(Function([left_type, right_type], result_type), op_type) except InferenceError: raise PythranTypeError( "Invalid operand for `{}`: `{}` and `{}`".format( symbol_of[type(node.op)], left_type, right_type), node ) return result_type elif isinstance(node, gast.Pow): return tr(MODULES['numpy']['power']) elif isinstance(node, gast.Sub): return tr(MODULES['operator']['sub']) elif isinstance(node, (gast.USub, gast.UAdd)): return tr(MODULES['operator']['pos']) elif isinstance(node, (gast.Eq, gast.NotEq, gast.Lt, gast.LtE, gast.Gt, gast.GtE, gast.Is, gast.IsNot)): return tr(MODULES['operator']['eq']) elif isinstance(node, (gast.In, gast.NotIn)): contains_sig = tr(MODULES['operator']['contains']) contains_sig.types[:-1] = reversed(contains_sig.types[:-1]) return contains_sig elif isinstance(node, gast.Add): return tr(MODULES['operator']['add']) elif isinstance(node, gast.Mult): return tr(MODULES['operator']['mul']) elif isinstance(node, gast.MatMult): return tr(MODULES['operator']['matmul']) elif isinstance(node, (gast.Div, gast.FloorDiv)): return tr(MODULES['operator']['floordiv']) elif isinstance(node, gast.Mod): return tr(MODULES['operator']['mod']) elif isinstance(node, (gast.LShift, gast.RShift)): return tr(MODULES['operator']['lshift']) elif isinstance(node, (gast.BitXor, gast.BitAnd, gast.BitOr)): return tr(MODULES['operator']['lshift']) elif isinstance(node, gast.List): new_type = TypeVariable() for elt in node.elts: elt_type = analyse(elt, env, non_generic) try: unify(new_type, elt_type) except InferenceError: raise PythranTypeError( "Incompatible list element type `{}` and `{}`".format( new_type, elt_type), node ) return List(new_type) elif isinstance(node, gast.Set): new_type = TypeVariable() for elt in node.elts: elt_type = analyse(elt, env, non_generic) try: unify(new_type, elt_type) except InferenceError: raise PythranTypeError( "Incompatible set element type `{}` and `{}`".format( new_type, elt_type), node ) return Set(new_type) elif isinstance(node, gast.Dict): new_key_type = TypeVariable() for key in node.keys: key_type = analyse(key, env, non_generic) try: unify(new_key_type, key_type) except InferenceError: raise PythranTypeError( "Incompatible dict key type `{}` and `{}`".format( new_key_type, key_type), node ) new_value_type = TypeVariable() for value in node.values: value_type = analyse(value, env, non_generic) try: unify(new_value_type, value_type) except InferenceError: raise PythranTypeError( "Incompatible dict value type `{}` and `{}`".format( new_value_type, value_type), node ) return Dict(new_key_type, new_value_type) elif isinstance(node, gast.Tuple): return Tuple([analyse(elt, env, non_generic) for elt in node.elts]) elif isinstance(node, gast.Index): return analyse(node.value, env, non_generic) elif isinstance(node, gast.Slice): def unify_int_or_none(t, name): try: unify(t, Integer()) except InferenceError: try: unify(t, NoneType) except InferenceError: raise PythranTypeError( "Invalid slice {} type `{}`, expecting int or None" .format(name, t) ) if node.lower: lower_type = analyse(node.lower, env, non_generic) unify_int_or_none(lower_type, 'lower bound') else: lower_type = Integer() if node.upper: upper_type = analyse(node.upper, env, non_generic) unify_int_or_none(upper_type, 'upper bound') else: upper_type = Integer() if node.step: step_type = analyse(node.step, env, non_generic) unify_int_or_none(step_type, 'step') else: step_type = Integer() return Slice elif isinstance(node, gast.ExtSlice): return [analyse(dim, env, non_generic) for dim in node.dims] elif isinstance(node, gast.Subscript): new_type = TypeVariable() value_type = prune(analyse(node.value, env, non_generic)) try: slice_type = prune(analyse(node.slice, env, non_generic)) except PythranTypeError as e: raise PythranTypeError(e.msg, node) if isinstance(node.slice, gast.ExtSlice): nbslice = len(node.slice.dims) dtype = TypeVariable() try: unify(Array(dtype, nbslice), clone(value_type)) except InferenceError: raise PythranTypeError( "Dimension mismatch when slicing `{}`".format(value_type), node) return TypeVariable() # FIXME elif isinstance(node.slice, gast.Index): # handle tuples in a special way num = isnum(node.slice.value) if num and is_tuple_type(value_type): try: unify(prune(prune(value_type.types[0]).types[0]) .types[node.slice.value.n], new_type) return new_type except IndexError: raise PythranTypeError( "Invalid tuple indexing, " "out-of-bound index `{}` for type `{}`".format( node.slice.value.n, value_type), node) try: unify(tr(MODULES['operator']['getitem']), Function([value_type, slice_type], new_type)) except InferenceError: raise PythranTypeError( "Invalid subscripting of `{}` by `{}`".format( value_type, slice_type), node) return new_type return new_type elif isinstance(node, gast.Attribute): from pythran.utils import attr_to_path obj, path = attr_to_path(node) if obj.signature is typing.Any: return TypeVariable() else: return tr(obj) # stmt elif isinstance(node, gast.Import): for alias in node.names: if alias.name not in MODULES: raise NotImplementedError("unknown module: %s " % alias.name) if alias.asname is None: target = alias.name else: target = alias.asname env[target] = tr(MODULES[alias.name]) return env elif isinstance(node, gast.ImportFrom): if node.module not in MODULES: raise NotImplementedError("unknown module: %s" % node.module) for alias in node.names: if alias.name not in MODULES[node.module]: raise NotImplementedError( "unknown function: %s in %s" % (alias.name, node.module)) if alias.asname is None: target = alias.name else: target = alias.asname env[target] = tr(MODULES[node.module][alias.name]) return env elif isinstance(node, gast.FunctionDef): ftypes = [] for i in range(1 + len(node.args.defaults)): new_env = env.copy() new_non_generic = non_generic.copy() # reset return special variables new_env.pop('@ret', None) new_env.pop('@gen', None) hy = HasYield() for stmt in node.body: hy.visit(stmt) new_env['@gen'] = hy.has_yield arg_types = [] istop = len(node.args.args) - i for arg in node.args.args[:istop]: arg_type = TypeVariable() new_env[arg.id] = arg_type new_non_generic.add(arg_type) arg_types.append(arg_type) for arg, expr in zip(node.args.args[istop:], node.args.defaults[-i:]): arg_type = analyse(expr, new_env, new_non_generic) new_env[arg.id] = arg_type analyse_body(node.body, new_env, new_non_generic) result_type = new_env.get('@ret', NoneType) if new_env['@gen']: result_type = Generator(result_type) ftype = Function(arg_types, result_type) ftypes.append(ftype) if len(ftypes) == 1: ftype = ftypes[0] env[node.name] = ftype else: env[node.name] = MultiType(ftypes) return env elif isinstance(node, gast.Module): analyse_body(node.body, env, non_generic) return env elif isinstance(node, (gast.Pass, gast.Break, gast.Continue)): return env elif isinstance(node, gast.Expr): analyse(node.value, env, non_generic) return env elif isinstance(node, gast.Delete): for target in node.targets: if isinstance(target, gast.Name): if target.id in env: del env[target.id] else: raise PythranTypeError( "Invalid del: unbound identifier `{}`".format( target.id), node) else: analyse(target, env, non_generic) return env elif isinstance(node, gast.Print): if node.dest is not None: analyse(node.dest, env, non_generic) for value in node.values: analyse(value, env, non_generic) return env elif isinstance(node, gast.Assign): defn_type = analyse(node.value, env, non_generic) for target in node.targets: target_type = analyse(target, env, non_generic) try: unify(target_type, defn_type) except InferenceError: raise PythranTypeError( "Invalid assignment from type `{}` to type `{}`".format( target_type, defn_type), node) return env elif isinstance(node, gast.AugAssign): # FIMXE: not optimal: evaluates type of node.value twice fake_target = deepcopy(node.target) fake_target.ctx = gast.Load() fake_op = gast.BinOp(fake_target, node.op, node.value) gast.copy_location(fake_op, node) res_type = analyse(fake_op, env, non_generic) target_type = analyse(node.target, env, non_generic) try: unify(target_type, res_type) except InferenceError: raise PythranTypeError( "Invalid update operand for `{}`: `{}` and `{}`".format( symbol_of[type(node.op)], res_type, target_type ), node ) return env elif isinstance(node, gast.Raise): return env # TODO elif isinstance(node, gast.Return): if env['@gen']: return env if node.value is None: ret_type = NoneType else: ret_type = analyse(node.value, env, non_generic) if '@ret' in env: try: ret_type = merge_unify(env['@ret'], ret_type) except InferenceError: raise PythranTypeError( "function may returns with incompatible types " "`{}` and `{}`".format(env['@ret'], ret_type), node ) env['@ret'] = ret_type return env elif isinstance(node, gast.Yield): assert env['@gen'] assert node.value is not None if node.value is None: ret_type = NoneType else: ret_type = analyse(node.value, env, non_generic) if '@ret' in env: try: ret_type = merge_unify(env['@ret'], ret_type) except InferenceError: raise PythranTypeError( "function may yields incompatible types " "`{}` and `{}`".format(env['@ret'], ret_type), node ) env['@ret'] = ret_type return env elif isinstance(node, gast.For): iter_type = analyse(node.iter, env, non_generic) target_type = analyse(node.target, env, non_generic) unify(Collection(TypeVariable(), TypeVariable(), TypeVariable(), target_type), iter_type) analyse_body(node.body, env, non_generic) analyse_body(node.orelse, env, non_generic) return env elif isinstance(node, gast.If): test_type = analyse(node.test, env, non_generic) unify(Function([test_type], Bool()), tr(MODULES['builtins']['bool'])) body_env = env.copy() body_non_generic = non_generic.copy() if is_test_is_none(node.test): none_id = node.test.left.id body_env[none_id] = NoneType else: none_id = None analyse_body(node.body, body_env, body_non_generic) orelse_env = env.copy() orelse_non_generic = non_generic.copy() if none_id: if is_option_type(env[none_id]): orelse_env[none_id] = prune(env[none_id]).types[0] else: orelse_env[none_id] = TypeVariable() analyse_body(node.orelse, orelse_env, orelse_non_generic) for var in body_env: if var not in env: if var in orelse_env: try: new_type = merge_unify(body_env[var], orelse_env[var]) except InferenceError: raise PythranTypeError( "Incompatible types from different branches for " "`{}`: `{}` and `{}`".format( var, body_env[var], orelse_env[var] ), node ) else: new_type = body_env[var] env[var] = new_type for var in orelse_env: if var not in env: # may not be unified by the prev loop if a del occured if var in body_env: new_type = merge_unify(orelse_env[var], body_env[var]) else: new_type = orelse_env[var] env[var] = new_type if none_id: try: new_type = merge_unify(body_env[none_id], orelse_env[none_id]) except InferenceError: msg = ("Inconsistent types while merging values of `{}` from " "conditional branches: `{}` and `{}`") err = msg.format(none_id, body_env[none_id], orelse_env[none_id]) raise PythranTypeError(err, node) env[none_id] = new_type return env elif isinstance(node, gast.While): test_type = analyse(node.test, env, non_generic) unify(Function([test_type], Bool()), tr(MODULES['builtins']['bool'])) analyse_body(node.body, env, non_generic) analyse_body(node.orelse, env, non_generic) return env elif isinstance(node, gast.Try): analyse_body(node.body, env, non_generic) for handler in node.handlers: analyse(handler, env, non_generic) analyse_body(node.orelse, env, non_generic) analyse_body(node.finalbody, env, non_generic) return env elif isinstance(node, gast.ExceptHandler): if(node.name): new_type = ExceptionType non_generic.add(new_type) if node.name.id in env: unify(env[node.name.id], new_type) else: env[node.name.id] = new_type analyse_body(node.body, env, non_generic) return env elif isinstance(node, gast.Assert): if node.msg: analyse(node.msg, env, non_generic) analyse(node.test, env, non_generic) return env elif isinstance(node, gast.UnaryOp): operand_type = analyse(node.operand, env, non_generic) return_type = TypeVariable() op_type = analyse(node.op, env, non_generic) unify(Function([operand_type], return_type), op_type) return return_type elif isinstance(node, gast.Invert): return MultiType([Function([Bool()], Integer()), Function([Integer()], Integer())]) elif isinstance(node, gast.Not): return tr(MODULES['builtins']['bool']) elif isinstance(node, gast.BoolOp): op_type = analyse(node.op, env, non_generic) value_types = [analyse(value, env, non_generic) for value in node.values] for value_type in value_types: unify(Function([value_type], Bool()), tr(MODULES['builtins']['bool'])) return_type = TypeVariable() prev_type = value_types[0] for value_type in value_types[1:]: unify(Function([prev_type, value_type], return_type), op_type) prev_type = value_type return return_type elif isinstance(node, (gast.And, gast.Or)): x_type = TypeVariable() return MultiType([ Function([x_type, x_type], x_type), Function([TypeVariable(), TypeVariable()], TypeVariable()), ]) raise RuntimeError("Unhandled syntax node {0}".format(type(node)))