def test_modify_node():
    node = ast_call("S3Downloader.download(session=sess)")
    modifier = renamed_params.S3SessionRenamer()
    modifier.modify_node(node)

    expected = "S3Downloader.download(sagemaker_session=sess)"
    assert expected == pasta.dump(node)
Exemple #2
0
def test_model_modify_node():
    node = ast_call("TensorFlowModel(image=my_image)")
    modifier = renamed_params.ModelImageURIRenamer()
    modifier.modify_node(node)

    expected = "TensorFlowModel(image_uri=my_image)"
    assert expected == pasta.dump(node)
def test_modify_node_set_model_dir_and_image_name(retrieve_image_uri,
                                                  boto_session):
    boto_session.return_value.region_name = REGION_NAME

    tf_constructors = (
        "TensorFlow()",
        "TensorFlow(script_mode=False)",
        "TensorFlow(model_dir='s3//bucket/model')",
    )
    modifier = tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader()

    for constructor in tf_constructors:
        node = ast_call(constructor)
        modifier.modify_node(node)

        assert "TensorFlow(image_uri='{}', model_dir=False)".format(
            IMAGE_URI) == pasta.dump(node)
        retrieve_image_uri.assert_called_with(
            "tensorflow",
            REGION_NAME,
            instance_type="ml.m4.xlarge",
            version="1.11.0",
            py_version="py2",
            image_scope="training",
        )
Exemple #4
0
 def test_add_existing_import_normal_import_aliased(self):
     tree = ast.parse('import a.b.c as d')
     self.assertEqual(
         'a.b', import_utils.add_import(tree, 'a.b', from_import=False))
     self.assertEqual(
         'd', import_utils.add_import(tree, 'a.b.c', from_import=False))
     self.assertEqual('import a.b\nimport a.b.c as d\n', pasta.dump(tree))
 def test_add_existing_import_aliased_with_asname(self):
     tree = pasta.ast_parse('from a.b import c as d', py_ver)
     self.assertEqual(
         'd', import_utils.add_import(tree, 'a.b.c', py_ver,
                                      asname='e'))
     self.assertEqual('from a.b import c as d\n',
                      pasta.dump(tree, py_ver))
 def test_add_existing_import_normal_import(self):
     tree = pasta.ast_parse('import a.b.c', py_ver)
     self.assertEqual(
         'a.b',
         import_utils.add_import(tree, 'a.b', py_ver,
                                 from_import=False))
     self.assertEqual('import a.b.c\n', pasta.dump(tree, py_ver))
 def test_add_import_with_asname_with_conflict(self):
     tree = ast.parse('def c(): pass\n')
     self.assertEqual(
         'c_1',
         import_utils.add_import(tree, 'a.b', asname='c', from_import=True))
     self.assertEqual('from a import b as c_1\ndef c():\n  pass\n',
                      pasta.dump(tree))
Exemple #8
0
    def _replace_external_reference(self):
        """
        Replace external reference statements.

        Returns:
            dict, key is external name, value is the new replaced node.
        """
        all_name_mappings = APIAnalysisSpec.import_name_mapping
        names_replaced_with = dict()
        for ref_info in self._code_analyzer.external_references.values():
            external_ref_info = ref_info['external_ref_info']
            import_node = ref_info['parent_node']
            if import_node is None:
                continue
            code = self._dump_without_prefix(import_node)
            import_parent_node = self._code_analyzer.root_scope.parent(import_node)
            # replace import with new name
            if external_ref_info.name in APIAnalysisSpec.get_convertible_external_names():
                external_ref_info = ref_info['external_ref_info']
                if external_ref_info.name in all_name_mappings.keys():
                    replace_info = all_name_mappings[external_ref_info.name]
                    new_node = self._make_import(name_to_import=replace_info[0], as_name=replace_info[1])
                    new_code = pasta.dump(new_node)
                    pasta.ast_utils.replace_child(import_parent_node, import_node, new_node)
                    names_replaced_with.update({external_ref_info.name: new_node})
                    self._process_log.info(import_node.lineno, import_node.col_offset, LOG_FMT_CONVERT %
                                           (code.strip(), new_code.strip()))
            elif external_ref_info.name.startswith('torch.'):
                self._process_log.warning(import_node.lineno, import_node.col_offset, LOG_FMT_NOT_CONVERT %
                                          (code.strip(), LOG_SUGGESTION_MANUAL_CONVERT))
            else:
                pass
        return names_replaced_with
Exemple #9
0
def main(coverage_file):
    data = coverage.CoverageData()
    data.read_file(coverage_file)

    for filename in data._lines:
        lines = data.lines(filename)
        assert lines is not None
        if not os.path.exists(filename):
            # It could be unlinked before
            continue
        if not lines:
            print(filename, 'not covered, removing')
            os.unlink(filename)
            continue
        with open(filename) as fp:
            tree = pasta.parse(fp.read())
        new_tree = rewrite(tree, lines)

        try:
            to_write = pasta.dump(new_tree)
        except pasta.base.codegen.PrintError:
            print("Error with file", filename)
            continue

        with open(filename, 'w') as fp:
            fp.write(to_write)
Exemple #10
0
def test_estimator_modify_node():
    node = ast_call("TensorFlow(image_name=my_image)")
    modifier = renamed_params.EstimatorImageURIRenamer()
    modifier.modify_node(node)

    expected = "TensorFlow(image_uri=my_image)"
    assert expected == pasta.dump(node)
Exemple #11
0
def test_import_check_and_modify_node_random_import():
    modifier = tfs.TensorFlowServingImportRenamer()

    import_statement = "import random"
    node = ast_import(import_statement)
    modifier.check_and_modify_node(node)
    assert import_statement == pasta.dump(node)
Exemple #12
0
 def test_add_existing_import(self):
   tree = ast.Module(body=[
       ast.ImportFrom(level=0, module='a.b',
                      names=[ast.alias(name='c', asname=None)])
   ])
   self.assertIsNone(import_utils.add_import(tree, 'a.b.c'))
   self.assertEqual('from a.b import c\n', pasta.dump(tree))
Exemple #13
0
  def visit_Attribute(self, node):  # pylint: disable=invalid-name
    """Handle bare Attributes i.e. [tf.foo, tf.bar]."""
    assert self._stack[-1] is node

    full_name = self._get_full_name(node)
    if full_name:
      parent = self._stack[-2]

      # Make sure the warning comes first, otherwise the name may have changed
      self._maybe_add_warning(node, full_name)

      # Once we did a modification, node is invalid and not worth inspecting
      # further. Also, we only perform modifications for simple nodes, so
      # There'd be no point in descending further.
      if self._maybe_rename(parent, node, full_name):
        return
      if self._maybe_change_to_function_call(parent, node, full_name):
        return

      # The isinstance check is enough -- a bare Attribute is never root.
      i = 2
      while isinstance(self._stack[-i], ast.Attribute):
        i += 1
      whole_name = pasta.dump(self._stack[-(i-1)])

      self._maybe_add_module_deprecation_warning(node, full_name, whole_name)

    self.generic_visit(node)
Exemple #14
0
 def test_call_no_pos(self):
     """Tests that Call node traversal works without position information."""
     src = 'f(a)'
     t = pasta.parse(src)
     node = ast_utils.find_nodes_by_type(t, (ast.Call, ))[0]
     node.keywords.append(ast.keyword(arg='b', value=ast.Num(n=0)))
     self.assertEqual('f(a, b=0)', pasta.dump(t))
def test_arg_order_modify_node():
    model_config_calls = (
        ("model_config(instance_type, model)", "model_config(model, instance_type=instance_type)"),
        (
            "model_config('ml.m4.xlarge', 'my-model')",
            "model_config('my-model', instance_type='ml.m4.xlarge')",
        ),
        (
            "model_config('ml.m4.xlarge', model='my-model')",
            "model_config(instance_type='ml.m4.xlarge', model='my-model')",
        ),
        (
            "model_config_from_estimator(instance_type, estimator, task_id, task_type)",
            "model_config_from_estimator(estimator, task_id, task_type, instance_type=instance_type)",
        ),
        (
            "model_config_from_estimator(instance_type, estimator, task_id=task_id, task_type=task_type)",
            "model_config_from_estimator(estimator, instance_type=instance_type, task_id=task_id, task_type=task_type)",
        ),
    )

    modifier = airflow.ModelConfigArgModifier()

    for call, expected in model_config_calls:
        node = ast_call(call)
        modifier.modify_node(node)
        assert expected == pasta.dump(node)
Exemple #16
0
  def visit_Attribute(self, node):  # pylint: disable=invalid-name
    """Handle bare Attributes i.e. [tf.foo, tf.bar]."""
    assert self._stack[-1] is node

    full_name = self._get_full_name(node)
    if full_name:
      parent = self._stack[-2]

      # Make sure the warning comes first, otherwise the name may have changed
      self._maybe_add_warning(node, full_name)

      # Once we did a modification, node is invalid and not worth inspecting
      # further. Also, we only perform modifications for simple nodes, so
      # There'd be no point in descending further.
      if self._maybe_rename(parent, node, full_name):
        return
      if self._maybe_change_to_function_call(parent, node, full_name):
        return

      # The isinstance check is enough -- a bare Attribute is never root.
      i = 2
      while isinstance(self._stack[-i], ast.Attribute):
        i += 1
      whole_name = pasta.dump(self._stack[-(i-1)])

      self._maybe_add_module_deprecation_warning(node, full_name, whole_name)

    self.generic_visit(node)
Exemple #17
0
    def update_string_pasta(self, text, in_filename):
        """Updates a file using pasta."""
        try:
            t = pasta.parse(text)
        except (SyntaxError, ValueError, TypeError):
            log = ["ERROR: Failed to parse.\n" + traceback.format_exc()]
            return 0, "", log, []

        t, preprocess_logs, preprocess_errors = self._api_change_spec.preprocess(
            t)

        visitor = _PastaEditVisitor(self._api_change_spec)
        visitor.visit(t)

        self._api_change_spec.clear_preprocessing()

        logs = [
            self.format_log(log, None)
            for log in (preprocess_logs + visitor.log)
        ]
        errors = [
            self.format_log(error, in_filename)
            for error in (preprocess_errors + visitor.warnings_and_errors)
        ]
        return 1, pasta.dump(t), logs, errors
def test_modify_node():
    node = ast_call("estimator.create_model(image=my_image)")
    modifier = renamed_params.EstimatorCreateModelImageURIRenamer()
    modifier.modify_node(node)

    expected = "estimator.create_model(image_uri=my_image)"
    assert expected == pasta.dump(node)
Exemple #19
0
    def visit_Call(self, node):
        """Callback function when visit AST tree"""
        code = pasta.dump(node)
        api_name = pasta.dump(node.func)

        # The parent node first call is equal to this node, skip when parent node is replaced.
        # This scenario occurs, for example, when out.view(out.size(0), -1) is first converted to
        # P.Reshape()(out, (out.size(0). -1)), will skip P.Reshape() in following visiting.
        # Access from the penultimate element in reverse order.
        for parent_node in self._stack[-2::-1]:
            if parent_node in self._new_call_nodes and pasta.dump(
                    parent_node).startswith(api_name):
                return
        parent = self._stack[-2]
        new_node = None
        new_code = code
        matched_api_name, match_case = self.match_api(
            node.func, self._is_forward_function)
        if match_case in [
                ApiMatchingEnum.API_INFER, ApiMatchingEnum.API_MATCHED
        ]:
            new_node = self._convert_call(node, matched_api_name)
        elif match_case in [
                ApiMatchingEnum.API_STANDARD, ApiMatchingEnum.API_FOUND
        ]:
            self._process_log.warning(node.lineno, node.col_offset,
                                      LOG_FMT_NOT_CONVERT % (api_name, ''))
        else:
            pass

        if parent and new_node:
            update_line_col = _LineColEditVisitor()
            update_line_col.update(new_node, node)
            pasta.ast_utils.replace_child(parent, node, new_node)
            self._new_call_nodes.append(new_node)

            node = new_node
            self._stack[-1] = node
        try:
            self.generic_visit(node)
        except Exception:
            logger.error('original code:%s, new code:%s',
                         code,
                         new_code,
                         exc_info=True)
            raise
Exemple #20
0
    def testRemoveAlias(self):
        src = "from a import b, c"
        tree = pasta.parse(src)
        import_node = tree.body[0]
        alias1 = import_node.names[0]
        ast_utils.remove_child(import_node, alias1)

        self.assertEqual(pasta.dump(tree), "from a import c")
Exemple #21
0
  def test_merge_from_import(self):
    tree = ast.Module(body=[
        ast.ImportFrom(level=0, module='a.b',
                       names=[ast.alias(name='c', asname=None)]),
    ])

    # x is explicitly not merged
    self.assertEqual('x', import_utils.add_import(tree, 'a.b.x',
                                                  merge_from_imports=False))
    self.assertEqual('from a.b import x\nfrom a.b import c\n',
                     pasta.dump(tree))

    # y is allowed to be merged and is grouped into the first matching import
    self.assertEqual('y', import_utils.add_import(tree, 'a.b.y',
                                                  merge_from_imports=True))
    self.assertEqual('from a.b import x, y\nfrom a.b import c\n',
                     pasta.dump(tree))
Exemple #22
0
 def _get_detail_prompt_msg(old_node, new_node):
     """Get detail converted prompt information."""
     msg = None
     if isinstance(old_node, ast.Call) and isinstance(new_node, ast.Call):
         old_api_name = pasta.dump(old_node.func)
         new_api_name = pasta.dump(new_node.func)
         if new_api_name == old_api_name:
             old_parameter_num = len(old_node.args) + len(old_node.keywords)
             new_parameter_num = len(new_node.args) + len(new_node.keywords)
             if old_parameter_num > 1:
                 msg = 'Parameters are converted.'
             else:
                 if old_parameter_num == 0 and new_parameter_num == 0:
                     msg = 'The API name is converted to mindspore API'
                 else:
                     msg = 'Parameter is converted.'
     return msg
Exemple #23
0
def test_modify_node():
    node = ast_call(
        "TensorFlow(distributions={'parameter_server': {'enabled': True}})")
    modifier = renamed_params.DistributionParameterRenamer()
    modifier.modify_node(node)

    expected = "TensorFlow(distribution={'parameter_server': {'enabled': True}})"
    assert expected == pasta.dump(node)
Exemple #24
0
 def _dump_without_prefix(node):
     """Get the python source for an AST."""
     pos = 0
     source_prefix = pasta.base.formatting.get(node, 'prefix')
     if source_prefix:
         pos = len(source_prefix)
     source_code = pasta.dump(node)
     return source_code[pos:]
Exemple #25
0
def test_import_from_modify_node(import_statements):
    modifier = image_uris.ImageURIRetrieveImportFromRenamer()
    expected_result = "from sagemaker import image_uris"

    for import_statement in import_statements:
        node = ast_import(import_statement)
        modifier.modify_node(node)
        assert expected_result == pasta.dump(node)
    def test_merge_from_import(self):
        tree = ast.parse('from a.b import c')

        # x is explicitly not merged
        self.assertEqual(
            'x',
            import_utils.add_import(tree, 'a.b.x', merge_from_imports=False))
        self.assertEqual('from a.b import x\nfrom a.b import c\n',
                         pasta.dump(tree))

        # y is allowed to be merged and is grouped into the first matching import
        self.assertEqual(
            'y', import_utils.add_import(tree,
                                         'a.b.y',
                                         merge_from_imports=True))
        self.assertEqual('from a.b import x, y\nfrom a.b import c\n',
                         pasta.dump(tree))
 def test_add_normal_import_with_asname(self):
     tree = ast.parse('')
     self.assertEqual(
         'd',
         import_utils.add_import(tree,
                                 'a.b.c',
                                 asname='d',
                                 from_import=False))
     self.assertEqual('import a.b.c as d\n', pasta.dump(tree))
def test_create_endpoint_modify_node():
    modifier = renamed_params.SessionCreateEndpointImageURIRenamer()

    for template in CREATE_ENDPOINT_TEMPLATES:
        call = ast_call(template.format("deployment_image=my_image"))
        modifier.modify_node(call)

        expected = template.format("image_uri=my_image")
        assert expected == pasta.dump(call)
def test_create_model_modify_node():
    modifier = renamed_params.SessionCreateModelImageURIRenamer()

    for template in CREATE_MODEL_TEMPLATES:
        call = ast_call(template.format("primary_container_image=my_image"))
        modifier.modify_node(call)

        expected = template.format("image_uri=my_image")
        assert expected == pasta.dump(call)
 def test_add_single_name_from_import_with_asname(self):
     tree = ast.parse('')
     self.assertEqual(
         'bar',
         import_utils.add_import(tree,
                                 'foo',
                                 asname='bar',
                                 from_import=True))
     self.assertEqual('import foo as bar\n', pasta.dump(tree))
 def test_add_from_import_with_asname(self):
     tree = ast.parse('')
     self.assertEqual(
         'd',
         import_utils.add_import(tree,
                                 'a.b.c',
                                 asname='d',
                                 from_import=True))
     self.assertEqual('from a.b import c as d\n', pasta.dump(tree))
Exemple #32
0
  def update_string_pasta(self, text, in_filename):
    """Updates a file using pasta."""
    try:
      t = pasta.parse(text)
    except (SyntaxError, ValueError, TypeError):
      log = "Failed to parse.\n\n" + traceback.format_exc()
      return 0, "", log, []

    visitor = _PastaEditVisitor(self._api_change_spec)
    visitor.visit(t)

    errors = self._format_errors(visitor.errors, in_filename)
    return 1, pasta.dump(t), visitor.log_text(), errors
Exemple #33
0
  def update_string_pasta(self, text, in_filename):
    """Updates a file using pasta."""
    try:
      t = pasta.parse(text)
    except (SyntaxError, ValueError, TypeError):
      log = ["ERROR: Failed to parse.\n" + traceback.format_exc()]
      return 0, "", log, []

    visitor = _PastaEditVisitor(self._api_change_spec)
    visitor.visit(t)

    logs = [self.format_log(log, None) for log in visitor.log]
    errors = [self.format_log(error, in_filename)
              for error in visitor.warnings_and_errors]
    return 1, pasta.dump(t), logs, errors