Beispiel #1
0
    def test_nested_loop_vars(self):
        func = self.nested_for_loop_func
        test_func = inspect.getsource(func)
        gast_root = gast.parse(test_func)
        name_visitor = NameVisitor(gast_root)

        self.loop_var_names = [
            set(["j", "two"]),
            set(["i", "three", "b"]),
            set(["i", "j"])
        ]
        self.create_var_names = [set(), set(["b"]), set()]

        i = 0
        for node in gast.walk(gast_root):
            if isinstance(node, (gast.While, gast.For)):
                loop_var_names, create_var_names = name_visitor.get_loop_var_names(
                    node)
                self.assertEqual(
                    loop_var_names,
                    self.loop_var_names[i],
                    msg="loop_var_names : {}, \nexpected loop_var_names : {}".
                    format(loop_var_names, self.loop_var_names[i]))
                self.assertEqual(
                    create_var_names,
                    self.create_var_names[i],
                    msg=
                    "i = {}\ncreate_var_names : {}, \nexpected create_var_names : {}"
                    .format(i, create_var_names, self.create_var_names[i]))
                i += 1
Beispiel #2
0
    def visit_Assign(self, node):
        if self._update_class_node_dict(node):
            return None

        for child_node in gast.walk(node.value):
            if isinstance(child_node, gast.Call):
                self._visit_Call(child_node)
        return node
Beispiel #3
0
 def visit_Expr(self, node):
     value_node = node.value
     for child_node in gast.walk(value_node):
         if isinstance(child_node, gast.Call):
             # TODO(liym27):
             #  Considers that a dygraph api which modifies the input or has a output.
             if utils.is_dygraph_api(child_node):
                 return
             else:
                 self._visit_Call(child_node)
     return node
Beispiel #4
0
 def test_loop_vars(self):
     for i in range(len(self.loop_funcs)):
         func = self.loop_funcs[i]
         test_func = inspect.getsource(func)
         gast_root = gast.parse(test_func)
         name_visitor = NameVisitor(gast_root)
         for node in gast.walk(gast_root):
             if isinstance(node, (gast.While, gast.For)):
                 loop_var_names, create_var_names = name_visitor.get_loop_var_names(
                     node)
                 self.assertEqual(loop_var_names, self.loop_var_names[i])
                 self.assertEqual(create_var_names,
                                  self.create_var_names[i])
Beispiel #5
0
        def filter_name_nodes_from(root_node, target_var_names):
            """
            Filter children with gast.Name type from node.(inclusivly)
            """
            name_nodes = set()
            if isinstance(root_node, gast.Name):
                if node.id in target_var_names:
                    name_nodes.add(root_node)
            for child_node in gast.walk(root_node):
                if isinstance(child_node, gast.Name):
                    if child_node.id in target_var_names:
                        name_nodes.add(child_node)

            return name_nodes
    def _transform_var_shape_if_necessary(self, cond):
        need_transformed = False
        for child_node in gast.walk(cond):
            var_shape_node = None
            if isinstance(child_node,
                          (gast.Name, gast.Attribute, gast.Subscript)):
                child_name = ast_to_source_code(child_node).strip()
                if child_name in self.name_to_var_shape:
                    var_shape_node = create_choose_shape_node(
                        child_name, self.name_to_var_shape[child_name])
                elif self._is_var_shape(child_node):
                    var_shape_node = child_node

            if var_shape_node:
                need_transformed = True
                wrapper_node = self.node_to_wrapper_map.get(child_node)
                parent_node = wrapper_node.parent.node
                for field, value in gast.iter_fields(parent_node):
                    if child_node is value:
                        if var_shape_node is child_node:
                            setattr(
                                parent_node, field,
                                create_convert_shape_node(
                                    var_shape_node, None, True))
                        else:
                            setattr(parent_node, field, var_shape_node)
                        break
                    # Some child_node may be in a list such as gast.Compare
                    if isinstance(value, list):
                        has_converted_shape = False
                        for i, v in enumerate(value):
                            if child_node is v:
                                if var_shape_node is child_node:
                                    value[i] = create_convert_shape_node(
                                        var_shape_node, None, True)
                                else:
                                    value[i] = var_shape_node
                                has_converted_shape = True
                                break
                        if has_converted_shape:
                            break
        return need_transformed
Beispiel #7
0
 def _transform_list_append_in_control_flow(self, node):
     for child_node in gast.walk(node):
         if self._need_to_array_write_node(child_node):
             child_node.value = \
                 self._to_array_write_node(child_node.value)
Beispiel #8
0
 def replace_list_with_tensor_array(self, node):
     for child_node in gast.walk(node):
         if isinstance(child_node, gast.Assign):
             if self._need_to_create_tensor_array(child_node):
                 child_node.value = self._create_tensor_array(
                     child_node.value)