def _convert_call(self, node, matched_api_name): """"Convert the call node.""" new_node = None code = pasta.dump(node) api_name = pasta.dump(node.func) warning_info = get_prompt_info(matched_api_name) if warning_info is None: warning_info = '' if matched_api_name in ALL_MAPPING: logger.info("Line %3d start converting API: %s", node.lineno, api_name) new_code = self.mapping_api(node) if new_code != code: try: new_node = pasta.parse(new_code).body[0].value # find the first call name new_api_name = new_code[:new_code.find('(')] detail_msg = self._get_detail_prompt_msg(node, new_node) if detail_msg: warning_info = detail_msg + ' ' + warning_info except AttributeError: new_node = pasta.parse(new_code).body[0] new_api_name = new_code self._process_log.info( node.lineno, node.col_offset, LOG_FMT_CONVERT_WITH_TIPS % (api_name, new_api_name, warning_info)) else: logger.warning("Line %3d: found unsupported API: %s%s", node.lineno, api_name, warning_info) self._process_log.warning( node.lineno, node.col_offset, LOG_FMT_NOT_CONVERT % (api_name, warning_info)) return new_node
def test_mixed_tabs_spaces_indentation(self): pasta.parse( textwrap.dedent('''\ if a: b {ONETAB}c ''').format(ONETAB='\t'))
def test_mixed_tabs_spaces_indentation(self): pasta.parse( textwrap.dedent("""\ if a: b {ONETAB}c """).format(ONETAB='\t'), py_ver)
def convert_dynamic_loss_scale(node): """Convert dynamic loss scale related Tensorflow APIs""" log_msg( getattr(node, 'lineno', 'None'), "change tf.train.experimental.DynamicLossScale" " to ExponentialUpdateLossScaleManager") node.func = ast.Name(id="ExponentialUpdateLossScaleManager", ctx=ast.Load()) def check_arg(node): initial_loss_scale = None increment_period = None multiplier = None for index, arg in enumerate(node.args): if index == 0: initial_loss_scale = arg if index == 1: increment_period = arg if index == 2: multiplier = arg for keyword in node.keywords: if keyword.arg == "initial_loss_scale": keyword.arg = "init_loss_scale" initial_loss_scale = keyword if keyword.arg == "increment_period": keyword.arg = "incr_every_n_steps" increment_period = keyword if keyword.arg == "multiplier": keyword.arg = "incr_ratio" multiplier = keyword return (initial_loss_scale, increment_period, multiplier) (initial_loss_scale, increment_period, multiplier) = check_arg(node) if initial_loss_scale: if not isinstance(initial_loss_scale, ast.keyword): node.keywords.append( ast.keyword(arg="init_loss_scale", value=initial_loss_scale)) else: node.keywords.append( ast.keyword(arg="init_loss_scale", value=pasta.parse("2**15"))) if increment_period: if not isinstance(increment_period, ast.keyword): node.keywords.append( ast.keyword(arg="incr_every_n_steps", value=increment_period)) else: node.keywords.append( ast.keyword(arg="incr_every_n_steps", value=pasta.parse("2000"))) if multiplier: if not isinstance(multiplier, ast.keyword): node.keywords.append( ast.keyword(arg="incr_ratio", value=multiplier)) else: node.keywords.append( ast.keyword(arg="incr_ratio", value=pasta.parse("2"))) node.args = [] util_global.set_value('need_conver', True) return node
def testReplaceChildInvalid(self): src = 'def foo():\n return 1\nx = 1\n' replace_with = pasta.parse('bar()', py_ver).body[0] t = pasta.parse(src, py_ver) parent = t.body[0] node_to_replace = t.body[1] with self.assertRaises(errors.InvalidAstError): ast_utils.replace_child(parent, node_to_replace, replace_with)
def testReplaceChildInBody(self): src = 'def foo():\n a = 0\n a += 1 # replace this\n return a\n' replace_with = pasta.parse('foo(a + 1) # trailing comment\n').body[0] expected = 'def foo():\n a = 0\n foo(a + 1) # replace this\n return a\n' t = pasta.parse(src) parent = t.body[0] node_to_replace = parent.body[1] ast_utils.replace_child(parent, node_to_replace, replace_with) self.assertEqual(expected, pasta.dump(t))
def test_default_indentation(self): for indent in (' ', ' ', '\t'): src ='def a():\n' + indent + 'b\n' t = pasta.parse(src) t.body.extend(ast.parse('def c(): d').body) self.assertEqual(codegen.to_str(t), src + 'def c():\n' + indent + 'd\n')
def execute_rename(file_path, moved_imports): if six.PY2: import imp import_from_user = imp.load_source('moved_imports', moved_imports) else: import importlib.util spec = importlib.util.spec_from_file_location("moved_imports", moved_imports) import_from_user = importlib.util.module_from_spec(spec) spec.loader.exec_module(import_from_user) with open(file_path, mode='r') as file: tree = pasta.parse(file.read()) for class_to_move in import_from_user.imports_to_move: old_path = class_to_move[0] new_path = class_to_move[1] try: rename.rename_external(tree, old_path, new_path) except ValueError: click.ClickException( "Some error happened on the following path: {0}.\n " "While trying to rename from: {1} to {2}".format( file_path, old_path, new_path)) source_code = pasta.dump(tree) with open(file_path, mode='w') as file: file.write(source_code)
def test_call_no_pos(self): """Tests that Call node traversal works without position information.""" src = 'f(a)' t = pasta.parse(src) node = ast_utils.find_nodes_by_type(t, (ast.Call, ))[0] node.keywords.append(ast.keyword(arg='b', value=ast.Num(n=0))) self.assertEqual('f(a, b=0)', pasta.dump(t))
def convert_loss_scale_api(node): """Convert loss scale related Tensorflow APIs""" if isinstance(node.func, ast.Attribute): if node.func.attr == "FixedLossScale": log_msg( getattr(node, 'lineno', 'None'), "change tf.train.experimental.FixedLossScale" " to FixedLossScaleManager") node.func = ast.Name(id="FixedLossScaleManager", ctx=ast.Load()) if len(node.keywords) == 1: node.keywords[0].arg = "loss_scale" util_global.set_value('need_conver', True) return node if node.func.attr == "DynamicLossScale": return convert_dynamic_loss_scale(node) if node.func.attr == "MixedPrecisionLossScaleOptimizer": log_msg( getattr(node, 'lineno', 'None'), "change tf.train.experimental.MixedPrecisionLossScaleOptimizer" " to NPULossScaleOptimizer") node.func = ast.Name(id="NPULossScaleOptimizer", ctx=ast.Load()) for keyword in node.keywords: if keyword.arg == "loss_scale": keyword.arg = "loss_scale_manager" if (len(util_global.get_value("distributed_mode", "")) != 0): node.keywords.append( ast.keyword(arg="is_distributed", value=pasta.parse("True"))) util_global.set_value('need_conver', True) return node
def test_fstring_escaping(self): src = 'f"a {{{b} {{c}}"' t = pasta.parse(src) node = t.body[0].value self.assertEqual( fmt.get(node, 'content'), 'f"a {{{__pasta_fstring_val_0__} {{c}}"')
def test_fstring(self): src = 'f"a {b} c d {e}"' t = pasta.parse(src, py_ver) node = t.body[0].value self.assertEqual( fmt.get(node, 'content'), 'f"a {__pasta_fstring_val_0__} c d {__pasta_fstring_val_1__}"')
def main(coverage_file): data = coverage.CoverageData() data.read_file(coverage_file) for filename in data._lines: lines = data.lines(filename) assert lines is not None if not os.path.exists(filename): # It could be unlinked before continue if not lines: print(filename, 'not covered, removing') os.unlink(filename) continue with open(filename) as fp: tree = pasta.parse(fp.read()) new_tree = rewrite(tree, lines) try: to_write = pasta.dump(new_tree) except pasta.base.codegen.PrintError: print("Error with file", filename) continue with open(filename, 'w') as fp: fp.write(to_write)
def rename_file(file_path, path_to_moved_imports_file): """ Iterates over the content of a file, looking for imports to be changed :param str file_path: Path of the file being parsed. :param str path_to_moved_imports_file: Path of the file with the list of changed imports. """ list_with_moved_imports = _get_list_of_moved_imports( path_to_moved_imports_file) with open(file_path, mode='r') as file: tree = pasta.parse(file.read()) for old_path, new_path in list_with_moved_imports: try: rename.rename_external(tree, old_path, new_path) except ValueError: raise click.ClickException( "An error has occurred on the following path: {0} ,\n " "while trying to rename from: {1} to {2}".format( file_path, old_path, new_path)) source_code = pasta.dump(tree) with open(file_path, mode='w') as file: file.write(source_code)
def update_string_pasta(self, text, in_filename): """Updates a file using pasta.""" try: t = pasta.parse(text) except (SyntaxError, ValueError, TypeError): log = ["ERROR: Failed to parse.\n" + traceback.format_exc()] return 0, "", log, [] preprocess_logs, preprocess_errors = self._api_change_spec.preprocess( t) visitor = _PastaEditVisitor(self._api_change_spec) visitor.visit(t) self._api_change_spec.clear_preprocessing() logs = [ self.format_log(log, None) for log in (preprocess_logs + visitor.log) ] errors = [ self.format_log(error, in_filename) for error in (preprocess_errors + visitor.warnings_and_errors) ] return 1, pasta.dump(t), logs, errors
def test_indent_levels(self): src = textwrap.dedent('''\ foo('begin') if a: foo('a1') if b: foo('b1') if c: foo('c1') foo('b2') foo('a2') foo('end') ''') t = pasta.parse(src) call_nodes = ast_utils.find_nodes_by_type(t, (ast.Call, )) call_nodes.sort(key=lambda node: node.lineno) begin, a1, b1, c1, b2, a2, end = call_nodes self.assertEqual('', fmt.get(begin, 'indent')) self.assertEqual(' ', fmt.get(a1, 'indent')) self.assertEqual(' ', fmt.get(b1, 'indent')) self.assertEqual(' ', fmt.get(c1, 'indent')) self.assertEqual(' ', fmt.get(b2, 'indent')) self.assertEqual(' ', fmt.get(a2, 'indent')) self.assertEqual('', fmt.get(end, 'indent'))
def test_indent_levels_same_line(self): src = 'if a: b; c\n' t = pasta.parse(src) if_node = t.body[0] b, c = if_node.body self.assertIsNone(fmt.get(b, 'indent_diff')) self.assertIsNone(fmt.get(c, 'indent_diff'))
def _update_base_name(self, class_def_scope): """ Update base name of class. Args: class_def_scope (ast.ClassDef): Class definition node. """ base_name_mapping = APIAnalysisSpec.base_name_mapping class_def_node = class_def_scope.node base_class_nodes = class_def_scope.node.bases # update base class name for base_class_node in base_class_nodes: base_name = base_class_node.attr if base_name in APIAnalysisSpec.get_network_base_class_names(): old_code = pasta.dump(base_class_node) if base_name in base_name_mapping: new_code = 'nn.' + base_name_mapping[base_class_node.attr] new_node = pasta.parse(new_code) pasta.ast_utils.replace_child(class_def_node, base_class_node, new_node) self._process_log.info( base_class_node.lineno, base_class_node.col_offset, LOG_FMT_CONVERT % (old_code, new_code)) else: self._process_log.info( base_class_node.lineno, base_class_node.col_offset, LOG_FMT_NOT_CONVERT % (old_code, ''))
def _read_input_file(self): """Reads input file and parses it as an abstract syntax tree (AST). Returns: ast.Module: AST representation of the input file. """ with open(self.input_path) as input_file: return pasta.parse(input_file.read())
def test_args(self): src = """ def func(): offset_multi = lambda *a: foo(*a) add_multi = lambda *a, **k: bar(*a, **k)""" t = pasta.parse(src, py_ver) print(pasta.dump(t, py_ver)) self.assertEqual(src, pasta.dump(t, py_ver))
def testRemoveAlias(self): src = "from a import b, c" tree = pasta.parse(src) import_node = tree.body[0] alias1 = import_node.names[0] ast_utils.remove_child(import_node, alias1) self.assertEqual(pasta.dump(tree), "from a import c")
def test_scope_trailing_comma(self): template = 'def foo(a, b{trailing_comma}): pass' for trailing_comma in ('', ',', ' , '): tree = pasta.parse(template.format(trailing_comma=trailing_comma)) self.assertEqual(trailing_comma.lstrip(' ') + ')', fmt.get(tree.body[0], 'args_suffix')) template = 'class Foo(a, b{trailing_comma}): pass' for trailing_comma in ('', ',', ' , '): tree = pasta.parse(template.format(trailing_comma=trailing_comma)) self.assertEqual(trailing_comma.lstrip(' ') + ')', fmt.get(tree.body[0], 'bases_suffix')) template = 'from mod import (a, b{trailing_comma})' for trailing_comma in ('', ',', ' , '): tree = pasta.parse(template.format(trailing_comma=trailing_comma)) self.assertEqual(trailing_comma + ')', fmt.get(tree.body[0], 'names_suffix'))
def test_tabs_below_spaces_and_tab(self): for num_spaces in range(1, 8): t = pasta.parse(textwrap.dedent('''\ if a: {WS}{ONETAB}if b: {ONETAB}{ONETAB}c ''').format(ONETAB='\t', WS=' ' * num_spaces)) node_c = t.body[0].body[0].body[0] self.assertEqual(fmt.get(node_c, 'indent_diff'), '\t')
def test_indent_extra_newlines(self): src = textwrap.dedent("""\ if a: b """) t = pasta.parse(src, py_ver) if_node = t.body[0] b = if_node.body[0] self.assertEqual(' ', fmt.get(b, 'indent_diff'))
def test_tab_below_spaces(self): for num_spaces in range(1, 8): t = pasta.parse( textwrap.dedent("""\ if a: {WS}if b: {ONETAB}c """).format(ONETAB='\t', WS=' ' * num_spaces), py_ver) node_c = t.body[0].body[0].body[0] self.assertEqual(fmt.get(node_c, 'indent_diff'), ' ' * (8 - num_spaces))
def test_indent_extra_newlines_with_comment(self): src = textwrap.dedent('''\ if a: #not here b ''') t = pasta.parse(src) if_node = t.body[0] b = if_node.body[0] self.assertEqual(' ', fmt.get(b, 'indent_diff'))
def test_indent_multiline_string_with_newline(self): src = textwrap.dedent('''\ class A: """Doc\n string.""" pass ''') t = pasta.parse(src, py_ver) docstring, pass_stmt = t.body[0].body self.assertEqual(' ', fmt.get(docstring, 'indent')) self.assertEqual(' ', fmt.get(pass_stmt, 'indent'))
def test_call_illegal_pos(self): """Tests that Call node traversal works even with illegal positions.""" src = 'f(a)' t = pasta.parse(src) node = ast_utils.find_nodes_by_type(t, (ast.Call, ))[0] node.keywords.append(ast.keyword(arg='b', value=ast.Num(n=0))) # This position would put b=0 before a, so it should be ignored. node.keywords[-1].value.lineno = 0 node.keywords[-1].value.col_offset = 0 self.assertEqual('f(a, b=0)', pasta.dump(t))
def conver_ast(path, out_path_dst, file_name): util_global.set_value('need_conver', False) util_global.set_value('is_keras_net', False) util_global.set_value('has_hccl_api', False) util_global.set_value('is_main_file', False) util_global.set_value('has_main_func', False) if os.path.join(path, file_name) == util_global.get_value('main', ""): util_global.set_value('is_main_file', True) with open(os.path.join(path, file_name), "r", encoding='utf-8') as file: source = file.read() try: r_node = pasta.parse(source) except Exception as e: print(repr(e)) return sys.setrecursionlimit(10000) visitor = ConverByAst() visitor.visit(r_node) ast.fix_missing_locations(r_node) (api, lineno) = get_tf_api(os.path.join(path, file_name)) if len(api) == 0: print( "No Tensorflow module is imported in script {}.".format(file_name)) scan_file(path, file_name, api, lineno) if util_global.get_value('need_conver', False): insert_npu_import(r_node) if not util_global.get_value('has_main_func', False) and ( util_global.get_value('has_hccl_api', False) or util_global.get_value('is_keras_net', False)): log_warning( 'the network of keras and horovod, or using dataset.shard script do not have main func, ' 'should set -m or --main parameter') if util_global.get_value('is_main_file', False) and util_global.get_value( 'has_hccl_api', False): insert_npu_resource_init(r_node) insert_npu_resource_shutdown(r_node) if util_global.get_value('is_main_file', False) and util_global.get_value( 'is_keras_net', False): insert_keras_sess_npu_config(r_node) insert_keras_sess_close(r_node) dst_content = pasta.dump(r_node) write_output_after_conver( os.path.join(util_global.get_value('output'), out_path_dst, file_name), dst_content) if file_name.endswith("a.py"): write_report_after_conver("only_for_test", file_name, node_tree(ast.dump(r_node)))
def test_block_suffix(self): src_tpl = textwrap.dedent("""\ {open_block} pass #a #b #c #d #e a """) test_cases = ( # first: attribute of the node with the last block # second: code snippet to open a block ('body', 'def x():'), ('body', 'class X:'), ('body', 'if x:'), ('orelse', 'if x:\n y\nelse:'), ('body', 'if x:\n y\nelif y:'), ('body', 'while x:'), ('orelse', 'while x:\n y\nelse:'), ('finalbody', 'try:\n x\nfinally:'), ('body', 'try:\n x\nexcept:'), ('orelse', 'try:\n x\nexcept:\n y\nelse:'), ('body', 'with x:'), ('body', 'with x, y:'), ('body', 'with x:\n with y:'), ('body', 'for x in y:'), ) def is_node_for_suffix(node, children_attr): # Return True if this node contains the 'pass' statement val = getattr(node, children_attr, None) return isinstance(val, list) and (type(val[0]) == ast27.Pass or type(val[0]) == ast3.Pass) for children_attr, open_block in test_cases: src = src_tpl.format(open_block=open_block) t = pasta.parse(src, py_ver) node_finder = ast_utils.get_find_node_visitor( lambda node: is_node_for_suffix(node, children_attr), py_ver) node_finder.visit(t) node = node_finder.results[0] expected = ' #b\n #c\n\n #d\n' actual = str(fmt.get(node, 'block_suffix_%s' % children_attr)) self.assertMultiLineEqual( expected, actual, 'Incorrect suffix for code:\n%s\nNode: %s (line %d)\nDiff:\n%s' % (src, node, node.lineno, '\n'.join( _get_diff(actual, expected)))) self.assertMultiLineEqual(src, pasta.dump(t, py_ver))
def update_string_pasta(self, text, in_filename): """Updates a file using pasta.""" try: t = pasta.parse(text) except (SyntaxError, ValueError, TypeError): log = "Failed to parse.\n\n" + traceback.format_exc() return 0, "", log, [] visitor = _PastaEditVisitor(self._api_change_spec) visitor.visit(t) errors = self._format_errors(visitor.errors, in_filename) return 1, pasta.dump(t), visitor.log_text(), errors
def update_string_pasta(self, text, in_filename): """Updates a file using pasta.""" try: t = pasta.parse(text) except (SyntaxError, ValueError, TypeError): log = ["ERROR: Failed to parse.\n" + traceback.format_exc()] return 0, "", log, [] visitor = _PastaEditVisitor(self._api_change_spec) visitor.visit(t) logs = [self.format_log(log, None) for log in visitor.log] errors = [self.format_log(error, in_filename) for error in visitor.warnings_and_errors] return 1, pasta.dump(t), logs, errors