Esempio n. 1
0
    def _transform_tensor_shape_in_range(self, node):
        assert isinstance(node, gast.For)
        if not isinstance(node.iter, gast.Call):
            return False
        if not isinstance(node.iter.func, gast.Name):
            return False
        if node.iter.func.id != "range":
            return False
        args = node.iter.args
        for idx, arg in enumerate(args):
            if isinstance(arg,
                          gast.Name) and arg.id in self.name_to_tensor_shape:
                args[idx] = create_api_shape_node(
                    self.name_to_tensor_shape[arg.id])

        return True
Esempio n. 2
0
    def _transform_tensor_shape_if_necessary(self, cond):
        for child_node in gast.walk(cond):
            tensor_shape_node = None
            if isinstance(child_node, (gast.Attribute)):
                if self.is_tensor_shape(child_node):
                    tensor_shape_node = child_node
            elif isinstance(child_node, (gast.Name)):
                if child_node.id in self.name_to_tensor_shape:
                    tensor_shape_node = self.name_to_tensor_shape[
                        child_node.id]

            if tensor_shape_node:
                wrapper_node = self.node_to_wrapper_map.get(child_node)
                parent_node = wrapper_node.parent.node
                for field, value in gast.iter_fields(parent_node):
                    if child_node is value:
                        setattr(parent_node, field,
                                create_api_shape_node(tensor_shape_node))
                        break
Esempio n. 3
0
 def visit_Name(self, node):
     if node.id in self.name_to_tensor_shape:
         if self._used_by_paddle_api(node):
             tensor_shape_node = self.name_to_tensor_shape[node.id]
             return create_api_shape_node(tensor_shape_node)
     return node
Esempio n. 4
0
 def visit_Attribute(self, node):
     if self._used_by_paddle_api(node):
         if self.is_tensor_shape(node):
             return create_api_shape_node(node)
     return node