def test_suite(self) -> None:
        # Test that we can insert various types of statement suites into a
        # spot accepting a suite.
        module = parse_template_module(
            "if x is True: {suite}\n",
            suite=cst.SimpleStatementSuite(body=(cst.Pass(),),),
        )
        self.assertEqual(
            module.code, "if x is True: pass\n",
        )

        module = parse_template_module(
            "if x is True: {suite}\n",
            suite=cst.IndentedBlock(body=(cst.SimpleStatementLine((cst.Pass(),),),),),
        )
        self.assertEqual(
            module.code, "if x is True:\n    pass\n",
        )

        module = parse_template_module(
            "if x is True:\n    {suite}\n",
            suite=cst.SimpleStatementSuite(body=(cst.Pass(),),),
        )
        self.assertEqual(
            module.code, "if x is True: pass\n",
        )

        module = parse_template_module(
            "if x is True:\n    {suite}\n",
            suite=cst.IndentedBlock(body=(cst.SimpleStatementLine((cst.Pass(),),),),),
        )
        self.assertEqual(
            module.code, "if x is True:\n    pass\n",
        )
Ejemplo n.º 2
0
class ReturnCreateTest(CSTNodeTest):
    @data_provider((
        {
            "node": cst.SimpleStatementLine([cst.Return()]),
            "code": "return\n",
            "expected_position": CodeRange((1, 0), (1, 6)),
        },
        {
            "node": cst.SimpleStatementLine([cst.Return(cst.Name("abc"))]),
            "code": "return abc\n",
            "expected_position": CodeRange((1, 0), (1, 10)),
        },
    ))
    def test_valid(self, **kwargs: Any) -> None:
        self.validate_node(**kwargs)

    @data_provider(({
        "get_node":
        lambda: cst.Return(cst.Name("abc"),
                           whitespace_after_return=cst.SimpleWhitespace("")),
        "expected_re":
        "Must have at least one space after 'return'.",
    }, ))
    def test_invalid(self, **kwargs: Any) -> None:
        self.assert_invalid(**kwargs)
Ejemplo n.º 3
0
    def leave_Module(
        self, original_node: cst.Module, updated_node: cst.Module
    ) -> cst.Module:
        if self.is_generated:
            return original_node
        if not self.toplevel_annotations and not self.imports:
            return updated_node
        toplevel_statements = []
        # First, find the insertion point for imports
        statements_before_imports, statements_after_imports = self._split_module(
            original_node, updated_node
        )

        # Make sure there's at least one empty line before the first non-import
        statements_after_imports = self._insert_empty_line(statements_after_imports)

        imported = set()
        for statement in self.import_statements:
            names = statement.names
            if isinstance(names, cst.ImportStar):
                continue
            for name in names:
                if name.asname:
                    name = name.asname
                if name:
                    imported.add(_get_name_as_string(name.name))

        for _, import_statement in self.imports.items():
            # Filter out anything that has already been imported.
            names = import_statement.names.difference(imported)
            names = [cst.ImportAlias(cst.Name(name)) for name in sorted(names)]
            if not names:
                continue
            import_statement = cst.ImportFrom(
                module=import_statement.module, names=names
            )
            # Add import statements to module body.
            # Need to assign an Iterable, and the argument to SimpleStatementLine
            # must be subscriptable.
            toplevel_statements.append(cst.SimpleStatementLine([import_statement]))

        for name, annotation in self.toplevel_annotations.items():
            annotated_assign = cst.AnnAssign(
                cst.Name(name),
                # pyre-fixme[16]: `CSTNode` has no attribute `annotation`.
                cst.Annotation(annotation.annotation),
                None,
            )
            toplevel_statements.append(cst.SimpleStatementLine([annotated_assign]))

        return updated_node.with_changes(
            body=[
                *statements_before_imports,
                *toplevel_statements,
                *statements_after_imports,
            ]
        )
Ejemplo n.º 4
0
def import_to_node_single(imp: SortableImport,
                          module: cst.Module) -> cst.BaseStatement:
    leading_lines = [
        cst.EmptyLine(indent=True, comment=cst.Comment(line))
        if line.startswith("#") else cst.EmptyLine(indent=False)
        for line in imp.comments.before
    ]

    trailing_whitespace = cst.TrailingWhitespace()
    trailing_comments = list(imp.comments.first_inline)

    names: List[cst.ImportAlias] = []
    for item in imp.items:
        name = name_to_node(item.name)
        asname = cst.AsName(
            name=cst.Name(item.asname)) if item.asname else None
        node = cst.ImportAlias(name=name, asname=asname)
        names.append(node)
        trailing_comments += item.comments.before
        trailing_comments += item.comments.inline
        trailing_comments += item.comments.following

    trailing_comments += imp.comments.final
    trailing_comments += imp.comments.last_inline
    if trailing_comments:
        text = COMMENT_INDENT.join(trailing_comments)
        trailing_whitespace = cst.TrailingWhitespace(
            whitespace=cst.SimpleWhitespace(COMMENT_INDENT),
            comment=cst.Comment(text))

    if imp.stem:
        stem, ndots = split_relative(imp.stem)
        if not stem:
            module_name = None
        else:
            module_name = name_to_node(stem)
        relative = (cst.Dot(), ) * ndots

        line = cst.SimpleStatementLine(
            body=[
                cst.ImportFrom(module=module_name,
                               names=names,
                               relative=relative)
            ],
            leading_lines=leading_lines,
            trailing_whitespace=trailing_whitespace,
        )

    else:
        line = cst.SimpleStatementLine(
            body=[cst.Import(names=names)],
            leading_lines=leading_lines,
            trailing_whitespace=trailing_whitespace,
        )

    return line
Ejemplo n.º 5
0
    def __get_required_imports(self):
        def find_required_modules(all_types):
            req_mod = set()
            for _, a_node in all_types:
                m = match.findall(
                    a_node.annotation,
                    match.Attribute(value=match.DoNotCare(),
                                    attr=match.DoNotCare()))
                if len(m) != 0:
                    for i in m:
                        req_mod.add([
                            n.value for n in match.findall(
                                i, match.Name(value=match.DoNotCare()))
                        ][0])
            return req_mod

        req_imports = []
        all_req_mods = find_required_modules(self.all_applied_types)
        all_type_names = set(
            chain.from_iterable(
                map(lambda t: regex.findall(r"\w+", t[0]),
                    self.all_applied_types)))

        typing_imports = PY_TYPING_MOD & all_type_names
        collection_imports = PY_COLLECTION_MOD & all_type_names

        if len(typing_imports) > 0:
            req_imports.append(
                cst.SimpleStatementLine(body=[
                    cst.ImportFrom(module=cst.Name(value="typing"),
                                   names=[
                                       cst.ImportAlias(name=cst.Name(value=t),
                                                       asname=None)
                                       for t in typing_imports
                                   ]),
                ]))
        if len(collection_imports) > 0:
            req_imports.append(cst.SimpleStatementLine(body=[cst.ImportFrom(module=cst.Name(value="collections"),
                                                       names=[cst.ImportAlias(name=cst.Name(value=t), asname=None) \
                                                              for t in collection_imports]),]))
        if len(all_req_mods) > 0:
            for mod_name in all_req_mods:
                req_imports.append(
                    cst.SimpleStatementLine(body=[
                        cst.Import(names=[
                            cst.ImportAlias(name=cst.Name(value=mod_name),
                                            asname=None)
                        ])
                    ]))

        return req_imports
Ejemplo n.º 6
0
    def leave_Module(
        self, original_node: cst.Module, updated_node: cst.Module
    ) -> cst.Module:
        fresh_class_definitions = [
            definition
            for name, definition in self.annotations.class_definitions.items()
            if name not in self.visited_classes
        ]
        if not self.toplevel_annotations and not fresh_class_definitions:
            return updated_node
        toplevel_statements = []
        # First, find the insertion point for imports
        statements_before_imports, statements_after_imports = self._split_module(
            original_node, updated_node
        )

        # Make sure there's at least one empty line before the first non-import
        statements_after_imports = self._insert_empty_line(statements_after_imports)

        for name, annotation in self.toplevel_annotations.items():
            annotated_assign = cst.AnnAssign(cst.Name(name), annotation, None)
            toplevel_statements.append(cst.SimpleStatementLine([annotated_assign]))

        toplevel_statements.extend(fresh_class_definitions)

        return updated_node.with_changes(
            body=[
                *statements_before_imports,
                *toplevel_statements,
                *statements_after_imports,
            ]
        )
Ejemplo n.º 7
0
def _get_expression_transform(
        before: Callable[..., Any],
        after: Callable[..., Any]) -> ExpressionTransform:
    expression = function_parser.parse(before)[0]
    matchers = function_parser.args_to_matchers(before)
    matcher = craftier.matcher.from_node(expression, matchers)
    inner_matcher = getattr(matcher, "matcher", None)
    if isinstance(matcher, libcst.matchers.DoNotCareSentinel) or isinstance(
            inner_matcher, libcst.matchers.DoNotCareSentinel):
        raise Exception(
            f"DoNotCare matcher is forbidden at top level in `{before.__name__}`"
        )

    after_expression = function_parser.parse(after)[0]
    # Technically this is not correct as some expressions,like function calls,
    # binary operations, name, etc, are wrapped in an `Expr` node.
    module = libcst.Module(body=[
        libcst.SimpleStatementLine(
            body=[cast(libcst.BaseSmallStatement, after_expression)])
    ])
    wrapper = libcst.metadata.MetadataWrapper(module)
    body = cast(libcst.SimpleStatementLine, wrapper.module.body[0])
    replacement = body.body[0]

    return ExpressionTransform(before=matcher,
                               after=replacement,
                               wrapper=wrapper)
Ejemplo n.º 8
0
    def leave_FunctionDef(self, original_node: cst.FunctionDef,
                          updated_node: cst.FunctionDef) -> cst.FunctionDef:
        final_node = updated_node
        if original_node is self.scope:
            suite = updated_node.body
            tail = self.tail
            for name, attr in self.names_to_attr.items():
                state = self.attr_states.get(name, [])
                # default writeback initial value
                state.append(([], cst.Name(name)))
                attr_val = _fold_conditions(_simplify_gaurds(state),
                                            self.strict)
                write = to_stmt(make_assign(attr, attr_val))
                tail.append(write)

            if self.returns:
                strict = self.strict

                try:
                    return_val = _fold_conditions(
                        _simplify_gaurds(self.returns), strict)
                except IncompleteGaurdError:
                    raise SyntaxError(
                        'Cannot prove function always returns') from None
                return_stmt = cst.SimpleStatementLine(
                    [cst.Return(value=return_val)])
                tail.append(return_stmt)

        return final_node
Ejemplo n.º 9
0
    def leave_Module(
        self,
        original_node: cst.Module,
        updated_node: cst.Module,
    ) -> cst.Module:
        fresh_class_definitions = [
            definition
            for name, definition in self.annotations.class_definitions.items()
            if name not in self.visited_classes
        ]

        # NOTE: The entire change will also be abandoned if
        # self.annotation_counts is all 0s, so if adding any new category make
        # sure to record it there.
        if not (self.toplevel_annotations or fresh_class_definitions
                or self.annotations.typevars):
            return updated_node

        toplevel_statements = []
        # First, find the insertion point for imports
        statements_before_imports, statements_after_imports = self._split_module(
            original_node, updated_node)

        # Make sure there's at least one empty line before the first non-import
        statements_after_imports = self._insert_empty_line(
            statements_after_imports)

        for name, annotation in self.toplevel_annotations.items():
            annotated_assign = self._apply_annotation_to_attribute_or_global(
                name=name,
                annotation=annotation,
                value=None,
            )
            toplevel_statements.append(
                cst.SimpleStatementLine([annotated_assign]))

        # TypeVar definitions could be scattered through the file, so do not
        # attempt to put new ones with existing ones, just add them at the top.
        typevars = {
            k: v
            for k, v in self.annotations.typevars.items()
            if k not in self.typevars
        }
        if typevars:
            for var, stmt in typevars.items():
                toplevel_statements.append(cst.Newline())
                toplevel_statements.append(stmt)
                self.annotation_counts.typevars_and_generics_added += 1
            toplevel_statements.append(cst.Newline())

        self.annotation_counts.classes_added = len(fresh_class_definitions)
        toplevel_statements.extend(fresh_class_definitions)

        return updated_node.with_changes(body=[
            *statements_before_imports,
            *toplevel_statements,
            *statements_after_imports,
        ])
Ejemplo n.º 10
0
class NamedExprTest(CSTNodeTest):
    @data_provider(
        (
            {
                "node": cst.BinaryOperation(
                    left=cst.Name("a"),
                    operator=cst.MatrixMultiply(),
                    right=cst.Name("b"),
                ),
                "code": "a @ b",
                "parser": parse_expression_as(python_version="3.8"),
            },
            {
                "node": cst.SimpleStatementLine(
                    body=(
                        cst.AugAssign(
                            target=cst.Name("a"),
                            operator=cst.MatrixMultiplyAssign(),
                            value=cst.Name("b"),
                        ),
                    ),
                ),
                "code": "a @= b\n",
                "parser": parse_statement_as(python_version="3.8"),
            },
        )
    )
    def test_valid(self, **kwargs: Any) -> None:
        self.validate_node(**kwargs)

    @data_provider(
        (
            {
                "code": "a @ b",
                "parser": parse_expression_as(python_version="3.6"),
                "expect_success": True,
            },
            {
                "code": "a @ b",
                "parser": parse_expression_as(python_version="3.3"),
                "expect_success": False,
            },
            {
                "code": "a @= b",
                "parser": parse_statement_as(python_version="3.6"),
                "expect_success": True,
            },
            {
                "code": "a @= b",
                "parser": parse_statement_as(python_version="3.3"),
                "expect_success": False,
            },
        )
    )
    def test_versions(self, **kwargs: Any) -> None:
        if is_native() and not kwargs.get("expect_success", True):
            self.skipTest("parse errors are disabled for native parser")
        self.assert_parses(**kwargs)
Ejemplo n.º 11
0
    def leave_Module(self, original_node: cst.Module,
                     updated_node: cst.Module) -> cst.Module:
        if not self.toplevel_annotations and not self.imports:
            return updated_node

        toplevel_statements = []

        # First, find the insertion point for imports
        statements_before_imports, statements_after_imports = self._split_module(
            original_node, updated_node)

        # Make sure there's at least one empty line before the first non-import
        statements_after_imports = self._insert_empty_line(
            statements_after_imports)

        for _, import_statement in self.imports.items():
            import_statement = cst.ImportFrom(
                module=import_statement.module,
                # pyre-fixme[6]: Expected `Union[Sequence[ImportAlias], ImportStar]`
                #  for 2nd param but got `List[ImportFrom]`.
                names=import_statement.names,
            )
            # Add import statements to module body.
            # Need to assign an Iterable, and the argument to SimpleStatementLine
            # must be subscriptable.
            toplevel_statements.append(
                cst.SimpleStatementLine([import_statement]))

        for name, annotation in self.toplevel_annotations.items():
            annotated_assign = cst.AnnAssign(
                cst.Name(name),
                # pyre-fixme[16]: `CSTNode` has no attribute `annotation`.
                cst.Annotation(annotation.annotation),
                None,
            )
            toplevel_statements.append(
                cst.SimpleStatementLine([annotated_assign]))

        return updated_node.with_changes(body=[
            *statements_before_imports,
            *toplevel_statements,
            *statements_after_imports,
        ])
Ejemplo n.º 12
0
    def leave_Module(self, original_node, updated_node):
        final_node = super().leave_Module(original_node, updated_node)
        imports_str = cst.Module(
            body=[cst.SimpleStatementLine([i]) for i in self.imports]).code
        sorted_imports = cst.parse_module(
            SortImports(file_contents=imports_str).output)

        # Add imports back to the top of the module
        new_body = sorted_imports.body + list(final_node.body)

        return final_node.with_changes(body=new_body)
Ejemplo n.º 13
0
    def test_parse_multiple_statements(self) -> None:
        # pylint: disable=pointless-statement
        def test_function(x, y):
            x
            y

        # pylint: enable=pointless-statement
        self.assert_node_equal(
            function_parser.parse(test_function)[0],
            libcst.SimpleStatementLine(body=[libcst.Expr(libcst.Name("x"))]),
        )
Ejemplo n.º 14
0
    def test_deprecated_non_element_construction(self) -> None:
        module = cst.Module(body=[
            cst.SimpleStatementLine(body=[
                cst.Expr(value=cst.Subscript(
                    value=cst.Name(value="foo"),
                    slice=cst.Index(value=cst.Integer(value="1")),
                ))
            ])
        ])

        self.assertEqual(module.code, "foo[1]\n")
Ejemplo n.º 15
0
 def leave_Module(self, node: cst.Module, updated_node: cst.Module) -> cst.CSTNode:
     body = list(updated_node.body)
     index = self._get_toplevel_index(body)
     for name, annotation in self.toplevel_annotations.items():
         annotated_assign = cst.AnnAssign(
             cst.Name(name),
             # pyre-fixme[16]: `CSTNode` has no attribute `annotation`.
             cst.Annotation(annotation.annotation),
             None,
         )
         body.insert(index, cst.SimpleStatementLine([annotated_assign]))
     return updated_node.with_changes(body=tuple(body))
Ejemplo n.º 16
0
 def on_leave(
     self, original_node: CSTNodeT, updated_node: CSTNodeT
 ) -> Union[CSTNodeT, RemovalSentinel,
            FlattenSentinel[cst.SimpleStatementLine]]:
     if isinstance(updated_node, cst.SimpleStatementLine):
         return FlattenSentinel([
             cst.SimpleStatementLine(
                 [stmt.with_changes(semicolon=cst.MaybeSentinel.DEFAULT)])
             for stmt in updated_node.body
         ])
     else:
         return updated_node
Ejemplo n.º 17
0
    def body(
        self,
    ) -> typing.Iterable[typing.Union[cst.BaseCompoundStatement,
                                      cst.SimpleStatementLine]]:
        yield cst.SimpleStatementLine(
            [cst.ImportFrom(cst.Name("typing"), names=cst.ImportStar())])
        yield from assign_properties(self.properties)

        yield from function_defs(self.function_overloads, self.functions,
                                 "function")

        for name, class_ in sort_items(self.classes):
            yield class_.class_def(name)
Ejemplo n.º 18
0
 def test_statement(self) -> None:
     # Test that we can insert various types of statements into a
     # statement list.
     module = parse_template_module(
         "{statement1}\n{statement2}\n{statement3}\n",
         statement1=cst.If(
             test=cst.Name("foo"), body=cst.SimpleStatementSuite((cst.Pass(),),),
         ),
         statement2=cst.SimpleStatementLine((cst.Expr(cst.Call(cst.Name("bar"))),),),
         statement3=cst.Pass(),
     )
     self.assertEqual(
         module.code, "if foo: pass\nbar()\npass\n",
     )
Ejemplo n.º 19
0
def generate_import(name, obj, func_obj=None, file_imports=None):
    """
    Generate an import statement for a (name, runtime object) pair.
    """
    inliner = ctx_inliner.get()

    # HACK? is this still needed?
    if name == 'self':
        return None

    # If the name is already in scope, don't need to import it
    if name in inliner.base_globls:
        # TODO: name conflicts? e.g. host imports json as x, and
        # another module imports foo as x
        return None

    # If the name appears directly in an import statement in the object's file,
    # then use that import
    if file_imports is not None and name in file_imports:
        return cst.SimpleStatementLine([file_imports[name]])

    # If we're importing a module, then add an import directly
    if inspect.ismodule(obj):
        mod_name = obj.__name__
        return parse_statement(f'import {mod_name} as {name}'
                               if name != mod_name else f'import {mod_name}')
    else:
        # Get module where global is defined
        mod = inspect.getmodule(obj)

        # TODO: When is mod None?
        if mod is None or mod is typing or mod.__name__ == '__main__':
            return None

        # Can't import builtins
        elif mod is __builtins__ or mod is builtins:
            return None

        # If the value is a class or function, then import it from the defining
        # module
        elif inspect.isclass(obj) or inspect.isfunction(obj):
            return parse_statement(f'from {mod.__name__} import {name}')

        # Otherwise import it from the module using the global
        elif func_obj is not None:
            func_mod_name = inspect.getmodule(func_obj).__name__
            if func_mod_name == '__main__':
                return None
            return parse_statement(f'from {func_mod_name} import {name}')
Ejemplo n.º 20
0
 def test_module_config_for_parsing(self) -> None:
     module = parse_module("pass\r")
     statement = parse_statement("if True:\r    pass",
                                 config=module.config_for_parsing)
     self.assertEqual(
         statement,
         cst.If(
             test=cst.Name(value="True"),
             body=cst.IndentedBlock(
                 body=[cst.SimpleStatementLine(body=[cst.Pass()])],
                 header=cst.TrailingWhitespace(newline=cst.Newline(
                     # This would be "\r" if we didn't pass the module config forward.
                     value=None)),
             ),
         ),
     )
Ejemplo n.º 21
0
    def test_deep_replace_identity(self) -> None:
        old_code = """
            pass
        """
        new_code = """
            break
        """

        module = cst.parse_module(dedent(old_code))
        new_module = module.deep_replace(
            module,
            cst.Module(
                header=(cst.EmptyLine(), ),
                body=(cst.SimpleStatementLine(body=(cst.Break(), )), ),
            ),
        )
        self.assertEqual(new_module.code, dedent(new_code))
Ejemplo n.º 22
0
 def function_def(
     self,
     name: str,
     type: typing.Literal["function", "classmethod", "method"],
     indent=0,
     overload=False,
 ) -> cst.FunctionDef:
     decorators: typing.List[cst.Decorator] = []
     if overload:
         decorators.append(cst.Decorator(cst.Name("overload")))
     if type == "classmethod":
         decorators.append(cst.Decorator(cst.Name("classmethod")))
     return cst.FunctionDef(
         cst.Name(name), self.parameters(type),
         cst.IndentedBlock(
             [cst.SimpleStatementLine([s]) for s in self.body(indent)]),
         decorators, self.return_type_annotation)
Ejemplo n.º 23
0
def astNodeFromAssertion(transform: Transform,
                         match: Match,
                         index=0) -> Sequence[cst.CSTNode]:
    # TODO: aggregate intersect possible_nodes_per_prop and and raise on multiple
    # results (some kind of "ambiguity error"). Also need to match with anchor placement
    cur_scope_expr = transform.assertion.nested_scopes[index]
    cur_capture = match.by_name.get(cur_scope_expr.capture.name)
    name = cur_scope_expr.capture.name
    body = ()
    node = BodyType = None

    if cur_capture is not None:
        node = cur_capture.node
        name = elemName(node=node)
        body, BodyType = getNodeBody(node)

    if index < len(transform.assertion.nested_scopes) - 1:
        # inner doesn'make sense for ( and , nesting scopes...
        inner = astNodeFromAssertion(transform, match, index+1)
        if transform.destructive:
            body = [s for s in body if not s.deep_equals(
                match.path[-1].node)]
        body = (*body, *inner)

    BodyType = BodyType or cst.IndentedBlock
    if not body:
        body = [cst.SimpleStatementLine(body=[cst.Pass()])]

    if cur_capture is not None:
        # XXX: perhaps I should use tuples instead of lists to abide by the immutability of cst
        return [node.with_changes(
            name=cst.Name(name),
            **({
                'body': BodyType(body=body),
                # probably need a better way to do this, ideally just ignore excess kwargs
            } if isinstance(node, (cst.FunctionDef, cst.ClassDef)) else {})
        )]
    else:
        unrefed_node = astNodeFrom(
            scope_expr=cur_scope_expr, ctx=transform, match=match)
        # NOTE: need a generic way to "place" the next scope in a node
        if index < len(transform.assertion.nested_scopes) - 1:
            return [unrefed_node.with_changes(body=BodyType(body=body))]
        else:
            return [unrefed_node]
Ejemplo n.º 24
0
 def _get_assert_replacement(self, node: cst.Assert):
     message = node.msg or str(cst.Module(body=[node]).code)
     return cst.If(
         test=cst.UnaryOperation(
             operator=cst.Not(),
             expression=node.test,  # Todo: parenthesize?
         ),
         body=cst.IndentedBlock(body=[
             cst.SimpleStatementLine(body=[
                 cst.Raise(exc=cst.Call(
                     func=cst.Name(value="AssertionError", ),
                     args=[
                         cst.Arg(value=cst.SimpleString(value=repr(message),
                                                        ), ),
                     ],
                 ), ),
             ]),
         ], ),
     )
Ejemplo n.º 25
0
def with_added_imports(
        module_node: cst.Module,
        import_nodes: Sequence[Union[cst.Import,
                                     cst.ImportFrom]]) -> cst.Module:
    """
    Adds new import `import_node` after the first import in the module `module_node`.
    """
    updated_body: List[Union[cst.SimpleStatementLine,
                             cst.BaseCompoundStatement]] = []
    added_import = False
    for line in module_node.body:
        updated_body.append(line)
        if not added_import and _is_import_line(line):
            for import_node in import_nodes:
                updated_body.append(
                    cst.SimpleStatementLine(body=tuple([import_node])))
            added_import = True

    if not added_import:
        raise RuntimeError("Failed to add imports")

    return module_node.with_changes(body=tuple(updated_body))
Ejemplo n.º 26
0
def assign_properties(
        p: typing.Dict[str, typing.Tuple[Metadata, Type]],
        is_classvar=False) -> typing.Iterable[cst.SimpleStatementLine]:
    for name, metadata_and_tp in sort_items(p):
        if bad_name(name):
            continue
        metadata, tp = metadata_and_tp
        ann = tp.annotation
        yield cst.SimpleStatementLine(
            [
                cst.AnnAssign(
                    cst.Name(name),
                    cst.Annotation(
                        cst.Subscript(cst.Name("ClassVar"),
                                      [cst.SubscriptElement(cst.Index(ann))]
                                      ) if is_classvar else ann),
                )
            ],
            leading_lines=[cst.EmptyLine()] + [
                cst.EmptyLine(comment=cst.Comment("# " + l))
                for l in metadata_lines(metadata)
            ],
        )
Ejemplo n.º 27
0
 def type_declaration_statements(
     bindings: UnpackedBindings,
     annotations: UnpackedAnnotations,
     leading_lines: Sequence[cst.EmptyLine],
     quote_annotations: bool,
 ) -> List[cst.SimpleStatementLine]:
     return [
         cst.SimpleStatementLine(
             body=[
                 AnnotationSpreader.type_declaration(
                     binding=binding,
                     raw_annotation=raw_annotation,
                     quote_annotations=quote_annotations,
                 )
             ],
             leading_lines=leading_lines if i == 0 else [],
         )
         for i, (binding, raw_annotation) in enumerate(
             AnnotationSpreader.annotated_bindings(
                 bindings=bindings,
                 annotations=annotations,
             )
         )
     ]
Ejemplo n.º 28
0
def import_to_node_multi(imp: SortableImport,
                         module: cst.Module) -> cst.BaseStatement:
    body: List[cst.BaseSmallStatement] = []
    names: List[cst.ImportAlias] = []
    prev: Optional[cst.ImportAlias] = None
    following: List[str] = []
    lpar_lines: List[cst.EmptyLine] = []
    lpar_inline: cst.TrailingWhitespace = cst.TrailingWhitespace()

    item_count = len(imp.items)
    for idx, item in enumerate(imp.items):
        name = name_to_node(item.name)
        asname = cst.AsName(
            name=cst.Name(item.asname)) if item.asname else None

        # Leading comments actually have to be trailing comments on the previous node.
        # That means putting them on the lpar node for the first item
        if item.comments.before:
            lines = [
                cst.EmptyLine(
                    indent=True,
                    comment=cst.Comment(c),
                    whitespace=cst.SimpleWhitespace(module.default_indent),
                ) for c in item.comments.before
            ]
            if prev is None:
                lpar_lines.extend(lines)
            else:
                prev.comma.whitespace_after.empty_lines.extend(
                    lines)  # type: ignore

        # all items except the last needs whitespace to indent the *next* line/item
        indent = idx != (len(imp.items) - 1)

        first_line = cst.TrailingWhitespace()
        inline = COMMENT_INDENT.join(item.comments.inline)
        if inline:
            first_line = cst.TrailingWhitespace(
                whitespace=cst.SimpleWhitespace(COMMENT_INDENT),
                comment=cst.Comment(inline),
            )

        if idx == item_count - 1:
            following = item.comments.following + imp.comments.final
        else:
            following = item.comments.following

        after = cst.ParenthesizedWhitespace(
            indent=True,
            first_line=first_line,
            empty_lines=[
                cst.EmptyLine(
                    indent=True,
                    comment=cst.Comment(c),
                    whitespace=cst.SimpleWhitespace(module.default_indent),
                ) for c in following
            ],
            last_line=cst.SimpleWhitespace(
                module.default_indent if indent else ""),
        )

        node = cst.ImportAlias(
            name=name,
            asname=asname,
            comma=cst.Comma(whitespace_after=after),
        )
        names.append(node)
        prev = node

    # from foo import (
    #     bar
    # )
    if imp.stem:
        stem, ndots = split_relative(imp.stem)
        if not stem:
            module_name = None
        else:
            module_name = name_to_node(stem)
        relative = (cst.Dot(), ) * ndots

        # inline comment following lparen
        if imp.comments.first_inline:
            inline = COMMENT_INDENT.join(imp.comments.first_inline)
            lpar_inline = cst.TrailingWhitespace(
                whitespace=cst.SimpleWhitespace(COMMENT_INDENT),
                comment=cst.Comment(inline),
            )

        body = [
            cst.ImportFrom(
                module=module_name,
                names=names,
                relative=relative,
                lpar=cst.LeftParen(
                    whitespace_after=cst.ParenthesizedWhitespace(
                        indent=True,
                        first_line=lpar_inline,
                        empty_lines=lpar_lines,
                        last_line=cst.SimpleWhitespace(module.default_indent),
                    ), ),
                rpar=cst.RightParen(),
            )
        ]

    # import foo
    else:
        raise ValueError("can't render basic imports on multiple lines")

    # comment lines above import
    leading_lines = [
        cst.EmptyLine(indent=True, comment=cst.Comment(line))
        if line.startswith("#") else cst.EmptyLine(indent=False)
        for line in imp.comments.before
    ]

    # inline comments following import/rparen
    if imp.comments.last_inline:
        inline = COMMENT_INDENT.join(imp.comments.last_inline)
        trailing = cst.TrailingWhitespace(
            whitespace=cst.SimpleWhitespace(COMMENT_INDENT),
            comment=cst.Comment(inline))
    else:
        trailing = cst.TrailingWhitespace()

    return cst.SimpleStatementLine(
        body=body,
        leading_lines=leading_lines,
        trailing_whitespace=trailing,
    )
Ejemplo n.º 29
0
class WhileTest(CSTNodeTest):
    @data_provider((
        # Simple while block
        # pyre-fixme[6]: Incompatible parameter type
        {
            "node":
            cst.While(cst.Call(cst.Name("iter")),
                      cst.SimpleStatementSuite((cst.Pass(), ))),
            "code":
            "while iter(): pass\n",
            "parser":
            parse_statement,
        },
        # While block with else
        {
            "node":
            cst.While(
                cst.Call(cst.Name("iter")),
                cst.SimpleStatementSuite((cst.Pass(), )),
                cst.Else(cst.SimpleStatementSuite((cst.Pass(), ))),
            ),
            "code":
            "while iter(): pass\nelse: pass\n",
            "parser":
            parse_statement,
        },
        # indentation
        {
            "node":
            DummyIndentedBlock(
                "    ",
                cst.While(
                    cst.Call(cst.Name("iter")),
                    cst.SimpleStatementSuite((cst.Pass(), )),
                ),
            ),
            "code":
            "    while iter(): pass\n",
            "parser":
            None,
            "expected_position":
            CodeRange((1, 4), (1, 22)),
        },
        # while an indented body
        {
            "node":
            DummyIndentedBlock(
                "    ",
                cst.While(
                    cst.Call(cst.Name("iter")),
                    cst.IndentedBlock((cst.SimpleStatementLine(
                        (cst.Pass(), )), )),
                ),
            ),
            "code":
            "    while iter():\n        pass\n",
            "parser":
            None,
            "expected_position":
            CodeRange((1, 4), (2, 12)),
        },
        # leading_lines
        {
            "node":
            cst.While(
                cst.Call(cst.Name("iter")),
                cst.IndentedBlock((cst.SimpleStatementLine((cst.Pass(), )), )),
                leading_lines=(cst.EmptyLine(
                    comment=cst.Comment("# leading comment")), ),
            ),
            "code":
            "# leading comment\nwhile iter():\n    pass\n",
            "parser":
            parse_statement,
            "expected_position":
            CodeRange((2, 0), (3, 8)),
        },
        {
            "node":
            cst.While(
                cst.Call(cst.Name("iter")),
                cst.IndentedBlock((cst.SimpleStatementLine((cst.Pass(), )), )),
                cst.Else(
                    cst.IndentedBlock((cst.SimpleStatementLine(
                        (cst.Pass(), )), )),
                    leading_lines=(cst.EmptyLine(
                        comment=cst.Comment("# else comment")), ),
                ),
                leading_lines=(cst.EmptyLine(
                    comment=cst.Comment("# leading comment")), ),
            ),
            "code":
            "# leading comment\nwhile iter():\n    pass\n# else comment\nelse:\n    pass\n",
            "parser":
            None,
            "expected_position":
            CodeRange((2, 0), (6, 8)),
        },
        # Weird spacing rules
        {
            "node":
            cst.While(
                cst.Call(
                    cst.Name("iter"),
                    lpar=(cst.LeftParen(), ),
                    rpar=(cst.RightParen(), ),
                ),
                cst.SimpleStatementSuite((cst.Pass(), )),
                whitespace_after_while=cst.SimpleWhitespace(""),
            ),
            "code":
            "while(iter()): pass\n",
            "parser":
            parse_statement,
            "expected_position":
            CodeRange((1, 0), (1, 19)),
        },
        # Whitespace
        {
            "node":
            cst.While(
                cst.Call(cst.Name("iter")),
                cst.SimpleStatementSuite((cst.Pass(), )),
                whitespace_after_while=cst.SimpleWhitespace("  "),
                whitespace_before_colon=cst.SimpleWhitespace("  "),
            ),
            "code":
            "while  iter()  : pass\n",
            "parser":
            parse_statement,
            "expected_position":
            CodeRange((1, 0), (1, 21)),
        },
    ))
    def test_valid(self, **kwargs: Any) -> None:
        self.validate_node(**kwargs)

    @data_provider(({
        "get_node":
        lambda: cst.While(
            cst.Call(cst.Name("iter")),
            cst.SimpleStatementSuite((cst.Pass(), )),
            whitespace_after_while=cst.SimpleWhitespace(""),
        ),
        "expected_re":
        "Must have at least one space after 'while' keyword",
    }, ))
    def test_invalid(self, **kwargs: Any) -> None:
        self.assert_invalid(**kwargs)
Ejemplo n.º 30
0
def to_stmt(node: cst.BaseSmallStatement) -> cst.SimpleStatementLine:
    return cst.SimpleStatementLine(body=[node])