示例#1
0
    def _convert_call(self, node, matched_api_name):
        """"Convert the call node."""
        new_node = None
        code = pasta.dump(node)
        api_name = pasta.dump(node.func)
        warning_info = get_prompt_info(matched_api_name)
        if warning_info is None:
            warning_info = ''
        if matched_api_name in ALL_MAPPING:
            logger.info("Line %3d start converting API: %s", node.lineno,
                        api_name)
            new_code = self.mapping_api(node)
            if new_code != code:
                try:
                    new_node = pasta.parse(new_code).body[0].value
                    # find the first call name
                    new_api_name = new_code[:new_code.find('(')]
                    detail_msg = self._get_detail_prompt_msg(node, new_node)
                    if detail_msg:
                        warning_info = detail_msg + ' ' + warning_info
                except AttributeError:
                    new_node = pasta.parse(new_code).body[0]
                    new_api_name = new_code
                self._process_log.info(
                    node.lineno, node.col_offset, LOG_FMT_CONVERT_WITH_TIPS %
                    (api_name, new_api_name, warning_info))
        else:
            logger.warning("Line %3d: found unsupported API: %s%s",
                           node.lineno, api_name, warning_info)
            self._process_log.warning(
                node.lineno, node.col_offset,
                LOG_FMT_NOT_CONVERT % (api_name, warning_info))

        return new_node
示例#2
0
 def test_mixed_tabs_spaces_indentation(self):
     pasta.parse(
         textwrap.dedent('''\
     if a:
             b
     {ONETAB}c
     ''').format(ONETAB='\t'))
示例#3
0
 def test_mixed_tabs_spaces_indentation(self):
     pasta.parse(
         textwrap.dedent("""\
   if a:
           b
   {ONETAB}c
   """).format(ONETAB='\t'), py_ver)
示例#4
0
def convert_dynamic_loss_scale(node):
    """Convert dynamic loss scale related Tensorflow APIs"""
    log_msg(
        getattr(node, 'lineno', 'None'),
        "change tf.train.experimental.DynamicLossScale"
        " to ExponentialUpdateLossScaleManager")
    node.func = ast.Name(id="ExponentialUpdateLossScaleManager",
                         ctx=ast.Load())

    def check_arg(node):
        initial_loss_scale = None
        increment_period = None
        multiplier = None
        for index, arg in enumerate(node.args):
            if index == 0:
                initial_loss_scale = arg
            if index == 1:
                increment_period = arg
            if index == 2:
                multiplier = arg
        for keyword in node.keywords:
            if keyword.arg == "initial_loss_scale":
                keyword.arg = "init_loss_scale"
                initial_loss_scale = keyword
            if keyword.arg == "increment_period":
                keyword.arg = "incr_every_n_steps"
                increment_period = keyword
            if keyword.arg == "multiplier":
                keyword.arg = "incr_ratio"
                multiplier = keyword
        return (initial_loss_scale, increment_period, multiplier)

    (initial_loss_scale, increment_period, multiplier) = check_arg(node)
    if initial_loss_scale:
        if not isinstance(initial_loss_scale, ast.keyword):
            node.keywords.append(
                ast.keyword(arg="init_loss_scale", value=initial_loss_scale))
    else:
        node.keywords.append(
            ast.keyword(arg="init_loss_scale", value=pasta.parse("2**15")))
    if increment_period:
        if not isinstance(increment_period, ast.keyword):
            node.keywords.append(
                ast.keyword(arg="incr_every_n_steps", value=increment_period))
    else:
        node.keywords.append(
            ast.keyword(arg="incr_every_n_steps", value=pasta.parse("2000")))
    if multiplier:
        if not isinstance(multiplier, ast.keyword):
            node.keywords.append(
                ast.keyword(arg="incr_ratio", value=multiplier))
    else:
        node.keywords.append(
            ast.keyword(arg="incr_ratio", value=pasta.parse("2")))
    node.args = []
    util_global.set_value('need_conver', True)
    return node
示例#5
0
        def testReplaceChildInvalid(self):
            src = 'def foo():\n  return 1\nx = 1\n'
            replace_with = pasta.parse('bar()', py_ver).body[0]
            t = pasta.parse(src, py_ver)

            parent = t.body[0]
            node_to_replace = t.body[1]
            with self.assertRaises(errors.InvalidAstError):
                ast_utils.replace_child(parent, node_to_replace, replace_with)
示例#6
0
    def testReplaceChildInBody(self):
        src = 'def foo():\n  a = 0\n  a += 1 # replace this\n  return a\n'
        replace_with = pasta.parse('foo(a + 1)  # trailing comment\n').body[0]
        expected = 'def foo():\n  a = 0\n  foo(a + 1) # replace this\n  return a\n'
        t = pasta.parse(src)

        parent = t.body[0]
        node_to_replace = parent.body[1]
        ast_utils.replace_child(parent, node_to_replace, replace_with)

        self.assertEqual(expected, pasta.dump(t))
示例#7
0
 def test_default_indentation(self):
   for indent in ('  ', '    ', '\t'):
     src ='def a():\n' + indent + 'b\n'
     t = pasta.parse(src)
     t.body.extend(ast.parse('def c(): d').body)
     self.assertEqual(codegen.to_str(t),
                      src + 'def c():\n' + indent + 'd\n')
示例#8
0
def execute_rename(file_path, moved_imports):
    if six.PY2:
        import imp
        import_from_user = imp.load_source('moved_imports', moved_imports)
    else:
        import importlib.util
        spec = importlib.util.spec_from_file_location("moved_imports",
                                                      moved_imports)
        import_from_user = importlib.util.module_from_spec(spec)
        spec.loader.exec_module(import_from_user)

    with open(file_path, mode='r') as file:
        tree = pasta.parse(file.read())
        for class_to_move in import_from_user.imports_to_move:
            old_path = class_to_move[0]
            new_path = class_to_move[1]
            try:
                rename.rename_external(tree, old_path, new_path)
            except ValueError:
                click.ClickException(
                    "Some error happened on the following path: {0}.\n "
                    "While trying to rename from: {1} to {2}".format(
                        file_path, old_path, new_path))
        source_code = pasta.dump(tree)

    with open(file_path, mode='w') as file:
        file.write(source_code)
示例#9
0
 def test_call_no_pos(self):
     """Tests that Call node traversal works without position information."""
     src = 'f(a)'
     t = pasta.parse(src)
     node = ast_utils.find_nodes_by_type(t, (ast.Call, ))[0]
     node.keywords.append(ast.keyword(arg='b', value=ast.Num(n=0)))
     self.assertEqual('f(a, b=0)', pasta.dump(t))
示例#10
0
def convert_loss_scale_api(node):
    """Convert loss scale related Tensorflow APIs"""
    if isinstance(node.func, ast.Attribute):
        if node.func.attr == "FixedLossScale":
            log_msg(
                getattr(node, 'lineno', 'None'),
                "change tf.train.experimental.FixedLossScale"
                " to FixedLossScaleManager")
            node.func = ast.Name(id="FixedLossScaleManager", ctx=ast.Load())
            if len(node.keywords) == 1:
                node.keywords[0].arg = "loss_scale"
            util_global.set_value('need_conver', True)
            return node
        if node.func.attr == "DynamicLossScale":
            return convert_dynamic_loss_scale(node)
        if node.func.attr == "MixedPrecisionLossScaleOptimizer":
            log_msg(
                getattr(node, 'lineno', 'None'),
                "change tf.train.experimental.MixedPrecisionLossScaleOptimizer"
                " to NPULossScaleOptimizer")
            node.func = ast.Name(id="NPULossScaleOptimizer", ctx=ast.Load())
            for keyword in node.keywords:
                if keyword.arg == "loss_scale":
                    keyword.arg = "loss_scale_manager"
            if (len(util_global.get_value("distributed_mode", "")) != 0):
                node.keywords.append(
                    ast.keyword(arg="is_distributed",
                                value=pasta.parse("True")))
            util_global.set_value('need_conver', True)
            return node
示例#11
0
 def test_fstring_escaping(self):
   src = 'f"a {{{b} {{c}}"'
   t = pasta.parse(src)
   node = t.body[0].value
   self.assertEqual(
       fmt.get(node, 'content'),
       'f"a {{{__pasta_fstring_val_0__} {{c}}"')
示例#12
0
 def test_fstring(self):
     src = 'f"a {b} c d {e}"'
     t = pasta.parse(src, py_ver)
     node = t.body[0].value
     self.assertEqual(
         fmt.get(node, 'content'),
         'f"a {__pasta_fstring_val_0__} c d {__pasta_fstring_val_1__}"')
示例#13
0
def main(coverage_file):
    data = coverage.CoverageData()
    data.read_file(coverage_file)

    for filename in data._lines:
        lines = data.lines(filename)
        assert lines is not None
        if not os.path.exists(filename):
            # It could be unlinked before
            continue
        if not lines:
            print(filename, 'not covered, removing')
            os.unlink(filename)
            continue
        with open(filename) as fp:
            tree = pasta.parse(fp.read())
        new_tree = rewrite(tree, lines)

        try:
            to_write = pasta.dump(new_tree)
        except pasta.base.codegen.PrintError:
            print("Error with file", filename)
            continue

        with open(filename, 'w') as fp:
            fp.write(to_write)
示例#14
0
def rename_file(file_path, path_to_moved_imports_file):
    """
    Iterates over the content of a file, looking for imports to be changed

    :param str file_path:
        Path of the file being parsed.
    :param str path_to_moved_imports_file:
        Path of the file with the list of changed imports.
    """
    list_with_moved_imports = _get_list_of_moved_imports(
        path_to_moved_imports_file)

    with open(file_path, mode='r') as file:
        tree = pasta.parse(file.read())
        for old_path, new_path in list_with_moved_imports:
            try:
                rename.rename_external(tree, old_path, new_path)
            except ValueError:
                raise click.ClickException(
                    "An error has occurred on the following path: {0} ,\n "
                    "while trying to rename from: {1} to {2}".format(
                        file_path, old_path, new_path))
        source_code = pasta.dump(tree)

    with open(file_path, mode='w') as file:
        file.write(source_code)
示例#15
0
    def update_string_pasta(self, text, in_filename):
        """Updates a file using pasta."""
        try:
            t = pasta.parse(text)
        except (SyntaxError, ValueError, TypeError):
            log = ["ERROR: Failed to parse.\n" + traceback.format_exc()]
            return 0, "", log, []

        preprocess_logs, preprocess_errors = self._api_change_spec.preprocess(
            t)

        visitor = _PastaEditVisitor(self._api_change_spec)
        visitor.visit(t)

        self._api_change_spec.clear_preprocessing()

        logs = [
            self.format_log(log, None)
            for log in (preprocess_logs + visitor.log)
        ]
        errors = [
            self.format_log(error, in_filename)
            for error in (preprocess_errors + visitor.warnings_and_errors)
        ]
        return 1, pasta.dump(t), logs, errors
示例#16
0
    def test_indent_levels(self):
        src = textwrap.dedent('''\
        foo('begin')
        if a:
          foo('a1')
          if b:
            foo('b1')
            if c:
              foo('c1')
            foo('b2')
          foo('a2')
        foo('end')
        ''')
        t = pasta.parse(src)
        call_nodes = ast_utils.find_nodes_by_type(t, (ast.Call, ))
        call_nodes.sort(key=lambda node: node.lineno)
        begin, a1, b1, c1, b2, a2, end = call_nodes

        self.assertEqual('', fmt.get(begin, 'indent'))
        self.assertEqual('  ', fmt.get(a1, 'indent'))
        self.assertEqual('    ', fmt.get(b1, 'indent'))
        self.assertEqual('      ', fmt.get(c1, 'indent'))
        self.assertEqual('    ', fmt.get(b2, 'indent'))
        self.assertEqual('  ', fmt.get(a2, 'indent'))
        self.assertEqual('', fmt.get(end, 'indent'))
示例#17
0
 def test_indent_levels_same_line(self):
     src = 'if a: b; c\n'
     t = pasta.parse(src)
     if_node = t.body[0]
     b, c = if_node.body
     self.assertIsNone(fmt.get(b, 'indent_diff'))
     self.assertIsNone(fmt.get(c, 'indent_diff'))
示例#18
0
    def _update_base_name(self, class_def_scope):
        """
        Update base name of class.

        Args:
            class_def_scope (ast.ClassDef): Class definition node.
        """
        base_name_mapping = APIAnalysisSpec.base_name_mapping
        class_def_node = class_def_scope.node
        base_class_nodes = class_def_scope.node.bases
        # update base class name
        for base_class_node in base_class_nodes:
            base_name = base_class_node.attr
            if base_name in APIAnalysisSpec.get_network_base_class_names():
                old_code = pasta.dump(base_class_node)
                if base_name in base_name_mapping:
                    new_code = 'nn.' + base_name_mapping[base_class_node.attr]
                    new_node = pasta.parse(new_code)
                    pasta.ast_utils.replace_child(class_def_node,
                                                  base_class_node, new_node)
                    self._process_log.info(
                        base_class_node.lineno, base_class_node.col_offset,
                        LOG_FMT_CONVERT % (old_code, new_code))
                else:
                    self._process_log.info(
                        base_class_node.lineno, base_class_node.col_offset,
                        LOG_FMT_NOT_CONVERT % (old_code, ''))
示例#19
0
    def _read_input_file(self):
        """Reads input file and parses it as an abstract syntax tree (AST).

        Returns:
            ast.Module: AST representation of the input file.
        """
        with open(self.input_path) as input_file:
            return pasta.parse(input_file.read())
示例#20
0
    def test_args(self):
      src = """
def func():
  offset_multi = lambda *a: foo(*a)
  add_multi = lambda *a, **k: bar(*a, **k)"""
      t = pasta.parse(src, py_ver)
      print(pasta.dump(t, py_ver))
      self.assertEqual(src, pasta.dump(t, py_ver))
示例#21
0
    def testRemoveAlias(self):
        src = "from a import b, c"
        tree = pasta.parse(src)
        import_node = tree.body[0]
        alias1 = import_node.names[0]
        ast_utils.remove_child(import_node, alias1)

        self.assertEqual(pasta.dump(tree), "from a import c")
示例#22
0
  def test_scope_trailing_comma(self):
    template = 'def foo(a, b{trailing_comma}): pass'
    for trailing_comma in ('', ',', ' , '):
      tree = pasta.parse(template.format(trailing_comma=trailing_comma))
      self.assertEqual(trailing_comma.lstrip(' ') + ')',
                       fmt.get(tree.body[0], 'args_suffix'))

    template = 'class Foo(a, b{trailing_comma}): pass'
    for trailing_comma in ('', ',', ' , '):
      tree = pasta.parse(template.format(trailing_comma=trailing_comma))
      self.assertEqual(trailing_comma.lstrip(' ') + ')',
                       fmt.get(tree.body[0], 'bases_suffix'))

    template = 'from mod import (a, b{trailing_comma})'
    for trailing_comma in ('', ',', ' , '):
      tree = pasta.parse(template.format(trailing_comma=trailing_comma))
      self.assertEqual(trailing_comma + ')',
                       fmt.get(tree.body[0], 'names_suffix'))
示例#23
0
 def test_tabs_below_spaces_and_tab(self):
   for num_spaces in range(1, 8):
     t = pasta.parse(textwrap.dedent('''\
         if a:
         {WS}{ONETAB}if b:
         {ONETAB}{ONETAB}c
         ''').format(ONETAB='\t', WS=' ' * num_spaces))
     node_c = t.body[0].body[0].body[0]
     self.assertEqual(fmt.get(node_c, 'indent_diff'), '\t')
示例#24
0
        def test_indent_extra_newlines(self):
            src = textwrap.dedent("""\
          if a:

            b
          """)
            t = pasta.parse(src, py_ver)
            if_node = t.body[0]
            b = if_node.body[0]
            self.assertEqual('  ', fmt.get(b, 'indent_diff'))
示例#25
0
 def test_tab_below_spaces(self):
     for num_spaces in range(1, 8):
         t = pasta.parse(
             textwrap.dedent("""\
     if a:
     {WS}if b:
     {ONETAB}c
     """).format(ONETAB='\t', WS=' ' * num_spaces), py_ver)
         node_c = t.body[0].body[0].body[0]
         self.assertEqual(fmt.get(node_c, 'indent_diff'),
                          ' ' * (8 - num_spaces))
示例#26
0
    def test_indent_extra_newlines_with_comment(self):
        src = textwrap.dedent('''\
        if a:
            #not here

          b
        ''')
        t = pasta.parse(src)
        if_node = t.body[0]
        b = if_node.body[0]
        self.assertEqual('  ', fmt.get(b, 'indent_diff'))
示例#27
0
 def test_indent_multiline_string_with_newline(self):
     src = textwrap.dedent('''\
   class A:
     """Doc\n
        string."""
     pass
   ''')
     t = pasta.parse(src, py_ver)
     docstring, pass_stmt = t.body[0].body
     self.assertEqual('  ', fmt.get(docstring, 'indent'))
     self.assertEqual('  ', fmt.get(pass_stmt, 'indent'))
示例#28
0
    def test_call_illegal_pos(self):
        """Tests that Call node traversal works even with illegal positions."""
        src = 'f(a)'
        t = pasta.parse(src)
        node = ast_utils.find_nodes_by_type(t, (ast.Call, ))[0]
        node.keywords.append(ast.keyword(arg='b', value=ast.Num(n=0)))

        # This position would put b=0 before a, so it should be ignored.
        node.keywords[-1].value.lineno = 0
        node.keywords[-1].value.col_offset = 0

        self.assertEqual('f(a, b=0)', pasta.dump(t))
示例#29
0
def conver_ast(path, out_path_dst, file_name):
    util_global.set_value('need_conver', False)
    util_global.set_value('is_keras_net', False)
    util_global.set_value('has_hccl_api', False)
    util_global.set_value('is_main_file', False)
    util_global.set_value('has_main_func', False)
    if os.path.join(path, file_name) == util_global.get_value('main', ""):
        util_global.set_value('is_main_file', True)
    with open(os.path.join(path, file_name), "r", encoding='utf-8') as file:
        source = file.read()
    try:
        r_node = pasta.parse(source)
    except Exception as e:
        print(repr(e))
        return

    sys.setrecursionlimit(10000)
    visitor = ConverByAst()
    visitor.visit(r_node)
    ast.fix_missing_locations(r_node)

    (api, lineno) = get_tf_api(os.path.join(path, file_name))
    if len(api) == 0:
        print(
            "No Tensorflow module is imported in script {}.".format(file_name))
    scan_file(path, file_name, api, lineno)

    if util_global.get_value('need_conver', False):
        insert_npu_import(r_node)
        if not util_global.get_value('has_main_func', False) and (
                util_global.get_value('has_hccl_api', False)
                or util_global.get_value('is_keras_net', False)):
            log_warning(
                'the network of keras and horovod, or using dataset.shard script do not have main func, '
                'should set -m or --main parameter')
        if util_global.get_value('is_main_file',
                                 False) and util_global.get_value(
                                     'has_hccl_api', False):
            insert_npu_resource_init(r_node)
            insert_npu_resource_shutdown(r_node)
        if util_global.get_value('is_main_file',
                                 False) and util_global.get_value(
                                     'is_keras_net', False):
            insert_keras_sess_npu_config(r_node)
            insert_keras_sess_close(r_node)
        dst_content = pasta.dump(r_node)
        write_output_after_conver(
            os.path.join(util_global.get_value('output'), out_path_dst,
                         file_name), dst_content)

    if file_name.endswith("a.py"):
        write_report_after_conver("only_for_test", file_name,
                                  node_tree(ast.dump(r_node)))
示例#30
0
        def test_block_suffix(self):
            src_tpl = textwrap.dedent("""\
          {open_block}
            pass #a
            #b
              #c

            #d
          #e
          a
          """)
            test_cases = (
                # first: attribute of the node with the last block
                # second: code snippet to open a block
                ('body', 'def x():'),
                ('body', 'class X:'),
                ('body', 'if x:'),
                ('orelse', 'if x:\n  y\nelse:'),
                ('body', 'if x:\n  y\nelif y:'),
                ('body', 'while x:'),
                ('orelse', 'while x:\n  y\nelse:'),
                ('finalbody', 'try:\n  x\nfinally:'),
                ('body', 'try:\n  x\nexcept:'),
                ('orelse', 'try:\n  x\nexcept:\n  y\nelse:'),
                ('body', 'with x:'),
                ('body', 'with x, y:'),
                ('body', 'with x:\n with y:'),
                ('body', 'for x in y:'),
            )

            def is_node_for_suffix(node, children_attr):
                # Return True if this node contains the 'pass' statement
                val = getattr(node, children_attr, None)
                return isinstance(val, list) and (type(val[0]) == ast27.Pass
                                                  or type(val[0]) == ast3.Pass)

            for children_attr, open_block in test_cases:
                src = src_tpl.format(open_block=open_block)
                t = pasta.parse(src, py_ver)
                node_finder = ast_utils.get_find_node_visitor(
                    lambda node: is_node_for_suffix(node, children_attr),
                    py_ver)
                node_finder.visit(t)
                node = node_finder.results[0]
                expected = '  #b\n    #c\n\n  #d\n'
                actual = str(fmt.get(node, 'block_suffix_%s' % children_attr))
                self.assertMultiLineEqual(
                    expected, actual,
                    'Incorrect suffix for code:\n%s\nNode: %s (line %d)\nDiff:\n%s'
                    % (src, node, node.lineno, '\n'.join(
                        _get_diff(actual, expected))))
                self.assertMultiLineEqual(src, pasta.dump(t, py_ver))
示例#31
0
  def update_string_pasta(self, text, in_filename):
    """Updates a file using pasta."""
    try:
      t = pasta.parse(text)
    except (SyntaxError, ValueError, TypeError):
      log = "Failed to parse.\n\n" + traceback.format_exc()
      return 0, "", log, []

    visitor = _PastaEditVisitor(self._api_change_spec)
    visitor.visit(t)

    errors = self._format_errors(visitor.errors, in_filename)
    return 1, pasta.dump(t), visitor.log_text(), errors
示例#32
0
  def update_string_pasta(self, text, in_filename):
    """Updates a file using pasta."""
    try:
      t = pasta.parse(text)
    except (SyntaxError, ValueError, TypeError):
      log = ["ERROR: Failed to parse.\n" + traceback.format_exc()]
      return 0, "", log, []

    visitor = _PastaEditVisitor(self._api_change_spec)
    visitor.visit(t)

    logs = [self.format_log(log, None) for log in visitor.log]
    errors = [self.format_log(error, in_filename)
              for error in visitor.warnings_and_errors]
    return 1, pasta.dump(t), logs, errors