def create_fill_constant_node(name, value): func_code = "{} = paddle.fluid.layers.fill_constant(shape=[1], ".format( name) if isinstance(value, bool): func_code += "dtype='bool', value={}, name='{}')".format(value, name) return gast.parse(func_code).body[0] if isinstance(value, float): func_code += "dtype='float64', value={}, name='{}')".format(value, name) return gast.parse(func_code).body[0] if isinstance(value, int): func_code += "dtype='int64', value={}, name='{}')".format(value, name) return gast.parse(func_code).body[0]
def _replace_after_node_to_if_in_stmt_list(self, stmt_list, node, return_name, parent_node_of_return): i = index_in_list(stmt_list, node) if i < 0 or i >= len(stmt_list): return False if i == len(stmt_list) - 1: # No need to add, we consider this as added successfully return True if_stmt = gast.If(test=gast.UnaryOp(op=gast.Not(), operand=gast.Name( id=return_name, ctx=gast.Store(), annotation=None, type_comment=None)), body=stmt_list[i + 1:], orelse=[]) stmt_list[i + 1:] = [if_stmt] # Here assume that the parent node of return is gast.If if isinstance(parent_node_of_return, gast.If): # Prepend control flow boolean nodes such as '__return@1 = False' node_str = "{} = _jst.create_bool_as_type({}, False)".format( return_name, ast_to_source_code(parent_node_of_return.test).strip()) assign_false_node = gast.parse(node_str).body[0] stmt_list[i:i] = [assign_false_node] return True
def test_nested_loop_vars(self): func = self.nested_for_loop_func test_func = inspect.getsource(func) gast_root = gast.parse(test_func) name_visitor = NameVisitor(gast_root) self.loop_var_names = [ set(["j", "two"]), set(["i", "three", "b"]), set(["i", "j"]) ] self.create_var_names = [set(), set(["b"]), set()] i = 0 for node in gast.walk(gast_root): if isinstance(node, (gast.While, gast.For)): loop_var_names, create_var_names = name_visitor.get_loop_var_names( node) self.assertEqual( loop_var_names, self.loop_var_names[i], msg="loop_var_names : {}, \nexpected loop_var_names : {}". format(loop_var_names, self.loop_var_names[i])) self.assertEqual( create_var_names, self.create_var_names[i], msg= "i = {}\ncreate_var_names : {}, \nexpected create_var_names : {}" .format(i, create_var_names, self.create_var_names[i])) i += 1
def _replace_pop(self, node): """ Replace a pop statement for a list or dict. For example: list_a = [0,1,2,3,4] x = list_a.pop() # --> convert_pop(list_a) y = list_a.pop(1) # --> convert_pop(list_a, 1) dict_a = {"red":0, "blue":1, "yellow":2} m = dict_a.pop("red") # --> convert_pop(dict_a, "red") n = dict_a.pop("black", 3) # --> convert_pop(dict_a, "black", 3) """ assert isinstance(node, gast.Call) assert isinstance(node.func, gast.Attribute) target_node = node.func.value target_str = ast_to_source_code(target_node).strip() args_str = [ast_to_source_code(arg).strip() for arg in node.args] # NOTE(liym27): # 1. pop stmt for a list if len(args_str) == 0 # 2. pop stmt for a list or dict if len(args_str) == 1 # 3. pop stmt for a dict if len(args_str) == 2 if len(args_str) <= 2: new_pop_str = "_jst.convert_pop({}, {})"\ .format(target_str, ",".join(args_str)) new_pop_node = gast.parse(new_pop_str).body[0].value return new_pop_node else: return node
def _create_tensor_array(self, value_node): # Although `dtype='float32'`, other types such as `int32` can also be supported init_value = ast_to_source_code(value_node).strip() func_code = "paddle.tensor.create_array('float32', {})".format( init_value) func_node = gast.parse(func_code).body[0].value return func_node
def create_convert_shape_node(var_shape_node, slice_node=None, in_control_flow=False): assert isinstance(var_shape_node, (gast.Attribute, gast.Subscript)) if isinstance(var_shape_node, gast.Attribute): args = [ast_to_source_code(var_shape_node.value).strip()] # (1) A slice can be a simple number such as 1, -2, i.e. gast.Index or gast.Constant # (2) A slice can also be represented by bounds such as 2:-1, i.e. not gast.Index or gast.Constant # In (1) case, we pass the number as 'idx' argument in convert_var_shape # In (2) case, we have to make it like `convert_var_shape(x)[slice]` if slice_node is not None and slice_is_num(slice_node): args.append(ast_to_source_code(slice_node.slice).strip()) convert_var_shape_func = "paddle.jit.dy2static.convert_var_shape({}, in_control_flow={})".format( ",".join(args), in_control_flow) api_shape_node = gast.parse(convert_var_shape_func).body[0].value if slice_node is not None and not slice_is_num(slice_node): return gast.Subscript( value=api_shape_node, slice=slice_node.slice, ctx=gast.Load()) return api_shape_node if isinstance(var_shape_node, gast.Subscript): result_node = copy.deepcopy(var_shape_node) result_node = create_convert_shape_node(result_node.value, result_node, in_control_flow) return result_node
def _create_bool_op_node(self, nodes, api_type): ''' NOTE(liym27): The arguments of function convert_logical_XX should be callable so that they can be run according to the actual order. In `convert_logical_and(lambda:x>1, lambda:y<1)`, `lambda:y<1` must be run after `lambda:x>1`, If `x>1` is False, `y<1` should NOT be run. ''' assert len( nodes ) > 1, "The length of BoolOp should be at least 2, but received {}.".format( len(nodes)) if len(nodes) > 2: # Creates logic_and/logic_or node recursively. pre_logic_node = self._create_bool_op_node(nodes[:2], api_type) if len(nodes[2:]) == 1: post_logic_node = nodes[2] else: post_logic_node = self._create_bool_op_node( nodes[2:], api_type) nodes = [pre_logic_node] + [post_logic_node] args = [ast_to_source_code(child) for child in nodes] new_node_str = "paddle.jit.dy2static.convert_logical_{}(lambda:{}, lambda:{})".format( api_type, args[0], args[1]) # NOTE: gast.parse return Module(body=[expr(...)]) new_node = gast.parse(new_node_str).body[0].value return new_node
def test_construct_node_wrapper(self): for func in test_funcs: test_source_code = inspect.getsource(func) ast_root = gast.parse(test_source_code) visitor = StaticAnalysisVisitor(ast_root) wrapper_root = visitor.get_node_wrapper_root() node_to_wrapper_map = visitor.get_node_to_wrapper_map() self._check_wrapper(wrapper_root, node_to_wrapper_map)
def _to_array_write_node(self, node): assert isinstance(node, gast.Call) array = astor.to_source(gast.gast_to_ast(node.func.value)) x = astor.to_source(gast.gast_to_ast(node.args[0])) i = "paddle.tensor.array_length({})".format(array) func_code = "paddle.tensor.array_write(x={}, i={}, array={})".format( x, i, array) return gast.parse(func_code).body[0].value
def visit_Assert(self, node): convert_assert_node = gast.parse( 'paddle.jit.dy2static.convert_assert({test}, {msg})'.format( test=ast_to_source_code(node.test), msg=ast_to_source_code(node.msg) if node.msg else "")).body[0].value return gast.Expr(value=convert_assert_node)
def visit_Attribute(self, node): if node.attr == 'shape': args = ast_to_source_code(node.value).strip() convert_var_shape_func = "paddle.jit.dy2static.convert_var_shape_simple({})".format( args) api_shape_node = gast.parse(convert_var_shape_func).body[0].value return api_shape_node return node
def visit_UnaryOp(self, node): self.generic_visit(node) if isinstance(node.op, gast.Not): arg = ast_to_source_code(node.operand) new_node_str = "_jst.convert_logical_not({})".format(arg) # NOTE: gast.parse returns Module(body=[expr(value=...)]) new_node = gast.parse(new_node_str).body[0].value return new_node return node
def code_gast_ast(source): """ Transform source_code into gast.Node and modify it, then back to ast.Node. """ source = textwrap.dedent(source) root = gast.parse(source) new_root = GastNodeTransformer(root).apply() ast_root = gast.gast_to_ast(new_root) return ast.dump(ast_root)
def create_convert_ifelse_node(return_name_ids, pred, true_func, false_func, is_if_expr=False): """ Create `paddle.jit.dy2static.convert_ifelse( pred, true_fn, false_fn, true_args, false_args, return_vars)` to replace original `python if/else` statement. """ def create_name_nodes(name_ids): if not name_ids: return gast.Tuple(elts=[], ctx=gast.Load()) gast_names = [ gast.Name(id=name_id, ctx=gast.Load(), annotation=None, type_comment=None) for name_id in name_ids ] name_node = gast.Tuple(elts=gast_names, ctx=gast.Load()) return name_node if is_if_expr: true_args = gast.Tuple(elts=[], ctx=gast.Load()) false_args = gast.Tuple(elts=[], ctx=gast.Load()) true_func_source = "lambda : {}".format(ast_to_source_code(true_func)) false_func_source = "lambda : {}".format( ast_to_source_code(false_func)) else: true_args = gast.Tuple(elts=true_func.args.args, ctx=gast.Load()) false_args = gast.Tuple(elts=false_func.args.args, ctx=gast.Load()) true_func_source = true_func.name false_func_source = false_func.name return_vars = create_name_nodes(return_name_ids) convert_ifelse_layer = gast.parse( '_jst.convert_ifelse(' '{pred}, {true_fn}, {false_fn}, {true_args}, {false_args}, {return_vars})' .format(pred=ast_to_source_code(pred), true_fn=true_func_source, false_fn=false_func_source, true_args=ast_to_source_code(true_args), false_args=ast_to_source_code(false_args), return_vars=ast_to_source_code(return_vars))).body[0].value if return_name_ids: _, cond_node = create_assign_node(return_name_ids, convert_ifelse_layer) else: # No variables can be returned if no assign statement in if.body. cond_node = gast.Expr(value=convert_ifelse_layer) return cond_node
def visit_Call(self, node): self.generic_visit(node) func_str = ast_to_source_code(node.func).strip() if func_str in self._castable_type and len(node.args) > 0: args_str = ast_to_source_code(node.args[0]).strip() new_func_str = "paddle.jit.dy2static.convert_var_dtype({}, '{}')".format( args_str, func_str) new_node = gast.parse(new_func_str).body[0].value return new_node return node
def test_loop_vars(self): for i in range(len(self.loop_funcs)): func = self.loop_funcs[i] test_func = inspect.getsource(func) gast_root = gast.parse(test_func) name_visitor = NameVisitor(gast_root) for node in gast.walk(gast_root): if isinstance(node, (gast.While, gast.For)): loop_var_names, create_var_names = name_visitor.get_loop_var_names( node) self.assertEqual(loop_var_names, self.loop_var_names[i]) self.assertEqual(create_var_names, self.create_var_names[i])
def create_choose_shape_node(attr_shape_name, api_shape_name, slice_node=None): eval_exist_func = "paddle.jit.dy2static.eval_if_exist_else_none('{}', globals())".format( api_shape_name) args = [attr_shape_name, eval_exist_func] if slice_node is not None and slice_is_num(slice_node): args.append(ast_to_source_code(slice_node.slice).strip()) choose_shape_func = "paddle.jit.dy2static.choose_shape_attr_or_api({})".format( ",".join(args)) choose_shape_node = gast.parse(choose_shape_func).body[0].value if slice_node is not None and not slice_is_num(slice_node): return gast.Subscript( value=choose_shape_node, slice=slice_node.slice, ctx=gast.Load()) return choose_shape_node
def test_paddle_api(self): code = """ def foo(x): if fluid.layers.shape(x)[0] > 16: x = x + 1 return x """ code = textwrap.dedent(code) node = gast.parse(code) static_analysis_visitor = StaticAnalysisVisitor(node) test_node = node.body[0].body[0].test self.assertTrue( is_control_flow_to_transform(test_node, static_analysis_visitor))
def test_shape_with_andOr(self): code = """ def foo(x): batch_size = fluid.layers.shape(x) if x is not None and batch_size[0] > 16 or 2 > 1: x = x + 1 return x """ code = textwrap.dedent(code) node = gast.parse(code) static_analysis_visitor = StaticAnalysisVisitor(node) test_node = node.body[0].body[1].test self.assertTrue( is_control_flow_to_transform(test_node, static_analysis_visitor))
def test_with_node_var_type_map(self): node = gast.parse("x > 1") node_test = node.body[0].value # if x is a Tensor var_name_to_type = {"x": {NodeVarType.TENSOR}} self.assertTrue( is_control_flow_to_transform(node_test, var_name_to_type=var_name_to_type)) # if x is not a Tensor var_name_to_type = {"x": {NodeVarType.NUMPY_NDARRAY}} self.assertFalse( is_control_flow_to_transform(node_test, var_name_to_type=var_name_to_type))
def get_code(self, dygraph_func): """ Returns the translated static function string code from dygraph function. Args: dygraph_func (callable): the dygraph function. Returns: str: the string code of translated static function. Examples: .. code-block:: python import paddle def func(x): if paddle.mean(x) > 0: x_v = x - 1 else: x_v = x + 1 return x_v prog_trans = paddle.jit.ProgramTranslator() code = prog_trans.get_code(func) print(type(code)) # <class 'str'> """ assert callable( dygraph_func ), "Input dygraph_func is not a callable in ProgramTranslator.get_code" # Gets AST from dygraph function unwrap_func = unwrap(dygraph_func) raw_code = inspect.getsource(unwrap_func) code = textwrap.dedent(raw_code) root = gast.parse(code) # Transform AST dygraph_to_static = DygraphToStaticAst() root_wrapper = dygraph_to_static.get_static_ast(root) # Get source_code source_code = ast_to_source_code(root_wrapper.node) return source_code
def test_log_api(self): # test api for CI Converage logging_utils.set_verbosity(1, True) logging_utils.warn("warn") logging_utils.error("error") logging_utils.log(1, "log level 1") logging_utils.log(2, "log level 2") source_code = "x = 3" ast_code = gast.parse(source_code) logging_utils.set_code_level(1, True) logging_utils.log_transformed_code(1, ast_code, "TestTransformer") logging_utils.set_code_level(logging_utils.LOG_AllTransformer, True) logging_utils.log_transformed_code(logging_utils.LOG_AllTransformer, ast_code, "TestTransformer")
def visit_Call(self, node): self.generic_visit(node) if self._no_need_convert_call(node): return node func_str = ast_to_source_code(node.func).strip() # NOTE(liym27): Don't convert `pad.set_trace` even if the convertion doesn't work finally, because # it is clearer to see where it is called from. if PDB_SET in func_str: return node new_func_str = "_jst.convert_call({})".format(func_str) new_func_ast = gast.parse(new_func_str).body[0].value node.func = new_func_ast return node
def _transform_slice_to_tensor_write(self, node): assert isinstance(node, gast.Assign) target_node = node.targets[0] target_name = target_node.value.id slice_node = target_node.slice if isinstance(slice_node, gast.Slice): pass elif slice_is_num(target_node): value_code = ast_to_source_code(node.value) i = "paddle.cast(" \ "x=_jst.to_static_variable({})," \ "dtype='int64')".format(ast_to_source_code(slice_node)) assign_code = "{} = paddle.tensor.array_write(x={}, i={}, array={})" \ .format(target_name, value_code, i, target_name) assign_node = gast.parse(assign_code).body[0] return assign_node
def create_and_update_origin_info_map(transformed_node, static_func, is_global=True): """ Creates a original information map between transformed static function and original dygraph function. Args: transformed_node(gast.AST): The AST node of transformed dygraph function with attached source information of original dygraph function. static_func(Callable): The static function transformed by dygraph function corresponding to transformed_node. Returns: The original information map. """ origin_info_map = {} static_source = inspect.getsource(static_func) static_node = gast.parse(static_source) static_node = attach_origin_info(static_node, static_func) for t_node, s_node in ast_walk(transformed_node, static_node): assert type(t_node) == type(s_node), \ "The node types should be the same, but received type(t_node) is {}, and type(s_node) is {}." \ .format(type(t_node), type(s_node)) dygraph_info = getattr(t_node, ORIGI_INFO, None) static_info = getattr(s_node, ORIGI_INFO, None) if dygraph_info is None or static_info is None: continue static_loc = static_info.location.line_location exist_origin_info = origin_info_map.get(static_loc) if exist_origin_info is not None: if exist_origin_info.location.lineno >= dygraph_info.location.lineno: continue if exist_origin_info.location.col_offset <= dygraph_info.location.col_offset: continue origin_info_map[static_loc] = dygraph_info global_origin_info_map.update(origin_info_map) if is_global: return global_origin_info_map return origin_info_map
def test_log_transformed_code(self): source_code = "x = 3" ast_code = gast.parse(source_code) stream = io.StringIO() log = self.translator_logger.logger stdout_handler = logging.StreamHandler(stream) log.addHandler(stdout_handler) with mock.patch.object(sys, 'stdout', stream): paddle.jit.set_code_level(1) logging_utils.log_transformed_code(1, ast_code, "BasicApiTransformer") paddle.jit.set_code_level() logging_utils.log_transformed_code(logging_utils.LOG_AllTransformer, ast_code, "All Transformers") self.assertIn(source_code, stream.getvalue())
def visit_Call(self, node): self.generic_visit(node) if not is_grad_api_node(node): return node dygraph_grad_parameters = [ "outputs", "inputs", "grad_outputs", "retain_graph", "create_graph", "only_inputs", "allow_unused", "no_grad_vars" ] to_static_grad_param = { "outputs": "targets", "inputs": "inputs", "grad_outputs": "target_gradients", "no_grad_vars": "no_grad_set" } static_keywords = [] for kw in node.keywords: if kw.arg not in dygraph_grad_parameters or kw.arg not in to_static_grad_param: warnings.warn( "paddle.grad has unsupported parameter in jit: " + kw.arg + ", jit will discard it") continue dygraph_grad_parameters.remove(kw.arg) kw.arg = to_static_grad_param[kw.arg] static_keywords.append(kw) for i in range(len(node.args)): arg_name = dygraph_grad_parameters[i] if arg_name not in to_static_grad_param: warnings.warn( "paddle.grad has unsupported parameter in jit: " + kw.arg + ", jit will discard it") continue kw = gast.keyword(arg=to_static_grad_param[arg_name], value=node.args[i]) static_keywords.append(kw) node.func = gast.parse('paddle.static.gradients').body[0].value node.keywords = static_keywords node.args = [] return node
def _convert(self, func): """ Converts dygraph function into static function. For two functions with same dedent code, the second function will reuse the transformed ast node of previous one. For example: # A.py def foo(x, y): z = x + y return z # B.py def foo(x, y): z = x + y return z If the conversion of A.foo happens after B.foo, it will reuse the transformed ast node of B.foo to speed up the conversion. """ # Note: In Python2, it will raise OSError when inspect function # with decorator directly and function.__wrapped__ holds the actual function. func = unwrap(func) source_code = func_to_source_code(func) # TODO(liym27): # Consider this case: source_code in self._code_to_ast_caches, # but actually they are methods in different classes. # Maybe use (__class__, source_code) as key if source_code in self._code_to_ast_caches: root_wrapper = self._code_to_ast_caches[source_code] else: root = gast.parse(source_code) root = attach_origin_info(root, func) root_wrapper = self._dygraph_to_static.get_static_ast(root) self._code_to_ast_caches[source_code] = root_wrapper # Get static function from AST static_func, file_name = ast_to_func(root_wrapper.node, func) create_and_update_origin_info_map(root_wrapper.node, static_func) return static_func
def test_var_env(self): for i, func in enumerate(test_funcs): var_type = result_var_type[i] test_source_code = inspect.getsource(func) ast_root = gast.parse(test_source_code) print(gast.dump(ast_root)) visitor = StaticAnalysisVisitor(ast_root) var_env = visitor.get_var_env() # There must be 1 sub scope for the test function self.assertEqual(1, len(var_env.cur_scope.sub_scopes)) var_env.cur_scope = var_env.cur_scope.sub_scopes[0] scope_var_type = var_env.get_scope_var_type() print(scope_var_type) self.assertEqual(len(scope_var_type), len(var_type)) for name in scope_var_type: print("Test var name %s" % (name)) self.assertTrue(name in var_type) self.assertEqual(scope_var_type[name], var_type[name])
def visit_Compare(self, node): self.generic_visit(node) left_str = ast_to_source_code(node.left).strip() if left_str.startswith("_jst.convert_var_shape"): # check left and comparators are all converted var shape compare_arg_strs = left_str for i, comparator in enumerate(node.comparators): comparator_str = ast_to_source_code(comparator).strip() if not comparator_str.startswith("_jst.convert_var_shape"): return node op_str = cmpop_node_to_str(node.ops[i]) compare_arg_strs += (", '" + op_str + "', " + comparator_str) # Now all left and comparators are converted shape # Replace some comparsion operation because of difference between # Python and Paddle new_node_str = "_jst.convert_shape_compare({})".format( compare_arg_strs) new_node = gast.parse(new_node_str).body[0].value return new_node return node