def visit_Call(self, node): assert isinstance(node, gast.Call) if is_paddle_api(node): # Visit gast.Attribute and gast.Name to replace tensor.shape if necessary. self.generic_visit(node) return node
def _used_by_paddle_api(self, node): """ Whether node is used in paddle api as arguments. For example: 1) Return True in `paddle.relu(x)` where node is `x` (gast.Name) 2) Return True in `paddle.add(self.x)` where node is `self.x` (gast.Attribute) 3) Return False in `paddle.add(self.x)` where node is `paddle.add` (gast.Attribute), because the role of node is not arguments but `gast.Call.func`. """ assert isinstance(node, (gast.Attribute, gast.Name)) wrapper_node = self.node_to_wrapper_map.get(node) if not wrapper_node: # Transformed node is not in node_to_wrapper_map return False while wrapper_node.parent: parent_node = wrapper_node.parent.node if isinstance(parent_node, gast.Call): # Note(Aurelius84): Filter the case when the role of node is `gast.Call.func`. if is_paddle_api(parent_node) and parent_node.func != node: return True else: return False wrapper_node = wrapper_node.parent return False
def is_to_variable(node): assert isinstance(node, gast.Call) api_name = utils.ast_to_source_code(node.func).strip() if utils.is_dygraph_api(node): return api_name.endswith("to_variable") if utils.is_paddle_api(node): return api_name.endswith("to_tensor") return False
def is_grad_api_node(node): assert isinstance(node, gast.Call) api_name = utils.ast_to_source_code(node.func).strip() if utils.is_paddle_api(node): if 'no_grad' in api_name: warnings.warn( "paddle.no_grad is only supported for inference model, and not supported for training under @to_static." ) return False return api_name.endswith("grad") return False
def visit_Call(self, node): self.generic_visit(node) if is_paddle_api(node): return node if self._is_builtin_call(node): return node func_str = ast_to_source_code(node.func).strip() new_func_str = "fluid.dygraph.dygraph_to_static.convert_call({})".format( func_str) new_func_ast = gast.parse(new_func_str).body[0].value node.func = new_func_ast return node
def _used_by_paddle_api(self, node): assert isinstance(node, (gast.Attribute, gast.Name)) wrapper_node = self.node_to_wrapper_map.get(node) if not wrapper_node: # Transformed node is not in node_to_wrapper_map return False while wrapper_node.parent: parent_node = wrapper_node.parent.node if isinstance(parent_node, gast.Call): if is_paddle_api(parent_node): return True else: return False wrapper_node = wrapper_node.parent return False
def _no_need_convert_call(self, node): """ Determines whether a function needs to be transformed by `convert_call`. It doesn't need to be transformed when a function satisfies the following conditions: 1. It's a api of paddle 2. It's a python builtin function not include `len` """ assert isinstance(node, gast.Call) if is_paddle_api(node): return True func_str = ast_to_source_code(node.func).strip() try: from paddle.fluid.dygraph.dygraph_to_static.convert_call_func import is_builtin_len, is_builtin is_builtin = eval("is_builtin({})".format(func_str)) is_builtin_len = eval("is_builtin_len({})".format(func_str)) return is_builtin and not is_builtin_len except Exception: return False
def visit_Call(self, node): if is_paddle_api(node): # Visit gast.Attribute and gast.Name to replace var.shape if necessary. self.generic_visit(node) # Don't have to visit other APIs return node
def visit_Call(self, node): self._visit_Call(node) if is_paddle_api(node): self.is_control_flow_num += 1 return node