def test_nested_loop_vars(self): func = self.nested_for_loop_func test_func = inspect.getsource(func) gast_root = gast.parse(test_func) name_visitor = NameVisitor(gast_root) self.loop_var_names = [ set(["j", "two"]), set(["i", "three", "b"]), set(["i", "j"]) ] self.create_var_names = [set(), set(["b"]), set()] i = 0 for node in gast.walk(gast_root): if isinstance(node, (gast.While, gast.For)): loop_var_names, create_var_names = name_visitor.get_loop_var_names( node) self.assertEqual( loop_var_names, self.loop_var_names[i], msg="loop_var_names : {}, \nexpected loop_var_names : {}". format(loop_var_names, self.loop_var_names[i])) self.assertEqual( create_var_names, self.create_var_names[i], msg= "i = {}\ncreate_var_names : {}, \nexpected create_var_names : {}" .format(i, create_var_names, self.create_var_names[i])) i += 1
def visit_Assign(self, node): if self._update_class_node_dict(node): return None for child_node in gast.walk(node.value): if isinstance(child_node, gast.Call): self._visit_Call(child_node) return node
def visit_Expr(self, node): value_node = node.value for child_node in gast.walk(value_node): if isinstance(child_node, gast.Call): # TODO(liym27): # Considers that a dygraph api which modifies the input or has a output. if utils.is_dygraph_api(child_node): return else: self._visit_Call(child_node) return node
def test_loop_vars(self): for i in range(len(self.loop_funcs)): func = self.loop_funcs[i] test_func = inspect.getsource(func) gast_root = gast.parse(test_func) name_visitor = NameVisitor(gast_root) for node in gast.walk(gast_root): if isinstance(node, (gast.While, gast.For)): loop_var_names, create_var_names = name_visitor.get_loop_var_names( node) self.assertEqual(loop_var_names, self.loop_var_names[i]) self.assertEqual(create_var_names, self.create_var_names[i])
def filter_name_nodes_from(root_node, target_var_names): """ Filter children with gast.Name type from node.(inclusivly) """ name_nodes = set() if isinstance(root_node, gast.Name): if node.id in target_var_names: name_nodes.add(root_node) for child_node in gast.walk(root_node): if isinstance(child_node, gast.Name): if child_node.id in target_var_names: name_nodes.add(child_node) return name_nodes
def _transform_var_shape_if_necessary(self, cond): need_transformed = False for child_node in gast.walk(cond): var_shape_node = None if isinstance(child_node, (gast.Name, gast.Attribute, gast.Subscript)): child_name = ast_to_source_code(child_node).strip() if child_name in self.name_to_var_shape: var_shape_node = create_choose_shape_node( child_name, self.name_to_var_shape[child_name]) elif self._is_var_shape(child_node): var_shape_node = child_node if var_shape_node: need_transformed = True wrapper_node = self.node_to_wrapper_map.get(child_node) parent_node = wrapper_node.parent.node for field, value in gast.iter_fields(parent_node): if child_node is value: if var_shape_node is child_node: setattr( parent_node, field, create_convert_shape_node( var_shape_node, None, True)) else: setattr(parent_node, field, var_shape_node) break # Some child_node may be in a list such as gast.Compare if isinstance(value, list): has_converted_shape = False for i, v in enumerate(value): if child_node is v: if var_shape_node is child_node: value[i] = create_convert_shape_node( var_shape_node, None, True) else: value[i] = var_shape_node has_converted_shape = True break if has_converted_shape: break return need_transformed
def _transform_list_append_in_control_flow(self, node): for child_node in gast.walk(node): if self._need_to_array_write_node(child_node): child_node.value = \ self._to_array_write_node(child_node.value)
def replace_list_with_tensor_array(self, node): for child_node in gast.walk(node): if isinstance(child_node, gast.Assign): if self._need_to_create_tensor_array(child_node): child_node.value = self._create_tensor_array( child_node.value)