class TensorShapeTransformer(gast.NodeTransformer): """ This class transforms variable.shape used in Paddle Apis or control flow conditions into Static Graph Ast. """ def __init__(self, wrapper_root): assert isinstance( wrapper_root, AstNodeWrapper ), "Input non-AstNodeWrapper node for the initialization of TensorShapeTransformer." self.wrapper_root = wrapper_root self.root = wrapper_root.node # stores origin var string name (like "x" in `x = t.shape`) to # static shape var string name (like "x_SUFFIX" in `x_SUFFIX = shape(t)`) self.name_to_var_shape = {} self.static_analysis_visitor = StaticAnalysisVisitor(self.root) self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map( ) var_env = self.static_analysis_visitor.get_var_env() var_env.cur_scope = var_env.cur_scope.sub_scopes[0] self.scope_var_type_dict = var_env.get_scope_var_type() def transform(self): SplitAssignTransformer(self.root).transform() self.visit(self.root) def visit_Assign(self, node): update_static_shape_var_node = self._update_name_to_var_shape(node) if update_static_shape_var_node is not None: ret = [node] ret.extend(update_static_shape_var_node) return ret self.generic_visit(node) return node def visit_Subscript(self, node): value_node = node.value slice_node = node.slice if isinstance(value_node, gast.Name): if value_node.id in self.name_to_var_shape and self._used_by_paddle_api( value_node): return create_choose_shape_node( value_node.id, self.name_to_var_shape[value_node.id], node) elif isinstance(value_node, gast.Attribute): if self._used_by_paddle_api(value_node): value_name = ast_to_source_code(value_node).strip() if value_name in self.name_to_var_shape: return create_choose_shape_node( value_name, self.name_to_var_shape[value_name], node) if self._is_var_shape(value_node): return create_convert_shape_node(value_node, node) return node def visit_Attribute(self, node): if self._used_by_paddle_api(node): name = ast_to_source_code(node).strip() if name in self.name_to_var_shape: return create_choose_shape_node(name, self.name_to_var_shape[name]) if self._is_var_shape(node): return create_convert_shape_node(node) return node def visit_Name(self, node): if node.id in self.name_to_var_shape: if self._used_by_paddle_api(node): return create_choose_shape_node(node.id, self.name_to_var_shape[node.id]) return node def visit_Call(self, node): if is_paddle_api(node): # Visit gast.Attribute and gast.Name to replace var.shape if necessary. self.generic_visit(node) # Don't have to visit other APIs return node def visit_If(self, node): # Call generic_visit first to transform var.shape that is used in Paddle Api. self.generic_visit(node) cond = node.test self._transform_var_shape_if_necessary(cond) return node def visit_While(self, node): self.generic_visit(node) cond = node.test self._transform_var_shape_if_necessary(cond) return node def visit_For(self, node): self.generic_visit(node) iter = node.iter self._transform_var_shape_if_necessary(iter) # If var.shape is a gast.Name and it is used in range function, transform it self._transform_var_shape_in_range(node) return node def _transform_var_shape_in_range(self, node): assert isinstance(node, gast.For) if not isinstance(node.iter, gast.Call): return False if not isinstance(node.iter.func, gast.Name): return False if node.iter.func.id != "range": return False args = node.iter.args for idx, arg in enumerate(args): if isinstance(arg, gast.Name) and arg.id in self.name_to_var_shape: args[idx] = create_choose_shape_node( arg.id, self.name_to_var_shape[arg.id]) return True def _transform_var_shape_if_necessary(self, cond): need_transformed = False for child_node in gast.walk(cond): var_shape_node = None if isinstance(child_node, (gast.Name, gast.Attribute, gast.Subscript)): child_name = ast_to_source_code(child_node).strip() if child_name in self.name_to_var_shape: var_shape_node = create_choose_shape_node( child_name, self.name_to_var_shape[child_name]) elif self._is_var_shape(child_node): var_shape_node = child_node if var_shape_node: need_transformed = True wrapper_node = self.node_to_wrapper_map.get(child_node) parent_node = wrapper_node.parent.node for field, value in gast.iter_fields(parent_node): if child_node is value: if var_shape_node is child_node: setattr(parent_node, field, create_convert_shape_node(var_shape_node, None, True)) else: setattr(parent_node, field, var_shape_node) break # Some child_node may be in a list such as gast.Compare if isinstance(value, list): has_converted_shape = False for i, v in enumerate(value): if child_node is v: if var_shape_node is child_node: value[i] = create_convert_shape_node( var_shape_node, None, True) else: value[i] = var_shape_node has_converted_shape = True break if has_converted_shape: break return need_transformed def _used_by_paddle_api(self, node): """ Whether node is used in paddle api as arguments. For example: 1) Return True in `paddle.relu(x)` where node is `x` (gast.Name) 2) Return True in `paddle.add(self.x)` where node is `self.x` (gast.Attribute) 3) Return False in `paddle.add(self.x)` where node is `paddle.add` (gast.Attribute), because the role of node is not arguments but `gast.Call.func`. """ assert isinstance(node, (gast.Attribute, gast.Name)) wrapper_node = self.node_to_wrapper_map.get(node) if not wrapper_node: # Transformed node is not in node_to_wrapper_map return False while wrapper_node.parent: parent_node = wrapper_node.parent.node if isinstance(parent_node, gast.Call): # Note(Aurelius84): Filter the case when the role of node is `gast.Call.func`. if is_paddle_api(parent_node) and parent_node.func != node: return True else: return False wrapper_node = wrapper_node.parent return False def _is_var_shape(self, node): """ Return True if node is like `x.shape` or `x.shape[0]`, return False otherwise. """ if not isinstance(node, (gast.Attribute, gast.Subscript)): return False if isinstance(node, gast.Attribute): # If node is `paddle.shape`, return False if (node.attr == 'shape' and isinstance(node.value, gast.Name) and node.value.id == 'paddle'): return False if node.attr != 'shape': return False return True if isinstance(node, gast.Subscript): value_node = node.value return self._is_var_shape(value_node) return False def _update_name_to_var_shape(self, node): def replace_dot(name): # replace all '.' into '_' return name.replace('.', '_') assert isinstance(node, gast.Assign) target_node = node.targets[0] value_node = node.value update_static_shape_var_node = None if isinstance(target_node, gast.Tuple): update_static_shape_var_node = [] for idx, element in enumerate(target_node.elts): target_id = ast_to_source_code(element).strip() if isinstance(value_node, gast.Name): if value_node.id in self.name_to_var_shape: # TODO(zhhsplendid): is context a problem for the result node of gast.parse? static_shape_var_name = unique_name.generate( replace_dot(target_id) + STATIC_CONVERT_VAR_SHAPE_SUFFIX) static_shape_var_node = gast.parse( static_shape_var_name).body[0].value static_shape_value_name = self.name_to_var_shape[ value_node.id] sub_node_str = "{}[{}]".format(static_shape_value_name, idx) sub_node = gast.parse(sub_node_str).body[0].value update_static_shape_var_node.append( gast.Assign( targets=[static_shape_var_node], value=sub_node)) self.name_to_var_shape[ target_id] = static_shape_var_name if isinstance(value_node, gast.Attribute): if self._is_var_shape(value_node): # eg: x.shape static_shape_var_name = unique_name.generate( replace_dot(target_id) + STATIC_CONVERT_VAR_SHAPE_SUFFIX) static_shape_var_node = gast.parse( static_shape_var_name).body[0].value static_shape_value_node = copy.deepcopy(value_node) # x.shape becomes convert_var_shape_simple(x) static_shape_value_node = ShapeAttributeTransformer( ).visit(static_shape_value_node) sub_node_str = "{}[{}]".format( ast_to_source_code(static_shape_value_node).strip(), idx) sub_node = gast.parse(sub_node_str).body[0].value # Note(Aurelius84): Becuase static_shape_var_name is used in # eval_if_exist_else_none() as plain string, so it will not # be pasred as argument in convert_loop/ifelse. We delcare it # as global var because it has unique name. update_static_shape_var_node.append( gast.Global(names=[static_shape_var_name])) update_static_shape_var_node.append( gast.Assign( targets=[static_shape_var_node], value=sub_node)) self.name_to_var_shape[ target_id] = static_shape_var_name return update_static_shape_var_node else: target_id = ast_to_source_code(target_node).strip() if isinstance(value_node, gast.Name): if value_node.id in self.name_to_var_shape: static_shape_var_name = unique_name.generate( replace_dot(target_id) + STATIC_CONVERT_VAR_SHAPE_SUFFIX) static_shape_var_node = gast.parse( static_shape_var_name).body[0].value static_shape_value_name = self.name_to_var_shape[ value_node.id] static_shape_value_node = gast.parse( static_shape_value_name).body[0].value update_static_shape_var_node = [ gast.Assign( targets=[static_shape_var_node], value=static_shape_value_node) ] self.name_to_var_shape[target_id] = static_shape_var_name elif self._is_var_shape(value_node): # eg: x.shape or x.shape[0] static_shape_var_name = unique_name.generate( replace_dot(target_id) + STATIC_CONVERT_VAR_SHAPE_SUFFIX) static_shape_var_node = gast.parse(static_shape_var_name).body[ 0].value static_shape_value_node = copy.deepcopy(value_node) # x.shape becomes convert_var_shape_simple(x) static_shape_value_node = ShapeAttributeTransformer().visit( static_shape_value_node) # Declare static_shape_var_name as global var update_static_shape_var_node = [ gast.Global(names=[static_shape_var_name]) ] update_static_shape_var_node.append( gast.Assign( targets=[static_shape_var_node], value=static_shape_value_node)) self.name_to_var_shape[target_id] = static_shape_var_name return update_static_shape_var_node
class ListTransformer(gast.NodeTransformer): """ This class transforms python list used in control flow into Static Graph Ast. """ def __init__(self, wrapper_root): assert isinstance( wrapper_root, AstNodeWrapper ), "Input non-AstNodeWrapper node for the initialization of ListTransformer." self.wrapper_root = wrapper_root self.root = wrapper_root.node self.list_name_to_updated = dict() self.list_nodes = set() self.static_analysis_visitor = StaticAnalysisVisitor(self.root) self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map( ) var_env = self.static_analysis_visitor.get_var_env() var_env.cur_scope = var_env.cur_scope.sub_scopes[0] self.scope_var_type_dict = var_env.get_scope_var_type() def transform(self): SplitAssignTransformer(self.root).transform() self.visit(self.root) self.replace_list_with_tensor_array(self.root) def visit_Call(self, node): if isinstance(node.func, gast.Attribute): func_name = node.func.attr if func_name == "pop": node = self._replace_pop(node) return node def visit_Assign(self, node): if self._update_list_name_to_updated(node): return node if self._need_to_array_write_node(node): return self._transform_slice_to_tensor_write(node) self.generic_visit(node) return node def visit_If(self, node): self.generic_visit(node) if is_control_flow_to_transform(node, self.static_analysis_visitor, self.scope_var_type_dict): self._transform_list_append_in_control_flow(node) return node def visit_While(self, node): self.generic_visit(node) if is_control_flow_to_transform(node, self.static_analysis_visitor, self.scope_var_type_dict): self._transform_list_append_in_control_flow(node) return node def visit_For(self, node): self.generic_visit(node) if is_control_flow_to_transform(node, self.static_analysis_visitor, self.scope_var_type_dict): self._transform_list_append_in_control_flow(node) return node def replace_list_with_tensor_array(self, node): for child_node in gast.walk(node): if isinstance(child_node, gast.Assign): if self._need_to_create_tensor_array(child_node): child_node.value = self._create_tensor_array() def _transform_list_append_in_control_flow(self, node): for child_node in gast.walk(node): if self._need_to_array_write_node(child_node): child_node.value = \ self._to_array_write_node(child_node.value) def _need_to_array_write_node(self, node): if isinstance(node, gast.Expr): if isinstance(node.value, gast.Call): if self._is_list_append_tensor(node.value): return True if isinstance(node, gast.Assign): target_node = node.targets[0] if isinstance(target_node, gast.Subscript): list_name = ast_to_source_code(target_node.value).strip() if list_name in self.list_name_to_updated: if self.list_name_to_updated[list_name] == True: return True return False def _transform_slice_to_tensor_write(self, node): assert isinstance(node, gast.Assign) target_node = node.targets[0] target_name = target_node.value.id slice_node = target_node.slice if isinstance(slice_node, gast.Slice): pass elif isinstance(slice_node, gast.Index): value_code = ast_to_source_code(node.value) i = "paddle.cast(" \ "x=paddle.jit.dy2static.to_static_variable({})," \ "dtype='int64')".format(ast_to_source_code(slice_node)) assign_code = "{} = fluid.layers.array_write(x={}, i={}, array={})" \ .format(target_name, value_code, i, target_name) assign_node = gast.parse(assign_code).body[0] return assign_node def _is_list_append_tensor(self, node): """ a.append(b): a is list, b is Tensor self.x.append(b): self.x is list, b is Tensor """ assert isinstance(node, gast.Call) # 1. The func is `append`. if not isinstance(node.func, gast.Attribute): return False if node.func.attr != 'append': return False # 2. It's a `python list` to call append(). value_name = astor.to_source(gast.gast_to_ast(node.func.value)).strip() if value_name not in self.list_name_to_updated: return False # 3. The number of arg of append() is one # Only one argument is supported in Python list.append() if len(node.args) != 1: return False # TODO(liym27): The arg of append() should be Tensor. But because the type of arg is often wrong with static analysis, # the arg is not required to be Tensor here. # 4. The arg of append() is Tensor # arg = node.args[0] # if isinstance(arg, gast.Name): # # TODO: `arg.id` may be not in scope_var_type_dict if `arg.id` is the arg of decorated function # # Need a better way to confirm whether `arg.id` is a Tensor. # try: # var_type_set = self.scope_var_type_dict[arg.id] # except KeyError: # return False # if NodeVarType.NUMPY_NDARRAY in var_type_set: # return False # if NodeVarType.TENSOR not in var_type_set and NodeVarType.PADDLE_RETURN_TYPES not in var_type_set: # return False # # TODO: Consider that `arg` may be a gast.Call about Paddle Api. eg: list_a.append(fluid.layers.reshape(x)) # # else: # # return True self.list_name_to_updated[value_name.strip()] = True return True def _need_to_create_tensor_array(self, node): assert isinstance(node, gast.Assign) target_node = node.targets[0] try: target_id = target_node.id except AttributeError: return False if self.list_name_to_updated.get( target_id) and node in self.list_nodes: return True return False def _create_tensor_array(self): # Although `dtype='float32'`, other types such as `int32` can also be supported func_code = "fluid.layers.create_array(dtype='float32')" func_node = gast.parse(func_code).body[0].value return func_node def _to_array_write_node(self, node): assert isinstance(node, gast.Call) array = astor.to_source(gast.gast_to_ast(node.func.value)) x = astor.to_source(gast.gast_to_ast(node.args[0])) i = "fluid.layers.array_length({})".format(array) func_code = "fluid.layers.array_write(x={}, i={}, array={})".format( x, i, array) return gast.parse(func_code).body[0].value def _update_list_name_to_updated(self, node): assert isinstance(node, gast.Assign) target_node = node.targets[0] # NOTE: Code like `x, y = a, []` has been transformed to `x=a; y=[]` try: target_id = target_node.id except AttributeError: return False value_node = node.value if isinstance(value_node, gast.List): self.list_name_to_updated[target_id] = False self.list_nodes.add(node) return True elif target_id in self.list_name_to_updated and \ self.list_name_to_updated[target_id] == False: del self.list_name_to_updated[target_id] return False def _replace_pop(self, node): """ Replace a pop statement for a list or dict. For example: list_a = [0,1,2,3,4] x = list_a.pop() # --> convert_pop(list_a) y = list_a.pop(1) # --> convert_pop(list_a, 1) dict_a = {"red":0, "blue":1, "yellow":2} m = dict_a.pop("red") # --> convert_pop(dict_a, "red") n = dict_a.pop("black", 3) # --> convert_pop(dict_a, "black", 3) """ assert isinstance(node, gast.Call) assert isinstance(node.func, gast.Attribute) target_node = node.func.value target_str = ast_to_source_code(target_node).strip() args_str = [ast_to_source_code(arg).strip() for arg in node.args] # NOTE(liym27): # 1. pop stmt for a list if len(args_str) == 0 # 2. pop stmt for a list or dict if len(args_str) == 1 # 3. pop stmt for a dict if len(args_str) == 2 if len(args_str) <= 2: new_pop_str = "paddle.jit.dy2static.convert_pop({}, {})"\ .format(target_str, ",".join(args_str)) new_pop_node = gast.parse(new_pop_str).body[0].value return new_pop_node else: return node
class TensorShapeTransformer(gast.NodeTransformer): """ This class transforms Tensor.shape used in Paddle Apis and control flow conditions into Static Graph Ast. """ def __init__(self, wrapper_root): assert isinstance( wrapper_root, AstNodeWrapper ), "Input non-AstNodeWrapper node for the initialization of TensorShapeTransformer." self.wrapper_root = wrapper_root self.root = wrapper_root.node self.name_to_tensor_shape = {} self.static_analysis_visitor = StaticAnalysisVisitor(self.root) self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map( ) var_env = self.static_analysis_visitor.get_var_env() var_env.cur_scope = var_env.cur_scope.sub_scopes[0] self.scope_var_type_dict = var_env.get_scope_var_type() def transform(self): self.visit(self.root) def visit_Assign(self, node): if self._update_name_to_tensor_shape(node): return node self.generic_visit(node) return node def visit_Attribute(self, node): if self._used_by_paddle_api(node): if self.is_tensor_shape(node): return create_api_shape_node(node) return node def visit_Name(self, node): if node.id in self.name_to_tensor_shape: if self._used_by_paddle_api(node): tensor_shape_node = self.name_to_tensor_shape[node.id] return create_api_shape_node(tensor_shape_node) return node def visit_Call(self, node): assert isinstance(node, gast.Call) if is_paddle_api(node): # Visit gast.Attribute and gast.Name to replace tensor.shape if necessary. self.generic_visit(node) return node def visit_If(self, node): # Call generic_visit first to transform Tensor.shape that is used in Paddle Api. self.generic_visit(node) cond = node.test self._transform_tensor_shape_if_necessary(cond) return node def visit_While(self, node): self.generic_visit(node) cond = node.test self._transform_tensor_shape_if_necessary(cond) return node def visit_For(self, node): self.generic_visit(node) iter = node.iter self._transform_tensor_shape_if_necessary(iter) # If tensor.shape is a gast.Name and it is used in range function, transform it self._transform_tensor_shape_in_range(node) return node def _transform_tensor_shape_in_range(self, node): assert isinstance(node, gast.For) if not isinstance(node.iter, gast.Call): return False if not isinstance(node.iter.func, gast.Name): return False if node.iter.func.id != "range": return False args = node.iter.args for idx, arg in enumerate(args): if isinstance(arg, gast.Name) and arg.id in self.name_to_tensor_shape: args[idx] = create_api_shape_node( self.name_to_tensor_shape[arg.id]) return True def _transform_tensor_shape_if_necessary(self, cond): for child_node in gast.walk(cond): tensor_shape_node = None if isinstance(child_node, (gast.Attribute)): if self.is_tensor_shape(child_node): tensor_shape_node = child_node elif isinstance(child_node, (gast.Name)): if child_node.id in self.name_to_tensor_shape: tensor_shape_node = self.name_to_tensor_shape[ child_node.id] if tensor_shape_node: wrapper_node = self.node_to_wrapper_map.get(child_node) parent_node = wrapper_node.parent.node for field, value in gast.iter_fields(parent_node): if child_node is value: setattr(parent_node, field, create_api_shape_node(tensor_shape_node)) break def _used_by_paddle_api(self, node): assert isinstance(node, (gast.Attribute, gast.Name)) wrapper_node = self.node_to_wrapper_map.get(node) if not wrapper_node: # Transformed node is not in node_to_wrapper_map return False while wrapper_node.parent: parent_node = wrapper_node.parent.node if isinstance(parent_node, gast.Call): if is_paddle_api(parent_node): return True else: return False wrapper_node = wrapper_node.parent return False def is_tensor_shape(self, node): """ Return True if node is like `x.shape` and x is Tensor, return False otherwise. """ assert isinstance(node, gast.Attribute) if node.attr != 'shape': return False try: value_id = node.value.id except AttributeError: return False if value_id in self.name_to_tensor_shape: return True # TODO: `value_id` may be not in scope_var_type_dict if `value_id` is the arg of decorated function # Need a better way to confirm whether `value_id` is a Tensor. try: var_type_set = self.scope_var_type_dict[value_id] except KeyError: return False if NodeVarType.NUMPY_NDARRAY in var_type_set: return False if NodeVarType.TENSOR not in var_type_set and NodeVarType.PADDLE_RETURN_TYPES not in var_type_set: return False return True def _update_name_to_tensor_shape(self, node): assert isinstance(node, gast.Assign) # TODO: Consider node has more than one target. eg: x, y = a, Tensor.shape[1] target_node = node.targets[0] try: target_id = target_node.id except AttributeError: return False value_node = node.value if isinstance(value_node, gast.Name): if value_node.id in self.name_to_tensor_shape: self.name_to_tensor_shape[ target_id] = self.name_to_tensor_shape[value_node.id] return True if isinstance(value_node, gast.Attribute): if self.is_tensor_shape(value_node): # eg: x.shape self.name_to_tensor_shape[target_id] = value_node return True if isinstance(value_node, gast.Subscript): if isinstance(value_node.value, gast.Attribute): if self.is_tensor_shape(value_node.value): # eg: x.shape[0] self.name_to_tensor_shape[target_id] = value_node return True return False
class ListTransformer(gast.NodeTransformer): """ This class transforms python list used in control flow into Static Graph Ast. """ def __init__(self, wrapper_root): assert isinstance( wrapper_root, AstNodeWrapper ), "Input non-AstNodeWrapper node for the initialization of ListTransformer." self.wrapper_root = wrapper_root self.root = wrapper_root.node self.list_name_to_updated = dict() self.list_nodes = set() self.static_analysis_visitor = StaticAnalysisVisitor(self.root) self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map( ) var_env = self.static_analysis_visitor.get_var_env() var_env.cur_scope = var_env.cur_scope.sub_scopes[0] self.scope_var_type_dict = var_env.get_scope_var_type() def transform(self): self.visit(self.root) self.replace_list_with_tensor_array(self.root) def visit_Call(self, node): if isinstance(node.func, gast.Attribute): func_name = node.func.attr if func_name == "pop": node = self._replace_list_pop(node) return node def visit_Assign(self, node): if self._update_list_name_to_updated(node): return node if self._need_to_array_write_node(node): return self._transform_slice_to_tensor_write(node) self.generic_visit(node) return node def visit_If(self, node): self.generic_visit(node) if is_control_flow_to_transform(node, self.static_analysis_visitor, self.scope_var_type_dict): self._transform_list_append_in_control_flow(node) return node def visit_While(self, node): self.generic_visit(node) if is_control_flow_to_transform(node, self.static_analysis_visitor, self.scope_var_type_dict): self._transform_list_append_in_control_flow(node) return node def visit_For(self, node): self.generic_visit(node) if is_control_flow_to_transform(node, self.static_analysis_visitor, self.scope_var_type_dict): self._transform_list_append_in_control_flow(node) return node def replace_list_with_tensor_array(self, node): for child_node in gast.walk(node): if isinstance(child_node, gast.Assign): if self._need_to_create_tensor_array(child_node): child_node.value = self._create_tensor_array() def _transform_list_append_in_control_flow(self, node): for child_node in gast.walk(node): if self._need_to_array_write_node(child_node): child_node.value = \ self._to_array_write_node(child_node.value) def _need_to_array_write_node(self, node): if isinstance(node, gast.Expr): if isinstance(node.value, gast.Call): if self._is_list_append_tensor(node.value): return True if isinstance(node, gast.Assign): target_node = node.targets[0] if isinstance(target_node, gast.Subscript): list_name = ast_to_source_code(target_node.value).strip() if list_name in self.list_name_to_updated: if self.list_name_to_updated[list_name] == True: return True return False def _transform_slice_to_tensor_write(self, node): assert isinstance(node, gast.Assign) target_node = node.targets[0] target_name = target_node.value.id slice_node = target_node.slice if isinstance(slice_node, gast.Slice): pass elif isinstance(slice_node, gast.Index): value_code = ast_to_source_code(node.value) i = "fluid.layers.cast(" \ "x=fluid.dygraph.dygraph_to_static.variable_trans_func.to_static_variable({})," \ "dtype='int64')".format(ast_to_source_code(slice_node)) assign_code = "{} = fluid.layers.array_write(x={}, i={}, array={})" \ .format(target_name, value_code, i, target_name) assign_node = gast.parse(assign_code).body[0] return assign_node def _is_list_append_tensor(self, node): """ a.append(b): a is list, b is Tensor self.x.append(b): self.x is list, b is Tensor """ assert isinstance(node, gast.Call) # 1. The func is `append`. if not isinstance(node.func, gast.Attribute): return False if node.func.attr != 'append': return False # 2. It's a `python list` to call append(). value_name = astor.to_source(gast.gast_to_ast(node.func.value)).strip() if value_name not in self.list_name_to_updated: return False # 3. The arg of append() is one `Tensor` # Only one argument is supported in Python list.append() if len(node.args) != 1: return False arg = node.args[0] if isinstance(arg, gast.Name): # TODO: `arg.id` may be not in scope_var_type_dict if `arg.id` is the arg of decorated function # Need a better way to confirm whether `arg.id` is a Tensor. try: var_type_set = self.scope_var_type_dict[arg.id] except KeyError: return False if NodeVarType.NUMPY_NDARRAY in var_type_set: return False if NodeVarType.TENSOR not in var_type_set and NodeVarType.PADDLE_RETURN_TYPES not in var_type_set: return False # else: # Todo: Consider that `arg` may be a gast.Call about Paddle Api. # eg: list_a.append(fluid.layers.reshape(x)) # return True self.list_name_to_updated[value_name.strip()] = True return True def _need_to_create_tensor_array(self, node): assert isinstance(node, gast.Assign) target_node = node.targets[0] try: target_id = target_node.id except AttributeError: return False if self.list_name_to_updated.get( target_id) and node in self.list_nodes: return True return False def _create_tensor_array(self): # Although `dtype='float32'`, other types such as `int32` can also be supported func_code = "fluid.layers.create_array(dtype='float32')" func_node = gast.parse(func_code).body[0].value return func_node def _to_array_write_node(self, node): assert isinstance(node, gast.Call) array = astor.to_source(gast.gast_to_ast(node.func.value)) x = astor.to_source(gast.gast_to_ast(node.args[0])) i = "fluid.layers.array_length({})".format(array) func_code = "fluid.layers.array_write(x={}, i={}, array={})".format( x, i, array) return gast.parse(func_code).body[0].value def _update_list_name_to_updated(self, node): assert isinstance(node, gast.Assign) target_node = node.targets[0] # TODO: Consider node has more than one target. eg: x, y = a, [] try: target_id = target_node.id except AttributeError: return False value_node = node.value if isinstance(value_node, gast.List): self.list_name_to_updated[target_id] = False self.list_nodes.add(node) return True elif target_id in self.list_name_to_updated and \ self.list_name_to_updated[target_id] == False: del self.list_name_to_updated[target_id] return False def _replace_list_pop(self, node): assert isinstance(node, gast.Call) assert isinstance(node.func, gast.Attribute) target_node = node.func.value target_str = ast_to_source_code(target_node).strip() if node.args: idx_node = node.args[0] idx_str = ast_to_source_code(idx_node).strip() else: idx_str = "None" new_call_str = "fluid.dygraph.dygraph_to_static.list_transformer.convert_list_pop({}, {})".format( target_str, idx_str) new_call_node = gast.parse(new_call_str).body[0].value return new_call_node
class TensorShapeTransformer(gast.NodeTransformer): """ This class transforms variable.shape used in Paddle Apis or control flow conditions into Static Graph Ast. """ def __init__(self, wrapper_root): assert isinstance( wrapper_root, AstNodeWrapper ), "Input non-AstNodeWrapper node for the initialization of TensorShapeTransformer." self.wrapper_root = wrapper_root self.root = wrapper_root.node self.name_to_var_shape = {} self.static_analysis_visitor = StaticAnalysisVisitor(self.root) self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map( ) var_env = self.static_analysis_visitor.get_var_env() var_env.cur_scope = var_env.cur_scope.sub_scopes[0] self.scope_var_type_dict = var_env.get_scope_var_type() def transform(self): SplitAssignTransformer(self.root).transform() self.visit(self.root) def visit_Assign(self, node): if self._update_name_to_var_shape(node): return node self.generic_visit(node) return node def visit_Attribute(self, node): if self._used_by_paddle_api(node): if self.is_var_shape(node): return create_convert_shape_node(node) return node def visit_Name(self, node): if node.id in self.name_to_var_shape: if self._used_by_paddle_api(node): var_shape_node = self.name_to_var_shape[node.id] return create_convert_shape_node(var_shape_node) return node def visit_Call(self, node): assert isinstance(node, gast.Call) if is_paddle_api(node): # Visit gast.Attribute and gast.Name to replace var.shape if necessary. self.generic_visit(node) return node def visit_If(self, node): # Call generic_visit first to transform var.shape that is used in Paddle Api. self.generic_visit(node) cond = node.test self._transform_var_shape_if_necessary(cond) return node def visit_While(self, node): self.generic_visit(node) cond = node.test self._transform_var_shape_if_necessary(cond) return node def visit_For(self, node): self.generic_visit(node) iter = node.iter self._transform_var_shape_if_necessary(iter) # If var.shape is a gast.Name and it is used in range function, transform it self._transform_var_shape_in_range(node) return node def _transform_var_shape_in_range(self, node): assert isinstance(node, gast.For) if not isinstance(node.iter, gast.Call): return False if not isinstance(node.iter.func, gast.Name): return False if node.iter.func.id != "range": return False args = node.iter.args for idx, arg in enumerate(args): if isinstance(arg, gast.Name) and arg.id in self.name_to_var_shape: args[idx] = create_convert_shape_node( self.name_to_var_shape[arg.id]) return True def _transform_var_shape_if_necessary(self, cond): need_transformed = False for child_node in gast.walk(cond): var_shape_node = None if isinstance(child_node, (gast.Attribute)): if self.is_var_shape(child_node): var_shape_node = child_node elif isinstance(child_node, (gast.Name)): if child_node.id in self.name_to_var_shape: var_shape_node = self.name_to_var_shape[child_node.id] if var_shape_node: need_transformed = True wrapper_node = self.node_to_wrapper_map.get(child_node) parent_node = wrapper_node.parent.node for field, value in gast.iter_fields(parent_node): if child_node is value: setattr(parent_node, field, create_convert_shape_node(var_shape_node)) break return need_transformed def _used_by_paddle_api(self, node): assert isinstance(node, (gast.Attribute, gast.Name)) wrapper_node = self.node_to_wrapper_map.get(node) if not wrapper_node: # Transformed node is not in node_to_wrapper_map return False while wrapper_node.parent: parent_node = wrapper_node.parent.node if isinstance(parent_node, gast.Call): if is_paddle_api(parent_node): return True else: return False wrapper_node = wrapper_node.parent return False def is_var_shape(self, node): """ Return True if node is like `x.shape`, return False otherwise. """ assert isinstance(node, gast.Attribute) if node.attr != 'shape': return False try: value_id = node.value.id except AttributeError: return False if value_id in self.name_to_var_shape: return True return True def _update_name_to_var_shape(self, node): assert isinstance(node, gast.Assign) target_node = node.targets[0] try: target_id = target_node.id except AttributeError: return False value_node = node.value if isinstance(value_node, gast.Name): if value_node.id in self.name_to_var_shape: self.name_to_var_shape[target_id] = self.name_to_var_shape[ value_node.id] return True if isinstance(value_node, gast.Attribute): if self.is_var_shape(value_node): # eg: x.shape self.name_to_var_shape[target_id] = value_node return True if isinstance(value_node, gast.Subscript): if isinstance(value_node.value, gast.Attribute): if self.is_var_shape(value_node.value): # eg: x.shape[0] self.name_to_var_shape[target_id] = value_node return True return False