def visit_Break(self, node): loop_node_index = self._find_ancestor_loop_index(node) assert loop_node_index != -1, "SyntaxError: 'break' outside loop" loop_node = self.ancestor_nodes[loop_node_index] # 1. Map the 'break/continue' stmt with an unique boolean variable V. variable_name = unique_name.generate(BREAK_NAME_PREFIX) # 2. Find the first ancestor block containing this 'break/continue', a # block can be a node containing stmt list. We should remove all stmts # after the 'break/continue' and set the V to True here. first_block_index = self._remove_stmts_after_break_continue( node, variable_name, loop_node_index) # 3. Add 'if V' for stmts in ancestor blocks between the first one # (exclusive) and the ancestor loop (inclusive) self._replace_if_stmt(loop_node_index, first_block_index, variable_name) # 4. For 'break' add break into condition of the loop. assign_false_node = create_fill_constant_node(variable_name, False) self._add_stmt_before_cur_node(loop_node_index, assign_false_node) cond_var_node = gast.UnaryOp(op=gast.Not(), operand=gast.Name(id=variable_name, ctx=gast.Load(), annotation=None, type_comment=None)) if isinstance(loop_node, gast.While): loop_node.test = gast.BoolOp( op=gast.And(), values=[loop_node.test, cond_var_node]) elif isinstance(loop_node, gast.For): parent_node = self.ancestor_nodes[loop_node_index - 1] for_to_while = ForToWhileTransformer(parent_node, loop_node, cond_var_node) for_to_while.transform()
def visit_Call(self, node): if sys.version_info.minor < 5: if node.starargs: star = gast.Starred(self._visit(node.starargs), gast.Load()) gast.copy_location(star, node) starred = [star] else: starred = [] if node.kwargs: kwargs = [gast.keyword(None, self._visit(node.kwargs))] else: kwargs = [] else: starred = kwargs = [] new_node = gast.Call( self._visit(node.func), self._visit(node.args) + starred, self._visit(node.keywords) + kwargs, ) gast.copy_location(new_node, node) return new_node
def test_unparse(self): node = gast.If( test=gast.Constant(1, kind=None), body=[ gast.Assign( targets=[ gast.Name( 'a', ctx=gast.Store(), annotation=None, type_comment=None) ], value=gast.Name( 'b', ctx=gast.Load(), annotation=None, type_comment=None)) ], orelse=[ gast.Assign( targets=[ gast.Name( 'a', ctx=gast.Store(), annotation=None, type_comment=None) ], value=gast.Constant('c', kind=None)) ]) source = parser.unparse(node, indentation=' ') self.assertEqual( textwrap.dedent(""" # coding=utf-8 if 1: a = b else: a = 'c' """).strip(), source.strip())
def interprocedural_type_translator(s, n): translated_othernode = ast.Name( '__fake__', ast.Load(), None, None) s.result[translated_othernode] = ( parametric_type.instanciate( s.current, [s.result[arg] for arg in n.args])) # look for modified argument for p, effective_arg in enumerate(n.args): formal_arg = args[p] if formal_arg.id == node_id: translated_node = effective_arg break try: s.combine(translated_node, translated_othernode, op, unary_op, register=True, aliasing_type=True) except NotImplementedError: pass # this may fail when the effective # parameter is an expression except UnboundLocalError: pass
def tokenize(s): '''A simple contextual "parser" for an OpenMP string''' # not completely satisfying if there are strings in if expressions out = '' par_count = 0 curr_index = 0 in_reserved_context = False while curr_index < len(s): m = re.match(r'^([a-zA-Z_]\w*)', s[curr_index:]) if m: word = m.group(0) curr_index += len(word) if (in_reserved_context or (par_count == 0 and word in keywords)): out += word in_reserved_context = word in reserved_contex else: v = '{}' self.deps.append(ast.Name(word, ast.Load(), None)) out += v elif s[curr_index] == '(': par_count += 1 curr_index += 1 out += '(' elif s[curr_index] == ')': par_count -= 1 curr_index += 1 out += ')' if par_count == 0: in_reserved_context = False else: if s[curr_index] == ',': in_reserved_context = False out += s[curr_index] curr_index += 1 return out
def size_container_folding(value): """ Convert value to ast expression if size is not too big. Converter for sized container. """ def size(x): return len(getattr(x, 'flatten', lambda: x)()) if size(value) < MAX_LEN: if isinstance(value, list): return ast.List([to_ast(elt) for elt in value], ast.Load()) elif isinstance(value, tuple): return ast.Tuple([to_ast(elt) for elt in value], ast.Load()) elif isinstance(value, set): if value: return ast.Set([to_ast(elt) for elt in value]) else: return ast.Call(func=ast.Attribute( ast.Name(mangle('builtins'), ast.Load(), None, None), 'set', ast.Load()), args=[], keywords=[]) elif isinstance(value, dict): keys = [to_ast(elt) for elt in value.keys()] values = [to_ast(elt) for elt in value.values()] return ast.Dict(keys, values) elif isinstance(value, np.ndarray): return ast.Call(func=ast.Attribute( ast.Name(mangle('numpy'), ast.Load(), None, None), 'array', ast.Load()), args=[to_ast(totuple(value.tolist())), dtype_to_ast(value.dtype.name)], keywords=[]) else: raise ConversionError() else: raise ToNotEval()
def make_dispatcher(static_expr, func_true, func_false, imported_ids): dispatcher_args = [static_expr, ast.Name(func_true.name, ast.Load(), None, None), ast.Name(func_false.name, ast.Load(), None, None)] dispatcher = ast.Call( ast.Attribute( ast.Attribute( ast.Name("builtins", ast.Load(), None, None), "pythran", ast.Load()), "static_if", ast.Load()), dispatcher_args, []) actual_call = ast.Call( dispatcher, [ast.Name(ii, ast.Load(), None, None) for ii in imported_ids], []) return actual_call
def visit_ExtSlice(self, node): new_dims = self._visit(node.dims) new_node = gast.Tuple(new_dims, gast.Load()) gast.copy_location(new_node, node) new_node.end_lineno = new_node.end_col_offset = None return new_node
class PyASTGraphsTest(parameterized.TestCase): def test_schema(self): """Check that the generated schema is reasonable.""" # Building should succeed schema = py_ast_graphs.SCHEMA # Should have one graph node for each AST node, plus sorted helpers. seq_helpers = [ "BoolOp_values-seq-helper", "Call_args-seq-helper", "For_body-seq-helper", "FunctionDef_body-seq-helper", "If_body-seq-helper", "If_orelse-seq-helper", "Module_body-seq-helper", "While_body-seq-helper", "arguments_args-seq-helper", ] expected_keys = [ *py_ast_graphs.PY_AST_SPECS.keys(), *seq_helpers, ] self.assertEqual(list(schema.keys()), expected_keys) # Check a few elements expected_fundef_in_edges = { "parent_in", "returns_in", "returns_missing", "body_in", "args_in", } expected_fundef_out_edges = { "parent_out", "returns_out", "body_out_all", "body_out_first", "body_out_last", "args_out", } self.assertEqual(set(schema["FunctionDef"].in_edges), expected_fundef_in_edges) self.assertEqual(set(schema["FunctionDef"].out_edges), expected_fundef_out_edges) expected_expr_in_edges = {"parent_in", "value_in"} expected_expr_out_edges = {"parent_out", "value_out"} self.assertEqual(set(schema["Expr"].in_edges), expected_expr_in_edges) self.assertEqual(set(schema["Expr"].out_edges), expected_expr_out_edges) expected_seq_helper_in_edges = { "parent_in", "item_in", "next_in", "next_missing", "prev_in", "prev_missing", } expected_seq_helper_out_edges = { "parent_out", "item_out", "next_out", "prev_out", } for seq_helper in seq_helpers: self.assertEqual(set(schema[seq_helper].in_edges), expected_seq_helper_in_edges) self.assertEqual(set(schema[seq_helper].out_edges), expected_seq_helper_out_edges) def test_ast_graph_conforms_to_schema(self): # Some example code using a few different syntactic constructs, to cover # a large set of nodes in the schema root = gast.parse( textwrap.dedent("""\ def foo(n): if n <= 1: return 1 else: return foo(n-1) + foo(n-2) def bar(m, n) -> int: x = n for i in range(m): if False: continue x = x + i while True: break return x x0 = 1 + 2 - 3 * 4 / 5 x1 = (1 == 2) and (3 < 4) and (5 > 6) x2 = (7 <= 8) and (9 >= 10) or (11 != 12) x2 = bar(13, 14 + 15) """)) graph, _ = py_ast_graphs.py_ast_to_graph(root) # Graph should match the schema schema_util.assert_conforms_to_schema(graph, py_ast_graphs.SCHEMA) def test_ast_graph_nodes(self): """Check node IDs, node types, and forward mapping.""" root = gast.parse( textwrap.dedent("""\ pass def foo(n): if n <= 1: return 1 """)) graph, forward_map = py_ast_graphs.py_ast_to_graph(root) # pytype: disable=attribute-error self.assertIn("root__Module", graph) self.assertEqual(graph["root__Module"].node_type, "Module") self.assertEqual(forward_map[id(root)], "root__Module") self.assertIn("root_body_1__Module_body-seq-helper", graph) self.assertEqual( graph["root_body_1__Module_body-seq-helper"].node_type, "Module_body-seq-helper") self.assertIn("root_body_1_item_body_0_item__If", graph) self.assertEqual(graph["root_body_1_item_body_0_item__If"].node_type, "If") self.assertEqual(forward_map[id(root.body[1].body[0])], "root_body_1_item_body_0_item__If") self.assertIn("root_body_1_item_body_0_item_test_left__Name", graph) self.assertEqual( graph["root_body_1_item_body_0_item_test_left__Name"].node_type, "Name") self.assertEqual(forward_map[id(root.body[1].body[0].test.left)], "root_body_1_item_body_0_item_test_left__Name") # pytype: enable=attribute-error def test_ast_graph_unique_field_edges(self): """Test that edges for unique fields are correct.""" root = gast.parse("print(1)") graph, _ = py_ast_graphs.py_ast_to_graph(root) self.assertEqual( graph["root_body_0_item__Expr"].out_edges["value_out"], [ graph_types.InputTaggedNode( node_id=graph_types.NodeId("root_body_0_item_value__Call"), in_edge=graph_types.InEdgeType("parent_in")) ]) self.assertEqual( graph["root_body_0_item_value__Call"].out_edges["parent_out"], [ graph_types.InputTaggedNode( node_id=graph_types.NodeId("root_body_0_item__Expr"), in_edge=graph_types.InEdgeType("value_in")) ]) def test_ast_graph_optional_field_edges(self): """Test that edges for optional fields are correct.""" root = gast.parse("return 1\nreturn") graph, _ = py_ast_graphs.py_ast_to_graph(root) self.assertEqual( graph["root_body_0_item__Return"].out_edges["value_out"], [ graph_types.InputTaggedNode( node_id=graph_types.NodeId( "root_body_0_item_value__Constant"), in_edge=graph_types.InEdgeType("parent_in")) ]) self.assertEqual( graph["root_body_0_item_value__Constant"].out_edges["parent_out"], [ graph_types.InputTaggedNode( node_id=graph_types.NodeId("root_body_0_item__Return"), in_edge=graph_types.InEdgeType("value_in")) ]) self.assertEqual( graph["root_body_1_item__Return"].out_edges["value_out"], [ graph_types.InputTaggedNode( node_id=graph_types.NodeId("root_body_1_item__Return"), in_edge=graph_types.InEdgeType("value_missing")) ]) def test_ast_graph_sequence_field_edges(self): """Test that edges for sequence fields are correct. Note that sequence fields produce connections between three nodes: the parent, the helper node, and the child. """ root = gast.parse( textwrap.dedent("""\ print(1) print(2) print(3) print(4) print(5) print(6) """)) graph, _ = py_ast_graphs.py_ast_to_graph(root) # Child edges from the parent node node = graph["root__Module"] self.assertLen(node.out_edges["body_out_all"], 6) self.assertEqual(node.out_edges["body_out_first"], [ graph_types.InputTaggedNode( node_id=graph_types.NodeId( "root_body_0__Module_body-seq-helper"), in_edge=graph_types.InEdgeType("parent_in")) ]) self.assertEqual(node.out_edges["body_out_last"], [ graph_types.InputTaggedNode( node_id=graph_types.NodeId( "root_body_5__Module_body-seq-helper"), in_edge=graph_types.InEdgeType("parent_in")) ]) # Edges from the sequence helper node = graph["root_body_0__Module_body-seq-helper"] self.assertEqual(node.out_edges["parent_out"], [ graph_types.InputTaggedNode( node_id=graph_types.NodeId("root__Module"), in_edge=graph_types.InEdgeType("body_in")) ]) self.assertEqual(node.out_edges["item_out"], [ graph_types.InputTaggedNode( node_id=graph_types.NodeId("root_body_0_item__Expr"), in_edge=graph_types.InEdgeType("parent_in")) ]) self.assertEqual(node.out_edges["prev_out"], [ graph_types.InputTaggedNode( node_id=graph_types.NodeId( "root_body_0__Module_body-seq-helper"), in_edge=graph_types.InEdgeType("prev_missing")) ]) self.assertEqual(node.out_edges["next_out"], [ graph_types.InputTaggedNode( node_id=graph_types.NodeId( "root_body_1__Module_body-seq-helper"), in_edge=graph_types.InEdgeType("prev_in")) ]) # Parent edge of the item node = graph["root_body_0_item__Expr"] self.assertEqual(node.out_edges["parent_out"], [ graph_types.InputTaggedNode( node_id=graph_types.NodeId( "root_body_0__Module_body-seq-helper"), in_edge=graph_types.InEdgeType("item_in")) ]) @parameterized.named_parameters( { "testcase_name": "unexpected_type", "ast": gast.Subscript(value=None, slice=None, ctx=None), "expected_error": "Unknown AST node type 'Subscript'", }, { "testcase_name": "too_many_unique", "ast": gast.Assign(targets=[ gast.Name("foo", gast.Store(), None, None), gast.Name("bar", gast.Store(), None, None) ], value=gast.Constant(True, None)), "expected_error": "Expected 1 child for field 'targets' of node .*; got 2", }, { "testcase_name": "missing_unique", "ast": gast.Assign(targets=[], value=gast.Constant(True, None)), "expected_error": "Expected 1 child for field 'targets' of node .*; got 0", }, { "testcase_name": "too_many_optional", "ast": gast.Return(value=[ gast.Name("foo", gast.Load(), None, None), gast.Name("bar", gast.Load(), None, None) ]), "expected_error": "Expected at most 1 child for field 'value' of node .*; got 2", }) def test_invalid_graphs(self, ast, expected_error): with self.assertRaisesRegex(ValueError, expected_error): py_ast_graphs.py_ast_to_graph(ast)
def visit_Call(self, node): """ Transform call site to have normal function call. Examples -------- For methods: >> a = [1, 2, 3] >> a.append(1) Becomes >> __list__.append(a, 1) For functions: >> __builtin__.dict.fromkeys([1, 2, 3]) Becomes >> __builtin__.__dict__.fromkeys([1, 2, 3]) """ node = self.generic_visit(node) # Only attributes function can be Pythonic and should be normalized if isinstance(node.func, ast.Attribute): if node.func.attr in methods: # Get object targeted by methods obj = lhs = node.func.value # Get the most left identifier to check if it is not an # imported module while isinstance(obj, ast.Attribute): obj = obj.value is_not_module = (not isinstance(obj, ast.Name) or obj.id not in self.imports) if is_not_module: self.update = True # As it was a methods call, push targeted object as first # arguments and add correct module prefix node.args.insert(0, lhs) mod = methods[node.func.attr][0] # Submodules import full module self.to_import.add(mangle(mod[0])) node.func = reduce( lambda v, o: ast.Attribute(v, o, ast.Load()), mod[1:] + (node.func.attr, ), ast.Name(mangle(mod[0]), ast.Load(), None)) # else methods have been called using function syntax if node.func.attr in methods or node.func.attr in functions: # Now, methods and function have both function syntax def rec(path, cur_module): """ Recursively rename path content looking in matching module. Prefers __module__ to module if it exists. This recursion is done as modules are visited top->bottom while attributes have to be visited bottom->top. """ err = "Function path is chained attributes and name" assert isinstance(path, (ast.Name, ast.Attribute)), err if isinstance(path, ast.Attribute): new_node, cur_module = rec(path.value, cur_module) new_id, mname = self.renamer(path.attr, cur_module) return (ast.Attribute(new_node, new_id, ast.Load()), cur_module[mname]) else: new_id, mname = self.renamer(path.id, cur_module) if mname not in cur_module: raise PythranSyntaxError( "Unbound identifier '{}'".format(mname), node) return (ast.Name(new_id, ast.Load(), None), cur_module[mname]) # Rename module path to avoid naming issue. node.func.value, _ = rec(node.func.value, MODULES) self.update = True return node
def add_stararg(self, a): self._consume_args() self._argspec.append( gast.Call(gast.Name('tuple', gast.Load(), None), [a], ()))
class Square(Transformation): """ Replaces **2 by a call to numpy.square. >>> import gast as ast >>> from pythran import passmanager, backend >>> node = ast.parse('a**2') >>> pm = passmanager.PassManager("test") >>> _, node = pm.apply(Square, node) >>> print pm.dump(backend.Python, node) import numpy numpy.square(a) >>> node = ast.parse('numpy.power(a,2)') >>> pm = passmanager.PassManager("test") >>> _, node = pm.apply(Square, node) >>> print pm.dump(backend.Python, node) import numpy numpy.square(a) """ POW_PATTERN = ast.BinOp(AST_any(), ast.Pow(), ast.Num(2)) POWER_PATTERN = ast.Call( ast.Attribute(ast.Name('numpy', ast.Load(), None), 'power', ast.Load()), [AST_any(), ast.Num(2)], []) def __init__(self): Transformation.__init__(self) def replace(self, value): self.update = self.need_import = True return ast.Call( ast.Attribute(ast.Name('numpy', ast.Load(), None), 'square', ast.Load()), [value], []) def visit_Module(self, node): self.need_import = False self.generic_visit(node) if self.need_import: importIt = ast.Import(names=[ast.alias(name='numpy', asname=None)]) node.body.insert(0, importIt) return node def expand_pow(self, node, n): if n == 0: return ast.Num(1) elif n == 1: return node else: node_square = self.replace(node) node_pow = self.expand_pow(node_square, n >> 1) if n & 1: return ast.BinOp(node_pow, ast.Mult(), copy.deepcopy(node)) else: return node_pow def visit_BinOp(self, node): self.generic_visit(node) if ASTMatcher(Square.POW_PATTERN).search(node): return self.replace(node.left) elif isinstance(node.op, ast.Pow) and isinstance(node.right, ast.Num): n = node.right.n if int(n) == n and n > 0: return self.expand_pow(node.left, n) else: return node else: return node def visit_Call(self, node): self.generic_visit(node) if ASTMatcher(Square.POWER_PATTERN).search(node): return self.replace(node.args[0]) else: return node
def replace(self, value): self.update = self.need_import = True return ast.Call( ast.Attribute(ast.Name('numpy', ast.Load(), None), 'square', ast.Load()), [value], [])
def class_to_graph(c, program_ctx): """Specialization of `entity_to_graph` for classes.""" converted_members = {} method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m ) members = tf_inspect.getmembers(c, predicate=method_filter) if not members: raise ValueError('Cannot convert %s: it has no member methods.' % c) class_namespace = {} for _, m in members: # Only convert the members that are directly defined by the class. if inspect_utils.getdefiningclass(m, c) is not c: continue node, _, namespace = function_to_graph( m, program_ctx=program_ctx, arg_values={}, arg_types={'self': (c.__name__, c)}, owner_type=c) if class_namespace is None: class_namespace = namespace else: class_namespace.update(namespace) converted_members[m] = node[0] namer = program_ctx.new_namer(class_namespace) class_name = namer.compiled_class_name(c.__name__, c) # TODO(mdan): This needs to be explained more thoroughly. # Process any base classes: if the superclass if of a whitelisted type, an # absolute import line is generated. Otherwise, it is marked for conversion # (as a side effect of the call to namer.compiled_class_name() followed by # program_ctx.update_name_map(namer)). output_nodes = [] renames = {} base_names = [] for base in c.__bases__: if isinstance(object, base): base_names.append('object') continue if is_whitelisted_for_graph(base): alias = namer.new_symbol(base.__name__, ()) output_nodes.append( gast.ImportFrom( module=base.__module__, names=[gast.alias(name=base.__name__, asname=alias)], level=0)) else: # This will trigger a conversion into a class with this name. alias = namer.compiled_class_name(base.__name__, base) base_names.append(alias) renames[qual_names.QN(base.__name__)] = qual_names.QN(alias) program_ctx.update_name_map(namer) # Generate the definition of the converted class. bases = [gast.Name(n, gast.Load(), None) for n in base_names] class_def = gast.ClassDef(class_name, bases=bases, keywords=[], body=list(converted_members.values()), decorator_list=[]) # Make a final pass to replace references to the class or its base classes. # Most commonly, this occurs when making super().__init__() calls. # TODO(mdan): Making direct references to superclass' superclass will fail. class_def = qual_names.resolve(class_def) renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name) class_def = ast_util.rename_symbols(class_def, renames) output_nodes.append(class_def) return output_nodes, class_name, class_namespace
def combine_(self, node, othernode, op, unary_op, register): try: if register: # this comes from an assignment, # so we must check where the value is assigned node_id, depth = self.node_to_id(node) if depth > 0: node = ast.Name(node_id, ast.Load(), None) former_unary_op = unary_op # update the type to reflect container nesting def unary_op(x): return reduce(lambda t, n: ContainerType(t), range(depth), former_unary_op(x)) # patch the op, as we no longer apply op, but infer content def op(*types): if len(types) == 1: return types[0] else: return CombinedTypes(*types) self.name_to_nodes.setdefault(node_id, set()).add(node) # only perform inter procedural combination upon stage 0 if register and self.isargument(node) and self.stage == 0: node_id, _ = self.node_to_id(node) if node not in self.result: self.result[node] = unary_op(self.result[othernode]) assert self.result[node], "found an alias with a type" parametric_type = PType(self.current, self.result[othernode]) if self.register(parametric_type): current_function = self.combiners[self.current] def translator_generator(args, op, unary_op): ''' capture args for translator generation''' def interprocedural_type_translator(s, n): translated_othernode = ast.Name( '__fake__', ast.Load(), None) s.result[translated_othernode] = ( parametric_type.instanciate( s.current, [s.result[arg] for arg in n.args])) # look for modified argument for p, effective_arg in enumerate(n.args): formal_arg = args[p] if formal_arg.id == node_id: translated_node = effective_arg break try: s.combine(translated_node, translated_othernode, op, unary_op, register=True, aliasing_type=True) except NotImplementedError: pass # this may fail when the effective # parameter is an expression except UnboundLocalError: pass # this may fail when translated_node # is a default parameter return interprocedural_type_translator translator = translator_generator( self.current.args.args, op, unary_op) # deferred combination current_function.add_combiner(translator) else: new_type = unary_op(self.result[othernode]) if node not in self.result or self.result[node] is UnknownType: self.result[node] = new_type else: self.result[node] = op(self.result[node], new_type) except UnboundableRValue: pass
def class_to_graph(c, program_ctx): """Specialization of `entity_to_graph` for classes.""" # TODO(mdan): Revisit this altogether. Not sure we still need it. converted_members = {} method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m ) members = tf_inspect.getmembers(c, predicate=method_filter) if not members: raise ValueError('Cannot convert %s: it has no member methods.' % c) class_namespace = {} for _, m in members: # Only convert the members that are directly defined by the class. if inspect_utils.getdefiningclass(m, c) is not c: continue nodes, _, namespace = function_to_graph( m, program_ctx=program_ctx, arg_values={}, arg_types={'self': (c.__name__, c)}, do_rename=False) if class_namespace is None: class_namespace = namespace else: class_namespace.update(namespace) converted_members[m] = nodes[0] namer = naming.Namer(class_namespace) class_name = namer.class_name(c.__name__) # Process any base classes: if the superclass if of a whitelisted type, an # absolute import line is generated. output_nodes = [] renames = {} base_names = [] for base in c.__bases__: if isinstance(object, base): base_names.append('object') continue if is_whitelisted_for_graph(base): alias = namer.new_symbol(base.__name__, ()) output_nodes.append( gast.ImportFrom( module=base.__module__, names=[gast.alias(name=base.__name__, asname=alias)], level=0)) else: raise NotImplementedError( 'Conversion of classes that do not directly extend classes from' ' whitelisted modules is temporarily suspended. If this breaks' ' existing code please notify the AutoGraph team immediately.') base_names.append(alias) renames[qual_names.QN(base.__name__)] = qual_names.QN(alias) # Generate the definition of the converted class. bases = [gast.Name(n, gast.Load(), None) for n in base_names] class_def = gast.ClassDef(class_name, bases=bases, keywords=[], body=list(converted_members.values()), decorator_list=[]) # Make a final pass to replace references to the class or its base classes. # Most commonly, this occurs when making super().__init__() calls. # TODO(mdan): Making direct references to superclass' superclass will fail. class_def = qual_names.resolve(class_def) renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name) class_def = ast_util.rename_symbols(class_def, renames) output_nodes.append(class_def) return output_nodes, class_name, class_namespace
def visit_Call(self, node): if not self.target: return node func = anno.getanno(node, 'func') if func in tangents.UNIMPLEMENTED_TANGENTS: raise errors.ForwardNotImplementedError(func) if func == tracing.Traceable: raise NotImplementedError('Tracing of %s is not enabled in forward mode' % quoting.unquote(node)) if func not in tangents.tangents: try: quoting.parse_function(func) except: raise ValueError('No tangent found for %s, and could not get source.' % func.__name__) # z = f(x,y) -> d[z],z = df(x,y,dx=dx,dy=dy) active_args = tuple(i for i, arg in enumerate(node.args) if isinstance(arg, gast.Name)) # TODO: Stack arguments are currently not considered # active, but for forward-mode applied to call trees, # they have to be. When we figure out how to update activity # analysis to do the right thing, we'll want to add the extra check: # `and arg.id in self.active_variables` # TODO: Duplicate of code in reverse_ad. already_counted = False for f, a in self.required: if f.__name__ == func.__name__ and set(a) == set(active_args): already_counted = True break if not already_counted: self.required.append((func, active_args)) fn_name = naming.tangent_name(func, active_args) orig_args = quoting.parse_function(func).body[0].args tangent_keywords = [] for i in active_args: grad_node = create.create_grad(node.args[i], self.namer, tangent=True) arg_grad_node = create.create_grad( orig_args.args[i], self.namer, tangent=True) grad_node.ctx = gast.Load() tangent_keywords.append( gast.keyword(arg=arg_grad_node.id, value=grad_node)) # Update the original call rhs = gast.Call( func=gast.Name(id=fn_name, ctx=gast.Load(), annotation=None), args=node.args, keywords=tangent_keywords + node.keywords) # Set self.value to False to trigger whole primal replacement self.value = False return [rhs] template_ = tangents.tangents[func] # Match the function call to the template sig = funcsigs.signature(template_) sig = sig.replace(parameters=list(sig.parameters.values())[1:]) kwargs = dict((keyword.arg, keyword.value) for keyword in node.keywords) bound_args = sig.bind(*node.args, **kwargs) bound_args.apply_defaults() # If any keyword arguments weren't passed, we fill them using the # defaults of the original function if grads.DEFAULT in bound_args.arguments.values(): # Build a mapping from names to defaults args = quoting.parse_function(func).body[0].args defaults = {} for arg, default in zip(*map(reversed, [args.args, args.defaults])): defaults[arg.id] = default for arg, default in zip(args.kwonlyargs, args.kw_defaults): if default is not None: defaults[arg.id] = default for name, value in bound_args.arguments.items(): if value is grads.DEFAULT: bound_args.arguments[name] = defaults[name] # Let's fill in the template. The first argument is the output, which # was stored in a temporary variable output_name = six.get_function_code(template_).co_varnames[0] arg_replacements = {output_name: self.tmp_node} arg_replacements.update(bound_args.arguments) # If the template uses *args, then we pack the corresponding inputs flags = six.get_function_code(template_).co_flags if flags & inspect.CO_VARARGS: to_pack = node.args[six.get_function_code(template_).co_argcount - 1:] vararg_name = six.get_function_code(template_).co_varnames[-1] target = gast.Name(annotation=None, id=vararg_name, ctx=gast.Store()) value = gast.Tuple(elts=to_pack, ctx=gast.Load()) # And we fill in the packed tuple into the template arg_replacements[six.get_function_code(template_).co_varnames[ -1]] = target tangent_node = template.replace( template_, replace_grad=template.Replace.TANGENT, namer=self.namer, **arg_replacements) # If the template uses the answer in the RHS of the tangent, # we need to make sure that the regular answer is replaced # with self.tmp_node, but that the gradient is not. We have # to be extra careful for statements like a = exp(a), because # both the target and RHS variables have the same name. tmp_grad_node = create.create_grad(self.tmp_node, self.namer, tangent=True) tmp_grad_name = tmp_grad_node.id ans_grad_node = create.create_grad(self.target, self.namer, tangent=True) for _node in tangent_node: for succ in gast.walk(_node): if isinstance(succ, gast.Name) and succ.id == tmp_grad_name: succ.id = ans_grad_node.id if flags & inspect.CO_VARARGS: # If the template packs arguments, then we have to unpack the # derivatives afterwards # We also have to update the replacements tuple then dto_pack = [ create.create_temp_grad(arg, self.namer, True) for arg in to_pack ] value = create.create_grad(target, self.namer, tangent=True) target = gast.Tuple(elts=dto_pack, ctx=gast.Store()) # Stack pops have to be special-cased, we have # to set the 'push' attribute, so we know that if we # remove this pop, we have to remove the equivalent push. # NOTE: this only works if we're doing forward-over-reverse, # where reverse is applied in joint mode, with no call tree. # Otherwise, the pushes and pops won't be matched within a single # function call. if func == tangent.pop: if len(self.metastack): anno.setanno(tangent_node[0], 'push', self.metastack.pop()) else: anno.setanno(tangent_node[0], 'push', None) return tangent_node
def get_while_stmt_nodes(self, node): # TODO: consider while - else in python if not self.name_visitor.is_control_flow_loop(node): return [node] loop_var_names, create_var_names = self.name_visitor.get_loop_var_names( node) new_stmts = [] # Python can create variable in loop and use it out of loop, E.g. # # while x < 10: # x += 1 # y = x # z = y # # We need to create static variable for those variables for name in create_var_names: if "." not in name: new_stmts.append(create_static_variable_gast_node(name)) # while x < 10 in dygraph should be convert into static tensor < 10 for name in loop_var_names: new_stmts.append(to_static_variable_gast_node(name)) logical_op_transformer = LogicalOpTransformer(node.test) cond_value_node = logical_op_transformer.transform() condition_func_node = gast.FunctionDef( name=unique_name.generate(WHILE_CONDITION_PREFIX), args=gast.arguments(args=[ gast.Name(id=name, ctx=gast.Param(), annotation=None, type_comment=None) for name in loop_var_names ], posonlyargs=[], vararg=None, kwonlyargs=[], kw_defaults=None, kwarg=None, defaults=[]), body=[gast.Return(value=cond_value_node)], decorator_list=[], returns=None, type_comment=None) for name in loop_var_names: if "." in name: rename_transformer = RenameTransformer(condition_func_node) rename_transformer.rename( name, unique_name.generate(GENERATE_VARIABLE_PREFIX)) new_stmts.append(condition_func_node) new_body = node.body new_body.append( gast.Return( value=generate_name_node(loop_var_names, ctx=gast.Load()))) body_func_node = gast.FunctionDef( name=unique_name.generate(WHILE_BODY_PREFIX), args=gast.arguments(args=[ gast.Name(id=name, ctx=gast.Param(), annotation=None, type_comment=None) for name in loop_var_names ], posonlyargs=[], vararg=None, kwonlyargs=[], kw_defaults=None, kwarg=None, defaults=[]), body=new_body, decorator_list=[], returns=None, type_comment=None) for name in loop_var_names: if "." in name: rename_transformer = RenameTransformer(body_func_node) rename_transformer.rename( name, unique_name.generate(GENERATE_VARIABLE_PREFIX)) new_stmts.append(body_func_node) while_loop_node = create_while_node(condition_func_node.name, body_func_node.name, loop_var_names) new_stmts.append(while_loop_node) return new_stmts
def get_for_stmt_nodes(self, node): # TODO: consider for - else in python if not self.name_visitor.is_control_flow_loop(node): return [node] # TODO: support non-range case range_call_node = self.get_for_range_node(node) if range_call_node is None: return [node] if not isinstance(node.target, gast.Name): return [node] iter_var_name = node.target.id init_stmt, cond_stmt, change_stmt = self.get_for_args_stmts( iter_var_name, range_call_node.args) loop_var_names, create_var_names = self.name_visitor.get_loop_var_names( node) new_stmts = [] # Python can create variable in loop and use it out of loop, E.g. # # for x in range(10): # y += x # print(x) # x = 10 # # We need to create static variable for those variables for name in create_var_names: if "." not in name: new_stmts.append(create_static_variable_gast_node(name)) new_stmts.append(init_stmt) # for x in range(10) in dygraph should be convert into static tensor + 1 <= 10 for name in loop_var_names: new_stmts.append(to_static_variable_gast_node(name)) condition_func_node = gast.FunctionDef( name=unique_name.generate(FOR_CONDITION_PREFIX), args=gast.arguments(args=[ gast.Name(id=name, ctx=gast.Param(), annotation=None, type_comment=None) for name in loop_var_names ], posonlyargs=[], vararg=None, kwonlyargs=[], kw_defaults=None, kwarg=None, defaults=[]), body=[gast.Return(value=cond_stmt)], decorator_list=[], returns=None, type_comment=None) for name in loop_var_names: if "." in name: rename_transformer = RenameTransformer(condition_func_node) rename_transformer.rename( name, unique_name.generate(GENERATE_VARIABLE_PREFIX)) new_stmts.append(condition_func_node) new_body = node.body new_body.append(change_stmt) new_body.append( gast.Return( value=generate_name_node(loop_var_names, ctx=gast.Load()))) body_func_node = gast.FunctionDef( name=unique_name.generate(FOR_BODY_PREFIX), args=gast.arguments(args=[ gast.Name(id=name, ctx=gast.Param(), annotation=None, type_comment=None) for name in loop_var_names ], posonlyargs=[], vararg=None, kwonlyargs=[], kw_defaults=None, kwarg=None, defaults=[]), body=new_body, decorator_list=[], returns=None, type_comment=None) for name in loop_var_names: if "." in name: rename_transformer = RenameTransformer(body_func_node) rename_transformer.rename( name, unique_name.generate(GENERATE_VARIABLE_PREFIX)) new_stmts.append(body_func_node) while_loop_node = create_while_node(condition_func_node.name, body_func_node.name, loop_var_names) new_stmts.append(while_loop_node) return new_stmts
def makeattr(): return ast.Attribute( value=ast.Name(id='__builtin__', ctx=ast.Load(), annotation=None), attr='map', ctx=ast.Load())
""" Optimization for Python costly pattern. """ from pythran.conversion import mangle from pythran.analyses import Check, Placeholder from pythran.passmanager import Transformation import gast as ast # Tuple of : (pattern, replacement) # replacement have to be a lambda function to have a new ast to replace when # replacement is inserted in main ast know_pattern = [ # __builtin__.len(__builtin__.set(X)) => __builtin__.pythran.len_set(X) (ast.Call(func=ast.Attribute(value=ast.Name(id='__builtin__', ctx=ast.Load(), annotation=None), attr="len", ctx=ast.Load()), args=[ ast.Call(func=ast.Attribute(value=ast.Name(id='__builtin__', ctx=ast.Load(), annotation=None), attr="set", ctx=ast.Load()), args=[Placeholder(0)], keywords=[]) ], keywords=[]), lambda: ast.Call(func=ast.Attribute(value=ast.Attribute(value=ast.Name( id='__builtin__', ctx=ast.Load(), annotation=None), attr="pythran",
def _consume_args(self): if self._arg_accumulator: self._argspec.append( gast.Tuple(elts=self._arg_accumulator, ctx=gast.Load())) self._arg_accumulator = []
def combine_(self, node, othernode, op, unary_op, register): try: # This comes from an assignment,so we must check where the value is # assigned if register: try: node_id, depth = self.node_to_id(node) if depth: node = ast.Name(node_id, ast.Load(), None, None) former_unary_op = unary_op # update the type to reflect container nesting def merge_container_type(ty, index): # integral index make it possible to correctly # update tuple type if isinstance(index, int): kty = self.builder.NamedType( 'std::integral_constant<long,{}>'.format( index)) return self.builder.IndexableContainerType( kty, ty) else: return self.builder.ContainerType(ty) def unary_op(x): return reduce(merge_container_type, depth, former_unary_op(x)) # patch the op, as we no longer apply op, # but infer content op = self.combined self.name_to_nodes[node_id].append(node) except UnboundableRValue: pass # only perform inter procedural combination upon stage 0 if register and self.isargument(node) and self.stage == 0: node_id, _ = self.node_to_id(node) if node not in self.result: self.result[node] = unary_op(self.result[othernode]) assert self.result[node], "found an alias with a type" parametric_type = self.builder.PType(self.current, self.result[othernode]) if self.register(parametric_type): current_function = self.combiners[self.current] def translator_generator(args, op, unary_op): ''' capture args for translator generation''' def interprocedural_type_translator(s, n): translated_othernode = ast.Name( '__fake__', ast.Load(), None, None) s.result[translated_othernode] = ( parametric_type.instanciate( s.current, [s.result[arg] for arg in n.args])) # look for modified argument for p, effective_arg in enumerate(n.args): formal_arg = args[p] if formal_arg.id == node_id: translated_node = effective_arg break try: s.combine(translated_node, translated_othernode, op, unary_op, register=True, aliasing_type=True) except NotImplementedError: pass # this may fail when the effective # parameter is an expression except UnboundLocalError: pass # this may fail when translated_node # is a default parameter return interprocedural_type_translator translator = translator_generator( self.current.args.args, op, unary_op) # deferred combination current_function.add_combiner(translator) else: new_type = unary_op(self.result[othernode]) UnknownType = self.builder.UnknownType if node not in self.result or self.result[node] is UnknownType: self.result[node] = new_type else: if isinstance(self.result[node], tuple): raise UnboundableRValue self.result[node] = op(self.result[node], new_type) except UnboundableRValue: pass
def sub(): return ast.Tuple(Placeholder(0), ast.Load())
def convert_class_to_ast(c, program_ctx): """Specialization of `convert_entity_to_ast` for classes.""" # TODO(mdan): Revisit this altogether. Not sure we still need it. converted_members = {} method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m ) members = tf_inspect.getmembers(c, predicate=method_filter) if not members: raise ValueError('cannot convert %s: no member methods' % c) # TODO(mdan): Don't clobber namespaces for each method in one class namespace. # The assumption that one namespace suffices for all methods only holds if # all methods were defined in the same module. # If, instead, functions are imported from multiple modules and then spliced # into the class, then each function has its own globals and __future__ # imports that need to stay separate. # For example, C's methods could both have `global x` statements referring to # mod1.x and mod2.x, but using one namespace for C would cause a conflict. # from mod1 import f1 # from mod2 import f2 # class C(object): # method1 = f1 # method2 = f2 class_namespace = {} future_features = None for _, m in members: # Only convert the members that are directly defined by the class. if inspect_utils.getdefiningclass(m, c) is not c: continue (node, ), _, entity_info = convert_func_to_ast(m, program_ctx=program_ctx, do_rename=False) class_namespace.update(entity_info.namespace) converted_members[m] = node # TODO(mdan): Similarly check the globals. if future_features is None: future_features = entity_info.future_features elif frozenset(future_features) ^ frozenset( entity_info.future_features): # Note: we can support this case if ever needed. raise ValueError( 'cannot convert {}: if has methods built with mismatched future' ' features: {} and {}'.format(c, future_features, entity_info.future_features)) namer = naming.Namer(class_namespace) class_name = namer.class_name(c.__name__) # Process any base classes: if the superclass if of a whitelisted type, an # absolute import line is generated. output_nodes = [] renames = {} base_names = [] for base in c.__bases__: if isinstance(object, base): base_names.append('object') continue if is_whitelisted_for_graph(base): alias = namer.new_symbol(base.__name__, ()) output_nodes.append( gast.ImportFrom( module=base.__module__, names=[gast.alias(name=base.__name__, asname=alias)], level=0)) else: raise NotImplementedError( 'Conversion of classes that do not directly extend classes from' ' whitelisted modules is temporarily suspended. If this breaks' ' existing code please notify the AutoGraph team immediately.') base_names.append(alias) renames[qual_names.QN(base.__name__)] = qual_names.QN(alias) # Generate the definition of the converted class. bases = [gast.Name(n, gast.Load(), None) for n in base_names] class_def = gast.ClassDef(class_name, bases=bases, keywords=[], body=list(converted_members.values()), decorator_list=[]) # Make a final pass to replace references to the class or its base classes. # Most commonly, this occurs when making super().__init__() calls. # TODO(mdan): Making direct references to superclass' superclass will fail. class_def = qual_names.resolve(class_def) renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name) class_def = ast_util.rename_symbols(class_def, renames) output_nodes.append(class_def) # TODO(mdan): Find a way better than forging this object. entity_info = transformer.EntityInfo(source_code=None, source_file=None, future_features=future_features, namespace=class_namespace) return output_nodes, class_name, entity_info
def visit_If(self, node): if node.test not in self.static_expressions: return self.generic_visit(node) imported_ids = self.gather(ImportedIds, node) assigned_ids_left = self.escaping_ids(node, node.body) assigned_ids_right = self.escaping_ids(node, node.orelse) assigned_ids_both = assigned_ids_left.union(assigned_ids_right) imported_ids.update(i for i in assigned_ids_left if i not in assigned_ids_right) imported_ids.update(i for i in assigned_ids_right if i not in assigned_ids_left) imported_ids = sorted(imported_ids) assigned_ids = sorted(assigned_ids_both) fbody = self.make_fake(node.body) true_has_return = self.gather(HasReturn, fbody) true_has_break = self.gather(HasBreak, fbody) true_has_cont = self.gather(HasContinue, fbody) felse = self.make_fake(node.orelse) false_has_return = self.gather(HasReturn, felse) false_has_break = self.gather(HasBreak, felse) false_has_cont = self.gather(HasContinue, felse) has_return = true_has_return or false_has_return has_break = true_has_break or false_has_break has_cont = true_has_cont or false_has_cont self.generic_visit(node) func_true = outline(self.true_name(), imported_ids, assigned_ids, node.body, has_return, has_break, has_cont) func_false = outline(self.false_name(), imported_ids, assigned_ids, node.orelse, has_return, has_break, has_cont) self.new_functions.extend((func_true, func_false)) actual_call = self.make_dispatcher(node.test, func_true, func_false, imported_ids) # variable modified within the static_if expected_return = [ast.Name(ii, ast.Store(), None, None) for ii in assigned_ids] self.update = True # name for various variables resulting from the static_if n = len(self.new_functions) status_n = "$status{}".format(n) return_n = "$return{}".format(n) cont_n = "$cont{}".format(n) if has_return: cfg = self.cfgs[-1] always_return = all(isinstance(x, (ast.Return, ast.Yield)) for x in cfg[node]) always_return &= true_has_return and false_has_return fast_return = [ast.Name(status_n, ast.Store(), None, None), ast.Name(return_n, ast.Store(), None, None), ast.Name(cont_n, ast.Store(), None, None)] if always_return: return [ast.Assign([ast.Tuple(fast_return, ast.Store())], actual_call, None), ast.Return(ast.Name(return_n, ast.Load(), None, None))] else: cont_ass = self.make_control_flow_handlers(cont_n, status_n, expected_return, has_cont, has_break) cmpr = ast.Compare(ast.Name(status_n, ast.Load(), None, None), [ast.Eq()], [ast.Constant(EARLY_RET, None)]) return [ast.Assign([ast.Tuple(fast_return, ast.Store())], actual_call, None), ast.If(cmpr, [ast.Return(ast.Name(return_n, ast.Load(), None, None))], cont_ass)] elif has_break or has_cont: cont_ass = self.make_control_flow_handlers(cont_n, status_n, expected_return, has_cont, has_break) fast_return = [ast.Name(status_n, ast.Store(), None, None), ast.Name(cont_n, ast.Store(), None, None)] return [ast.Assign([ast.Tuple(fast_return, ast.Store())], actual_call, None)] + cont_ass elif expected_return: return ast.Assign([ast.Tuple(expected_return, ast.Store())], actual_call, None) else: return ast.Expr(actual_call)
def _to_reference_list(self, names): return gast.List([self._to_reference(name) for name in names], ctx=gast.Load())
def _process_body_item(self, node): if isinstance(node, gast.Assign) and (node.value.id == 'y'): if_node = gast.If(gast.Name('x', gast.Load(), None), [node], []) return if_node, if_node.body return node, None
def analyse(node, env, non_generic=None): """Computes the type of the expression given by node. The type of the node is computed in the context of the context of the supplied type environment env. Data types can be introduced into the language simply by having a predefined set of identifiers in the initial environment. Environment; this way there is no need to change the syntax or more importantly, the type-checking program when extending the language. Args: node: The root of the abstract syntax tree. env: The type environment is a mapping of expression identifier names to type assignments. non_generic: A set of non-generic variables, or None Returns: The computed type of the expression. Raises: InferenceError: The type of the expression could not be inferred, PythranTypeError: InferenceError with user friendly message + location """ if non_generic is None: non_generic = set() # expr if isinstance(node, gast.Name): if isinstance(node.ctx, (gast.Store)): new_type = TypeVariable() non_generic.add(new_type) env[node.id] = new_type return get_type(node.id, env, non_generic) elif isinstance(node, gast.Num): if isinstance(node.n, (int, long)): return Integer() elif isinstance(node.n, float): return Float() elif isinstance(node.n, complex): return Complex() else: raise NotImplementedError elif isinstance(node, gast.Str): return Str() elif isinstance(node, gast.Compare): left_type = analyse(node.left, env, non_generic) comparators_type = [analyse(comparator, env, non_generic) for comparator in node.comparators] ops_type = [analyse(op, env, non_generic) for op in node.ops] prev_type = left_type result_type = TypeVariable() for op_type, comparator_type in zip(ops_type, comparators_type): try: unify(Function([prev_type, comparator_type], result_type), op_type) prev_type = comparator_type except InferenceError: raise PythranTypeError( "Invalid comparison, between `{}` and `{}`".format( prev_type, comparator_type ), node) return result_type elif isinstance(node, gast.Call): if is_getattr(node): self_type = analyse(node.args[0], env, non_generic) attr_name = node.args[1].s _, attr_signature = attributes[attr_name] attr_type = tr(attr_signature) result_type = TypeVariable() try: unify(Function([self_type], result_type), attr_type) except InferenceError: if isinstance(prune(attr_type), MultiType): msg = 'no attribute found, tried:\n{}'.format(attr_type) else: msg = 'tried {}'.format(attr_type) raise PythranTypeError( "Invalid attribute for getattr call with self" "of type `{}`, {}".format(self_type, msg), node) else: fun_type = analyse(node.func, env, non_generic) arg_types = [analyse(arg, env, non_generic) for arg in node.args] result_type = TypeVariable() try: unify(Function(arg_types, result_type), fun_type) except InferenceError: # recover original type fun_type = analyse(node.func, env, non_generic) if isinstance(prune(fun_type), MultiType): msg = 'no overload found, tried:\n{}'.format(fun_type) else: msg = 'tried {}'.format(fun_type) raise PythranTypeError( "Invalid argument type for function call to " "`Callable[[{}], ...]`, {}" .format(', '.join('{}'.format(at) for at in arg_types), msg), node) return result_type elif isinstance(node, gast.IfExp): test_type = analyse(node.test, env, non_generic) unify(Function([test_type], Bool()), tr(MODULES['__builtin__']['bool_'])) if is_test_is_none(node.test): none_id = node.test.left.id body_env = env.copy() body_env[none_id] = NoneType else: none_id = None body_env = env body_type = analyse(node.body, body_env, non_generic) if none_id: orelse_env = env.copy() if is_option_type(env[none_id]): orelse_env[none_id] = prune(env[none_id]).types[0] else: orelse_env[none_id] = TypeVariable() else: orelse_env = env orelse_type = analyse(node.orelse, orelse_env, non_generic) try: return merge_unify(body_type, orelse_type) except InferenceError: raise PythranTypeError( "Incompatible types from different branches:" "`{}` and `{}`".format( body_type, orelse_type ), node ) elif isinstance(node, gast.UnaryOp): operand_type = analyse(node.operand, env, non_generic) op_type = analyse(node.op, env, non_generic) result_type = TypeVariable() try: unify(Function([operand_type], result_type), op_type) return result_type except InferenceError: raise PythranTypeError( "Invalid operand for `{}`: `{}`".format( symbol_of[type(node.op)], operand_type ), node ) elif isinstance(node, gast.BinOp): left_type = analyse(node.left, env, non_generic) op_type = analyse(node.op, env, non_generic) right_type = analyse(node.right, env, non_generic) result_type = TypeVariable() try: unify(Function([left_type, right_type], result_type), op_type) except InferenceError: raise PythranTypeError( "Invalid operand for `{}`: `{}` and `{}`".format( symbol_of[type(node.op)], left_type, right_type), node ) return result_type elif isinstance(node, gast.Pow): return tr(MODULES['numpy']['power']) elif isinstance(node, gast.Sub): return tr(MODULES['operator_']['sub']) elif isinstance(node, (gast.USub, gast.UAdd)): return tr(MODULES['operator_']['pos']) elif isinstance(node, (gast.Eq, gast.NotEq, gast.Lt, gast.LtE, gast.Gt, gast.GtE, gast.Is, gast.IsNot)): return tr(MODULES['operator_']['eq']) elif isinstance(node, (gast.In, gast.NotIn)): contains_sig = tr(MODULES['operator_']['contains']) contains_sig.types[:-1] = reversed(contains_sig.types[:-1]) return contains_sig elif isinstance(node, gast.Add): return tr(MODULES['operator_']['add']) elif isinstance(node, gast.Mult): return tr(MODULES['operator_']['mul']) elif isinstance(node, (gast.Div, gast.FloorDiv)): return tr(MODULES['operator_']['floordiv']) elif isinstance(node, gast.Mod): return tr(MODULES['operator_']['mod']) elif isinstance(node, (gast.LShift, gast.RShift)): return tr(MODULES['operator_']['lshift']) elif isinstance(node, (gast.BitXor, gast.BitAnd, gast.BitOr)): return tr(MODULES['operator_']['lshift']) elif isinstance(node, gast.List): new_type = TypeVariable() for elt in node.elts: elt_type = analyse(elt, env, non_generic) try: unify(new_type, elt_type) except InferenceError: raise PythranTypeError( "Incompatible list element type `{}` and `{}`".format( new_type, elt_type), node ) return List(new_type) elif isinstance(node, gast.Set): new_type = TypeVariable() for elt in node.elts: elt_type = analyse(elt, env, non_generic) try: unify(new_type, elt_type) except InferenceError: raise PythranTypeError( "Incompatible set element type `{}` and `{}`".format( new_type, elt_type), node ) return Set(new_type) elif isinstance(node, gast.Dict): new_key_type = TypeVariable() for key in node.keys: key_type = analyse(key, env, non_generic) try: unify(new_key_type, key_type) except InferenceError: raise PythranTypeError( "Incompatible dict key type `{}` and `{}`".format( new_key_type, key_type), node ) new_value_type = TypeVariable() for value in node.values: value_type = analyse(value, env, non_generic) try: unify(new_value_type, value_type) except InferenceError: raise PythranTypeError( "Incompatible dict value type `{}` and `{}`".format( new_value_type, value_type), node ) return Dict(new_key_type, new_value_type) elif isinstance(node, gast.Tuple): return Tuple([analyse(elt, env, non_generic) for elt in node.elts]) elif isinstance(node, gast.Index): return analyse(node.value, env, non_generic) elif isinstance(node, gast.Slice): def unify_int_or_none(t, name): try: unify(t, Integer()) except InferenceError: try: unify(t, NoneType) except InferenceError: raise PythranTypeError( "Invalid slice {} type `{}`, expecting int or None" .format(name, t) ) if node.lower: lower_type = analyse(node.lower, env, non_generic) unify_int_or_none(lower_type, 'lower bound') else: lower_type = Integer() if node.upper: upper_type = analyse(node.upper, env, non_generic) unify_int_or_none(upper_type, 'upper bound') else: upper_type = Integer() if node.step: step_type = analyse(node.step, env, non_generic) unify_int_or_none(step_type, 'step') else: step_type = Integer() return Slice elif isinstance(node, gast.ExtSlice): return [analyse(dim, env, non_generic) for dim in node.dims] elif isinstance(node, gast.NameConstant): if node.value is None: return env['None'] elif isinstance(node, gast.Subscript): new_type = TypeVariable() value_type = prune(analyse(node.value, env, non_generic)) try: slice_type = prune(analyse(node.slice, env, non_generic)) except PythranTypeError as e: raise PythranTypeError(e.msg, node) if isinstance(node.slice, gast.ExtSlice): nbslice = len(node.slice.dims) dtype = TypeVariable() try: unify(Array(dtype, nbslice), clone(value_type)) except InferenceError: raise PythranTypeError( "Dimension mismatch when slicing `{}`".format(value_type), node) return TypeVariable() # FIXME elif isinstance(node.slice, gast.Index): # handle tuples in a special way isnum = isinstance(node.slice.value, gast.Num) if isnum and is_tuple_type(value_type): try: unify(prune(prune(value_type.types[0]).types[0]) .types[node.slice.value.n], new_type) return new_type except IndexError: raise PythranTypeError( "Invalid tuple indexing, " "out-of-bound index `{}` for type `{}`".format( node.slice.value.n, value_type), node) try: unify(tr(MODULES['operator_']['getitem']), Function([value_type, slice_type], new_type)) except InferenceError: raise PythranTypeError( "Invalid subscripting of `{}` by `{}`".format( value_type, slice_type), node) return new_type return new_type elif isinstance(node, gast.Attribute): from pythran.utils import attr_to_path obj, path = attr_to_path(node) if obj.signature is typing.Any: return TypeVariable() else: return tr(obj) # stmt elif isinstance(node, gast.Import): for alias in node.names: if alias.name not in MODULES: raise NotImplementedError("unknown module: %s " % alias.name) if alias.asname is None: target = alias.name else: target = alias.asname env[target] = tr(MODULES[alias.name]) return env elif isinstance(node, gast.ImportFrom): if node.module not in MODULES: raise NotImplementedError("unknown module: %s" % node.module) for alias in node.names: if alias.name not in MODULES[node.module]: raise NotImplementedError( "unknown function: %s in %s" % (alias.name, node.module)) if alias.asname is None: target = alias.name else: target = alias.asname env[target] = tr(MODULES[node.module][alias.name]) return env elif isinstance(node, gast.FunctionDef): ftypes = [] for i in range(1 + len(node.args.defaults)): old_type = env[node.name] new_env = env.copy() new_non_generic = non_generic.copy() # reset return special variables new_env.pop('@ret', None) new_env.pop('@gen', None) hy = HasYield() for stmt in node.body: hy.visit(stmt) new_env['@gen'] = hy.has_yield arg_types = [] istop = len(node.args.args) - i for arg in node.args.args[:istop]: arg_type = TypeVariable() new_env[arg.id] = arg_type new_non_generic.add(arg_type) arg_types.append(arg_type) for arg, expr in zip(node.args.args[istop:], node.args.defaults[-i:]): arg_type = analyse(expr, new_env, new_non_generic) new_env[arg.id] = arg_type analyse_body(node.body, new_env, new_non_generic) result_type = new_env.get('@ret', NoneType) if new_env['@gen']: result_type = Generator(result_type) ftype = Function(arg_types, result_type) ftypes.append(ftype) if len(ftypes) == 1: ftype = ftypes[0] env[node.name] = ftype else: env[node.name] = MultiType(ftypes) return env elif isinstance(node, gast.Module): analyse_body(node.body, env, non_generic) return env elif isinstance(node, (gast.Pass, gast.Break, gast.Continue)): return env elif isinstance(node, gast.Expr): analyse(node.value, env, non_generic) return env elif isinstance(node, gast.Delete): for target in node.targets: if isinstance(target, gast.Name): if target.id in env: del env[target.id] else: raise PythranTypeError( "Invalid del: unbound identifier `{}`".format( target.id), node) else: analyse(target, env, non_generic) return env elif isinstance(node, gast.Print): if node.dest is not None: analyse(node.dest, env, non_generic) for value in node.values: analyse(value, env, non_generic) return env elif isinstance(node, gast.Assign): defn_type = analyse(node.value, env, non_generic) for target in node.targets: target_type = analyse(target, env, non_generic) try: unify(target_type, defn_type) except InferenceError: raise PythranTypeError( "Invalid assignment from type `{}` to type `{}`".format( target_type, defn_type), node) return env elif isinstance(node, gast.AugAssign): # FIMXE: not optimal: evaluates type of node.value twice fake_target = deepcopy(node.target) fake_target.ctx = gast.Load() fake_op = gast.BinOp(fake_target, node.op, node.value) gast.copy_location(fake_op, node) analyse(fake_op, env, non_generic) value_type = analyse(node.value, env, non_generic) target_type = analyse(node.target, env, non_generic) try: unify(target_type, value_type) except InferenceError: raise PythranTypeError( "Invalid update operand for `{}`: `{}` and `{}`".format( symbol_of[type(node.op)], value_type, target_type ), node ) return env elif isinstance(node, gast.Raise): return env # TODO elif isinstance(node, gast.Return): if env['@gen']: return env if node.value is None: ret_type = NoneType else: ret_type = analyse(node.value, env, non_generic) if '@ret' in env: try: ret_type = merge_unify(env['@ret'], ret_type) except InferenceError: raise PythranTypeError( "function may returns with incompatible types " "`{}` and `{}`".format(env['@ret'], ret_type), node ) env['@ret'] = ret_type return env elif isinstance(node, gast.Yield): assert env['@gen'] assert node.value is not None if node.value is None: ret_type = NoneType else: ret_type = analyse(node.value, env, non_generic) if '@ret' in env: try: ret_type = merge_unify(env['@ret'], ret_type) except InferenceError: raise PythranTypeError( "function may yields incompatible types " "`{}` and `{}`".format(env['@ret'], ret_type), node ) env['@ret'] = ret_type return env elif isinstance(node, gast.For): iter_type = analyse(node.iter, env, non_generic) target_type = analyse(node.target, env, non_generic) unify(Collection(TypeVariable(), TypeVariable(), TypeVariable(), target_type), iter_type) analyse_body(node.body, env, non_generic) analyse_body(node.orelse, env, non_generic) return env elif isinstance(node, gast.If): test_type = analyse(node.test, env, non_generic) unify(Function([test_type], Bool()), tr(MODULES['__builtin__']['bool_'])) body_env = env.copy() body_non_generic = non_generic.copy() if is_test_is_none(node.test): none_id = node.test.left.id body_env[none_id] = NoneType else: none_id = None analyse_body(node.body, body_env, body_non_generic) orelse_env = env.copy() orelse_non_generic = non_generic.copy() if none_id: if is_option_type(env[none_id]): orelse_env[none_id] = prune(env[none_id]).types[0] else: orelse_env[none_id] = TypeVariable() analyse_body(node.orelse, orelse_env, orelse_non_generic) for var in body_env: if var not in env: if var in orelse_env: try: new_type = merge_unify(body_env[var], orelse_env[var]) except InferenceError: raise PythranTypeError( "Incompatible types from different branches for " "`{}`: `{}` and `{}`".format( var, body_env[var], orelse_env[var] ), node ) else: new_type = body_env[var] env[var] = new_type for var in orelse_env: if var not in env: # may not be unified by the prev loop if a del occured if var in body_env: new_type = merge_unify(orelse_env[var], body_env[var]) else: new_type = orelse_env[var] env[var] = new_type if none_id: try: new_type = merge_unify(body_env[none_id], orelse_env[none_id]) except InferenceError: msg = ("Inconsistent types while merging values of `{}` from " "conditional branches: `{}` and `{}`") err = msg.format(none_id, body_env[none_id], orelse_env[none_id]) raise PythranTypeError(err, node) env[none_id] = new_type return env elif isinstance(node, gast.While): test_type = analyse(node.test, env, non_generic) unify(Function([test_type], Bool()), tr(MODULES['__builtin__']['bool_'])) analyse_body(node.body, env, non_generic) analyse_body(node.orelse, env, non_generic) return env elif isinstance(node, gast.Try): analyse_body(node.body, env, non_generic) for handler in node.handlers: analyse(handler, env, non_generic) analyse_body(node.orelse, env, non_generic) analyse_body(node.finalbody, env, non_generic) return env elif isinstance(node, gast.ExceptHandler): if(node.name): new_type = ExceptionType non_generic.add(new_type) if node.name.id in env: unify(env[node.name.id], new_type) else: env[node.name.id] = new_type analyse_body(node.body, env, non_generic) return env elif isinstance(node, gast.Assert): if node.msg: analyse(node.msg, env, non_generic) analyse(node.test, env, non_generic) return env elif isinstance(node, gast.UnaryOp): operand_type = analyse(node.operand, env, non_generic) return_type = TypeVariable() op_type = analyse(node.op, env, non_generic) unify(Function([operand_type], return_type), op_type) return return_type elif isinstance(node, gast.Invert): return MultiType([Function([Bool()], Integer()), Function([Integer()], Integer())]) elif isinstance(node, gast.Not): return tr(MODULES['__builtin__']['bool_']) elif isinstance(node, gast.BoolOp): op_type = analyse(node.op, env, non_generic) value_types = [analyse(value, env, non_generic) for value in node.values] for value_type in value_types: unify(Function([value_type], Bool()), tr(MODULES['__builtin__']['bool_'])) return_type = TypeVariable() prev_type = value_types[0] for value_type in value_types[1:]: unify(Function([prev_type, value_type], return_type), op_type) prev_type = value_type return return_type elif isinstance(node, (gast.And, gast.Or)): x_type = TypeVariable() return MultiType([ Function([x_type, x_type], x_type), Function([TypeVariable(), TypeVariable()], TypeVariable()), ]) raise RuntimeError("Unhandled syntax node {0}".format(type(node)))
def tokenize(s): '''A simple contextual "parser" for an OpenMP string''' # not completely satisfying if there are strings in if expressions out = '' par_count = 0 curr_index = 0 in_reserved_context = False in_declare = False in_shared = in_private = False while curr_index < len(s): bounds = [] if in_declare and is_declare_typename(s, curr_index, bounds): start, stop = bounds pytypes = parse_pytypes(s[start:stop]) out += ', '.join(map(pytype_to_ctype, pytypes)) curr_index = stop continue m = re.match(r'^([a-zA-Z_]\w*)', s[curr_index:]) if m: word = m.group(0) curr_index += len(word) if (in_reserved_context or (in_declare and word in declare_keywords) or (par_count == 0 and word in keywords)): out += word in_reserved_context = word in reserved_contex in_declare |= word == 'declare' in_private |= word == 'private' in_shared |= word == 'shared' else: out += '{}' self.deps.append(ast.Name(word, ast.Load(), None, None)) isattr = re.match(r'^\s*(\.\s*[a-zA-Z_]\w*)', s[curr_index:]) if isattr: attr = isattr.group(0) curr_index += len(attr) self.deps[-1] = ast.Attribute( self.deps[-1], attr[1:], ast.Load()) if in_private: self.private_deps.append(self.deps[-1]) if in_shared: self.shared_deps.append(self.deps[-1]) elif s[curr_index] == '(': par_count += 1 curr_index += 1 out += '(' elif s[curr_index] == ')': par_count -= 1 curr_index += 1 out += ')' if par_count == 0: in_reserved_context = False in_shared = in_private = False else: if s[curr_index] in ',:': in_reserved_context = False out += s[curr_index] curr_index += 1 return out