def _transform_tensor_shape_in_range(self, node): assert isinstance(node, gast.For) if not isinstance(node.iter, gast.Call): return False if not isinstance(node.iter.func, gast.Name): return False if node.iter.func.id != "range": return False args = node.iter.args for idx, arg in enumerate(args): if isinstance(arg, gast.Name) and arg.id in self.name_to_tensor_shape: args[idx] = create_api_shape_node( self.name_to_tensor_shape[arg.id]) return True
def _transform_tensor_shape_if_necessary(self, cond): for child_node in gast.walk(cond): tensor_shape_node = None if isinstance(child_node, (gast.Attribute)): if self.is_tensor_shape(child_node): tensor_shape_node = child_node elif isinstance(child_node, (gast.Name)): if child_node.id in self.name_to_tensor_shape: tensor_shape_node = self.name_to_tensor_shape[ child_node.id] if tensor_shape_node: wrapper_node = self.node_to_wrapper_map.get(child_node) parent_node = wrapper_node.parent.node for field, value in gast.iter_fields(parent_node): if child_node is value: setattr(parent_node, field, create_api_shape_node(tensor_shape_node)) break
def visit_Name(self, node): if node.id in self.name_to_tensor_shape: if self._used_by_paddle_api(node): tensor_shape_node = self.name_to_tensor_shape[node.id] return create_api_shape_node(tensor_shape_node) return node
def visit_Attribute(self, node): if self._used_by_paddle_api(node): if self.is_tensor_shape(node): return create_api_shape_node(node) return node