예제 #1
0
 def test_annotate_ancestry(self) -> None:
     """ Tests that `annotate_ancestry` properly decorates """
     node = Module(
         body=[
             AnnAssign(
                 annotation=Name(
                     "str",
                     Load(),
                 ),
                 simple=1,
                 target=Name("dataset_name", Store()),
                 value=set_value("~/tensorflow_datasets"),
                 expr=None,
                 expr_target=None,
                 expr_annotation=None,
             ),
             Assign(annotation=None,
                    simple=1,
                    targets=[Name("epochs", Store())],
                    value=set_value("333"),
                    expr=None,
                    expr_target=None,
                    expr_annotation=None,
                    **maybe_type_comment),
         ],
         stmt=None,
     )
     self.assertFalse(hasattr(node.body[0], "_location"))
     self.assertFalse(hasattr(node.body[1], "_location"))
     annotate_ancestry(node)
     self.assertEqual(node.body[0]._location, ["dataset_name"])
     self.assertEqual(node.body[1]._location, ["epochs"])
예제 #2
0
def ast_parse(
    source,
    filename="<unknown>",
    mode="exec",
    skip_annotate=False,
    skip_docstring_remit=False,
):
    """
    Convert the AST input to Python source string

    :param source: Python source
    :type  source: ```str```

    :param filename: Filename being parsed
    :type filename: ```str```

    :param mode: 'exec' to compile a module, 'single' to compile a, single (interactive) statement,
      or 'eval' to compile an expression.
    :type mode: ```Literal['exec', 'single', 'eval']```

    :param skip_annotate: Don't run `annotate_ancestry`
    :type skip_annotate: ```bool```

    :param skip_docstring_remit: Don't parse & emit the docstring as a replacement for current docstring
    :type skip_docstring_remit: ```bool```

    :returns: AST node
    :rtype: node: ```AST```
    """
    parsed_ast = parse(source, filename=filename, mode=mode)
    if not skip_annotate:
        annotate_ancestry(parsed_ast)
    if not skip_docstring_remit and isinstance(
            parsed_ast, (Module, ClassDef, FunctionDef, AsyncFunctionDef)):
        docstring = get_docstring(parsed_ast)
        if docstring is None:
            return parsed_ast

        # Reindent docstring
        parsed_ast.body[0].value.value = "\n{tab}{docstring}\n{tab}".format(
            tab=tab, docstring=reindent(docstring))
    return parsed_ast
예제 #3
0
    def test_from_class_with_body_in_method_to_method_with_body(self) -> None:
        """ Tests if this can make the roundtrip from a full function to a full function """
        annotate_ancestry(class_with_method_and_body_types_ast)

        function_def = next(
            filter(
                rpartial(isinstance, FunctionDef),
                class_with_method_and_body_types_ast.body,
            )
        )
        # Reindent docstring
        function_def.body[0].value.value = "\n{tab}{docstring}\n{tab}".format(
            tab=tab, docstring=reindent(ast.get_docstring(function_def))
        )

        ir = parse.function(
            find_in_ast(
                "C.function_name".split("."),
                class_with_method_and_body_types_ast,
            ),
        )
        gen_ast = emit.function(
            ir,
            emit_default_doc=False,
            function_name="function_name",
            function_type="self",
            indent_level=1,
            emit_separating_tab=True,
            emit_as_kwonlyargs=False,
        )

        # emit.file(gen_ast, os.path.join(os.path.dirname(__file__), "delme.py"), mode="wt")

        run_ast_test(
            self,
            gen_ast=gen_ast,
            gold=function_def,
        )
예제 #4
0
 def test_find_in_ast_self(self) -> None:
     """ Tests that `find_in_ast` successfully finds itself in AST """
     run_ast_test(self, find_in_ast(["ConfigClass"], class_ast), class_ast)
     module = Module(body=[], type_ignores=[], stmt=None)
     run_ast_test(self, find_in_ast([], module), module)
     module_with_fun = Module(
         body=[
             FunctionDef(
                 name="call_peril",
                 args=arguments(
                     args=[],
                     defaults=[],
                     kw_defaults=[],
                     kwarg=None,
                     kwonlyargs=[],
                     posonlyargs=[],
                     vararg=None,
                     arg=None,
                 ),
                 body=[],
                 decorator_list=[],
                 lineno=None,
                 arguments_args=None,
                 identifier_name=None,
                 stmt=None,
             )
         ],
         stmt=None,
     )
     annotate_ancestry(module_with_fun)
     run_ast_test(
         self,
         find_in_ast(["call_peril"], module_with_fun),
         module_with_fun.body[0],
         skip_black=True,
     )
예제 #5
0
def sync_property(
    input_eval,
    input_param,
    input_ast,
    input_filename,
    output_param,
    output_param_wrap,
    output_ast,
):
    """
    Sync a single property

    :param input_eval: Whether to evaluate the `param`, or just leave it
    :type input_eval: ```bool```

    :param input_param: Location within file of property.
       Can be top level like `'a'` for `a=5` or with the `.` syntax as in `output_params`.
    :type input_param: ```List[str]```

    :param input_ast: AST of the input file
    :type input_ast: ```AST```

    :param input_filename: Filename of the input (used in `eval`)
    :type input_filename: ```str```

    :param output_param: Parameters to update. E.g., `'A.F'` for `class A: F = None`, `'f.g'` for `def f(g): pass`
    :type output_param: ```str```

    :param output_param_wrap: Wrap all input_str params with this. E.g., `Optional[Union[{output_param}, str]]`
    :param output_param_wrap: ```Optional[str]```

    :param output_ast: AST of the input file
    :type output_ast: ```AST```

    :return: New AST derived from `output_ast`
    :rtype: ```AST```
    """
    search = list(strip_split(output_param, "."))
    if input_eval:
        if input_param.count(".") != 0:
            raise NotImplementedError("Anything not on the top-level of the module")

        local = {}
        output = eval(compile(input_ast, filename=input_filename, mode="exec"), local)
        assert output is None
        replacement_node = ast.AnnAssign(
            annotation=it2literal(local[input_param]),
            simple=1,
            target=ast.Name(
                # input_param
                search[-1],
                ast.Store(),
            ),
            value=None,
            expr=None,
            expr_annotation=None,
            expr_target=None,
        )
    else:
        annotate_ancestry(input_ast)
        assert isinstance(input_ast, ast.Module)
        replacement_node = find_in_ast(list(strip_split(input_param, ".")), input_ast)

    assert replacement_node is not None
    if output_param_wrap is not None:
        if hasattr(replacement_node, "annotation"):
            if replacement_node.annotation is not None:
                replacement_node.annotation = (
                    ast.parse(
                        output_param_wrap.format(
                            output_param=to_code(replacement_node.annotation)
                        )
                    )
                    .body[0]
                    .value
                )
        else:
            raise NotImplementedError(type(replacement_node).__name__)

    rewrite_at_query = RewriteAtQuery(
        search=search,
        replacement_node=replacement_node,
    )

    gen_ast = rewrite_at_query.visit(output_ast)
    assert rewrite_at_query.replaced is True, "Failed to update with {!r}".format(
        to_code(replacement_node)
    )
    return gen_ast