def test_suite(self) -> None: # Test that we can insert various types of statement suites into a # spot accepting a suite. module = parse_template_module( "if x is True: {suite}\n", suite=cst.SimpleStatementSuite(body=(cst.Pass(),),), ) self.assertEqual( module.code, "if x is True: pass\n", ) module = parse_template_module( "if x is True: {suite}\n", suite=cst.IndentedBlock(body=(cst.SimpleStatementLine((cst.Pass(),),),),), ) self.assertEqual( module.code, "if x is True:\n pass\n", ) module = parse_template_module( "if x is True:\n {suite}\n", suite=cst.SimpleStatementSuite(body=(cst.Pass(),),), ) self.assertEqual( module.code, "if x is True: pass\n", ) module = parse_template_module( "if x is True:\n {suite}\n", suite=cst.IndentedBlock(body=(cst.SimpleStatementLine((cst.Pass(),),),),), ) self.assertEqual( module.code, "if x is True:\n pass\n", )
class ReturnCreateTest(CSTNodeTest): @data_provider(( { "node": cst.SimpleStatementLine([cst.Return()]), "code": "return\n", "expected_position": CodeRange((1, 0), (1, 6)), }, { "node": cst.SimpleStatementLine([cst.Return(cst.Name("abc"))]), "code": "return abc\n", "expected_position": CodeRange((1, 0), (1, 10)), }, )) def test_valid(self, **kwargs: Any) -> None: self.validate_node(**kwargs) @data_provider(({ "get_node": lambda: cst.Return(cst.Name("abc"), whitespace_after_return=cst.SimpleWhitespace("")), "expected_re": "Must have at least one space after 'return'.", }, )) def test_invalid(self, **kwargs: Any) -> None: self.assert_invalid(**kwargs)
def leave_Module( self, original_node: cst.Module, updated_node: cst.Module ) -> cst.Module: if self.is_generated: return original_node if not self.toplevel_annotations and not self.imports: return updated_node toplevel_statements = [] # First, find the insertion point for imports statements_before_imports, statements_after_imports = self._split_module( original_node, updated_node ) # Make sure there's at least one empty line before the first non-import statements_after_imports = self._insert_empty_line(statements_after_imports) imported = set() for statement in self.import_statements: names = statement.names if isinstance(names, cst.ImportStar): continue for name in names: if name.asname: name = name.asname if name: imported.add(_get_name_as_string(name.name)) for _, import_statement in self.imports.items(): # Filter out anything that has already been imported. names = import_statement.names.difference(imported) names = [cst.ImportAlias(cst.Name(name)) for name in sorted(names)] if not names: continue import_statement = cst.ImportFrom( module=import_statement.module, names=names ) # Add import statements to module body. # Need to assign an Iterable, and the argument to SimpleStatementLine # must be subscriptable. toplevel_statements.append(cst.SimpleStatementLine([import_statement])) for name, annotation in self.toplevel_annotations.items(): annotated_assign = cst.AnnAssign( cst.Name(name), # pyre-fixme[16]: `CSTNode` has no attribute `annotation`. cst.Annotation(annotation.annotation), None, ) toplevel_statements.append(cst.SimpleStatementLine([annotated_assign])) return updated_node.with_changes( body=[ *statements_before_imports, *toplevel_statements, *statements_after_imports, ] )
def import_to_node_single(imp: SortableImport, module: cst.Module) -> cst.BaseStatement: leading_lines = [ cst.EmptyLine(indent=True, comment=cst.Comment(line)) if line.startswith("#") else cst.EmptyLine(indent=False) for line in imp.comments.before ] trailing_whitespace = cst.TrailingWhitespace() trailing_comments = list(imp.comments.first_inline) names: List[cst.ImportAlias] = [] for item in imp.items: name = name_to_node(item.name) asname = cst.AsName( name=cst.Name(item.asname)) if item.asname else None node = cst.ImportAlias(name=name, asname=asname) names.append(node) trailing_comments += item.comments.before trailing_comments += item.comments.inline trailing_comments += item.comments.following trailing_comments += imp.comments.final trailing_comments += imp.comments.last_inline if trailing_comments: text = COMMENT_INDENT.join(trailing_comments) trailing_whitespace = cst.TrailingWhitespace( whitespace=cst.SimpleWhitespace(COMMENT_INDENT), comment=cst.Comment(text)) if imp.stem: stem, ndots = split_relative(imp.stem) if not stem: module_name = None else: module_name = name_to_node(stem) relative = (cst.Dot(), ) * ndots line = cst.SimpleStatementLine( body=[ cst.ImportFrom(module=module_name, names=names, relative=relative) ], leading_lines=leading_lines, trailing_whitespace=trailing_whitespace, ) else: line = cst.SimpleStatementLine( body=[cst.Import(names=names)], leading_lines=leading_lines, trailing_whitespace=trailing_whitespace, ) return line
def __get_required_imports(self): def find_required_modules(all_types): req_mod = set() for _, a_node in all_types: m = match.findall( a_node.annotation, match.Attribute(value=match.DoNotCare(), attr=match.DoNotCare())) if len(m) != 0: for i in m: req_mod.add([ n.value for n in match.findall( i, match.Name(value=match.DoNotCare())) ][0]) return req_mod req_imports = [] all_req_mods = find_required_modules(self.all_applied_types) all_type_names = set( chain.from_iterable( map(lambda t: regex.findall(r"\w+", t[0]), self.all_applied_types))) typing_imports = PY_TYPING_MOD & all_type_names collection_imports = PY_COLLECTION_MOD & all_type_names if len(typing_imports) > 0: req_imports.append( cst.SimpleStatementLine(body=[ cst.ImportFrom(module=cst.Name(value="typing"), names=[ cst.ImportAlias(name=cst.Name(value=t), asname=None) for t in typing_imports ]), ])) if len(collection_imports) > 0: req_imports.append(cst.SimpleStatementLine(body=[cst.ImportFrom(module=cst.Name(value="collections"), names=[cst.ImportAlias(name=cst.Name(value=t), asname=None) \ for t in collection_imports]),])) if len(all_req_mods) > 0: for mod_name in all_req_mods: req_imports.append( cst.SimpleStatementLine(body=[ cst.Import(names=[ cst.ImportAlias(name=cst.Name(value=mod_name), asname=None) ]) ])) return req_imports
def leave_Module( self, original_node: cst.Module, updated_node: cst.Module ) -> cst.Module: fresh_class_definitions = [ definition for name, definition in self.annotations.class_definitions.items() if name not in self.visited_classes ] if not self.toplevel_annotations and not fresh_class_definitions: return updated_node toplevel_statements = [] # First, find the insertion point for imports statements_before_imports, statements_after_imports = self._split_module( original_node, updated_node ) # Make sure there's at least one empty line before the first non-import statements_after_imports = self._insert_empty_line(statements_after_imports) for name, annotation in self.toplevel_annotations.items(): annotated_assign = cst.AnnAssign(cst.Name(name), annotation, None) toplevel_statements.append(cst.SimpleStatementLine([annotated_assign])) toplevel_statements.extend(fresh_class_definitions) return updated_node.with_changes( body=[ *statements_before_imports, *toplevel_statements, *statements_after_imports, ] )
def _get_expression_transform( before: Callable[..., Any], after: Callable[..., Any]) -> ExpressionTransform: expression = function_parser.parse(before)[0] matchers = function_parser.args_to_matchers(before) matcher = craftier.matcher.from_node(expression, matchers) inner_matcher = getattr(matcher, "matcher", None) if isinstance(matcher, libcst.matchers.DoNotCareSentinel) or isinstance( inner_matcher, libcst.matchers.DoNotCareSentinel): raise Exception( f"DoNotCare matcher is forbidden at top level in `{before.__name__}`" ) after_expression = function_parser.parse(after)[0] # Technically this is not correct as some expressions,like function calls, # binary operations, name, etc, are wrapped in an `Expr` node. module = libcst.Module(body=[ libcst.SimpleStatementLine( body=[cast(libcst.BaseSmallStatement, after_expression)]) ]) wrapper = libcst.metadata.MetadataWrapper(module) body = cast(libcst.SimpleStatementLine, wrapper.module.body[0]) replacement = body.body[0] return ExpressionTransform(before=matcher, after=replacement, wrapper=wrapper)
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: final_node = updated_node if original_node is self.scope: suite = updated_node.body tail = self.tail for name, attr in self.names_to_attr.items(): state = self.attr_states.get(name, []) # default writeback initial value state.append(([], cst.Name(name))) attr_val = _fold_conditions(_simplify_gaurds(state), self.strict) write = to_stmt(make_assign(attr, attr_val)) tail.append(write) if self.returns: strict = self.strict try: return_val = _fold_conditions( _simplify_gaurds(self.returns), strict) except IncompleteGaurdError: raise SyntaxError( 'Cannot prove function always returns') from None return_stmt = cst.SimpleStatementLine( [cst.Return(value=return_val)]) tail.append(return_stmt) return final_node
def leave_Module( self, original_node: cst.Module, updated_node: cst.Module, ) -> cst.Module: fresh_class_definitions = [ definition for name, definition in self.annotations.class_definitions.items() if name not in self.visited_classes ] # NOTE: The entire change will also be abandoned if # self.annotation_counts is all 0s, so if adding any new category make # sure to record it there. if not (self.toplevel_annotations or fresh_class_definitions or self.annotations.typevars): return updated_node toplevel_statements = [] # First, find the insertion point for imports statements_before_imports, statements_after_imports = self._split_module( original_node, updated_node) # Make sure there's at least one empty line before the first non-import statements_after_imports = self._insert_empty_line( statements_after_imports) for name, annotation in self.toplevel_annotations.items(): annotated_assign = self._apply_annotation_to_attribute_or_global( name=name, annotation=annotation, value=None, ) toplevel_statements.append( cst.SimpleStatementLine([annotated_assign])) # TypeVar definitions could be scattered through the file, so do not # attempt to put new ones with existing ones, just add them at the top. typevars = { k: v for k, v in self.annotations.typevars.items() if k not in self.typevars } if typevars: for var, stmt in typevars.items(): toplevel_statements.append(cst.Newline()) toplevel_statements.append(stmt) self.annotation_counts.typevars_and_generics_added += 1 toplevel_statements.append(cst.Newline()) self.annotation_counts.classes_added = len(fresh_class_definitions) toplevel_statements.extend(fresh_class_definitions) return updated_node.with_changes(body=[ *statements_before_imports, *toplevel_statements, *statements_after_imports, ])
class NamedExprTest(CSTNodeTest): @data_provider( ( { "node": cst.BinaryOperation( left=cst.Name("a"), operator=cst.MatrixMultiply(), right=cst.Name("b"), ), "code": "a @ b", "parser": parse_expression_as(python_version="3.8"), }, { "node": cst.SimpleStatementLine( body=( cst.AugAssign( target=cst.Name("a"), operator=cst.MatrixMultiplyAssign(), value=cst.Name("b"), ), ), ), "code": "a @= b\n", "parser": parse_statement_as(python_version="3.8"), }, ) ) def test_valid(self, **kwargs: Any) -> None: self.validate_node(**kwargs) @data_provider( ( { "code": "a @ b", "parser": parse_expression_as(python_version="3.6"), "expect_success": True, }, { "code": "a @ b", "parser": parse_expression_as(python_version="3.3"), "expect_success": False, }, { "code": "a @= b", "parser": parse_statement_as(python_version="3.6"), "expect_success": True, }, { "code": "a @= b", "parser": parse_statement_as(python_version="3.3"), "expect_success": False, }, ) ) def test_versions(self, **kwargs: Any) -> None: if is_native() and not kwargs.get("expect_success", True): self.skipTest("parse errors are disabled for native parser") self.assert_parses(**kwargs)
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: if not self.toplevel_annotations and not self.imports: return updated_node toplevel_statements = [] # First, find the insertion point for imports statements_before_imports, statements_after_imports = self._split_module( original_node, updated_node) # Make sure there's at least one empty line before the first non-import statements_after_imports = self._insert_empty_line( statements_after_imports) for _, import_statement in self.imports.items(): import_statement = cst.ImportFrom( module=import_statement.module, # pyre-fixme[6]: Expected `Union[Sequence[ImportAlias], ImportStar]` # for 2nd param but got `List[ImportFrom]`. names=import_statement.names, ) # Add import statements to module body. # Need to assign an Iterable, and the argument to SimpleStatementLine # must be subscriptable. toplevel_statements.append( cst.SimpleStatementLine([import_statement])) for name, annotation in self.toplevel_annotations.items(): annotated_assign = cst.AnnAssign( cst.Name(name), # pyre-fixme[16]: `CSTNode` has no attribute `annotation`. cst.Annotation(annotation.annotation), None, ) toplevel_statements.append( cst.SimpleStatementLine([annotated_assign])) return updated_node.with_changes(body=[ *statements_before_imports, *toplevel_statements, *statements_after_imports, ])
def leave_Module(self, original_node, updated_node): final_node = super().leave_Module(original_node, updated_node) imports_str = cst.Module( body=[cst.SimpleStatementLine([i]) for i in self.imports]).code sorted_imports = cst.parse_module( SortImports(file_contents=imports_str).output) # Add imports back to the top of the module new_body = sorted_imports.body + list(final_node.body) return final_node.with_changes(body=new_body)
def test_parse_multiple_statements(self) -> None: # pylint: disable=pointless-statement def test_function(x, y): x y # pylint: enable=pointless-statement self.assert_node_equal( function_parser.parse(test_function)[0], libcst.SimpleStatementLine(body=[libcst.Expr(libcst.Name("x"))]), )
def test_deprecated_non_element_construction(self) -> None: module = cst.Module(body=[ cst.SimpleStatementLine(body=[ cst.Expr(value=cst.Subscript( value=cst.Name(value="foo"), slice=cst.Index(value=cst.Integer(value="1")), )) ]) ]) self.assertEqual(module.code, "foo[1]\n")
def leave_Module(self, node: cst.Module, updated_node: cst.Module) -> cst.CSTNode: body = list(updated_node.body) index = self._get_toplevel_index(body) for name, annotation in self.toplevel_annotations.items(): annotated_assign = cst.AnnAssign( cst.Name(name), # pyre-fixme[16]: `CSTNode` has no attribute `annotation`. cst.Annotation(annotation.annotation), None, ) body.insert(index, cst.SimpleStatementLine([annotated_assign])) return updated_node.with_changes(body=tuple(body))
def on_leave( self, original_node: CSTNodeT, updated_node: CSTNodeT ) -> Union[CSTNodeT, RemovalSentinel, FlattenSentinel[cst.SimpleStatementLine]]: if isinstance(updated_node, cst.SimpleStatementLine): return FlattenSentinel([ cst.SimpleStatementLine( [stmt.with_changes(semicolon=cst.MaybeSentinel.DEFAULT)]) for stmt in updated_node.body ]) else: return updated_node
def body( self, ) -> typing.Iterable[typing.Union[cst.BaseCompoundStatement, cst.SimpleStatementLine]]: yield cst.SimpleStatementLine( [cst.ImportFrom(cst.Name("typing"), names=cst.ImportStar())]) yield from assign_properties(self.properties) yield from function_defs(self.function_overloads, self.functions, "function") for name, class_ in sort_items(self.classes): yield class_.class_def(name)
def test_statement(self) -> None: # Test that we can insert various types of statements into a # statement list. module = parse_template_module( "{statement1}\n{statement2}\n{statement3}\n", statement1=cst.If( test=cst.Name("foo"), body=cst.SimpleStatementSuite((cst.Pass(),),), ), statement2=cst.SimpleStatementLine((cst.Expr(cst.Call(cst.Name("bar"))),),), statement3=cst.Pass(), ) self.assertEqual( module.code, "if foo: pass\nbar()\npass\n", )
def generate_import(name, obj, func_obj=None, file_imports=None): """ Generate an import statement for a (name, runtime object) pair. """ inliner = ctx_inliner.get() # HACK? is this still needed? if name == 'self': return None # If the name is already in scope, don't need to import it if name in inliner.base_globls: # TODO: name conflicts? e.g. host imports json as x, and # another module imports foo as x return None # If the name appears directly in an import statement in the object's file, # then use that import if file_imports is not None and name in file_imports: return cst.SimpleStatementLine([file_imports[name]]) # If we're importing a module, then add an import directly if inspect.ismodule(obj): mod_name = obj.__name__ return parse_statement(f'import {mod_name} as {name}' if name != mod_name else f'import {mod_name}') else: # Get module where global is defined mod = inspect.getmodule(obj) # TODO: When is mod None? if mod is None or mod is typing or mod.__name__ == '__main__': return None # Can't import builtins elif mod is __builtins__ or mod is builtins: return None # If the value is a class or function, then import it from the defining # module elif inspect.isclass(obj) or inspect.isfunction(obj): return parse_statement(f'from {mod.__name__} import {name}') # Otherwise import it from the module using the global elif func_obj is not None: func_mod_name = inspect.getmodule(func_obj).__name__ if func_mod_name == '__main__': return None return parse_statement(f'from {func_mod_name} import {name}')
def test_module_config_for_parsing(self) -> None: module = parse_module("pass\r") statement = parse_statement("if True:\r pass", config=module.config_for_parsing) self.assertEqual( statement, cst.If( test=cst.Name(value="True"), body=cst.IndentedBlock( body=[cst.SimpleStatementLine(body=[cst.Pass()])], header=cst.TrailingWhitespace(newline=cst.Newline( # This would be "\r" if we didn't pass the module config forward. value=None)), ), ), )
def test_deep_replace_identity(self) -> None: old_code = """ pass """ new_code = """ break """ module = cst.parse_module(dedent(old_code)) new_module = module.deep_replace( module, cst.Module( header=(cst.EmptyLine(), ), body=(cst.SimpleStatementLine(body=(cst.Break(), )), ), ), ) self.assertEqual(new_module.code, dedent(new_code))
def function_def( self, name: str, type: typing.Literal["function", "classmethod", "method"], indent=0, overload=False, ) -> cst.FunctionDef: decorators: typing.List[cst.Decorator] = [] if overload: decorators.append(cst.Decorator(cst.Name("overload"))) if type == "classmethod": decorators.append(cst.Decorator(cst.Name("classmethod"))) return cst.FunctionDef( cst.Name(name), self.parameters(type), cst.IndentedBlock( [cst.SimpleStatementLine([s]) for s in self.body(indent)]), decorators, self.return_type_annotation)
def astNodeFromAssertion(transform: Transform, match: Match, index=0) -> Sequence[cst.CSTNode]: # TODO: aggregate intersect possible_nodes_per_prop and and raise on multiple # results (some kind of "ambiguity error"). Also need to match with anchor placement cur_scope_expr = transform.assertion.nested_scopes[index] cur_capture = match.by_name.get(cur_scope_expr.capture.name) name = cur_scope_expr.capture.name body = () node = BodyType = None if cur_capture is not None: node = cur_capture.node name = elemName(node=node) body, BodyType = getNodeBody(node) if index < len(transform.assertion.nested_scopes) - 1: # inner doesn'make sense for ( and , nesting scopes... inner = astNodeFromAssertion(transform, match, index+1) if transform.destructive: body = [s for s in body if not s.deep_equals( match.path[-1].node)] body = (*body, *inner) BodyType = BodyType or cst.IndentedBlock if not body: body = [cst.SimpleStatementLine(body=[cst.Pass()])] if cur_capture is not None: # XXX: perhaps I should use tuples instead of lists to abide by the immutability of cst return [node.with_changes( name=cst.Name(name), **({ 'body': BodyType(body=body), # probably need a better way to do this, ideally just ignore excess kwargs } if isinstance(node, (cst.FunctionDef, cst.ClassDef)) else {}) )] else: unrefed_node = astNodeFrom( scope_expr=cur_scope_expr, ctx=transform, match=match) # NOTE: need a generic way to "place" the next scope in a node if index < len(transform.assertion.nested_scopes) - 1: return [unrefed_node.with_changes(body=BodyType(body=body))] else: return [unrefed_node]
def _get_assert_replacement(self, node: cst.Assert): message = node.msg or str(cst.Module(body=[node]).code) return cst.If( test=cst.UnaryOperation( operator=cst.Not(), expression=node.test, # Todo: parenthesize? ), body=cst.IndentedBlock(body=[ cst.SimpleStatementLine(body=[ cst.Raise(exc=cst.Call( func=cst.Name(value="AssertionError", ), args=[ cst.Arg(value=cst.SimpleString(value=repr(message), ), ), ], ), ), ]), ], ), )
def with_added_imports( module_node: cst.Module, import_nodes: Sequence[Union[cst.Import, cst.ImportFrom]]) -> cst.Module: """ Adds new import `import_node` after the first import in the module `module_node`. """ updated_body: List[Union[cst.SimpleStatementLine, cst.BaseCompoundStatement]] = [] added_import = False for line in module_node.body: updated_body.append(line) if not added_import and _is_import_line(line): for import_node in import_nodes: updated_body.append( cst.SimpleStatementLine(body=tuple([import_node]))) added_import = True if not added_import: raise RuntimeError("Failed to add imports") return module_node.with_changes(body=tuple(updated_body))
def assign_properties( p: typing.Dict[str, typing.Tuple[Metadata, Type]], is_classvar=False) -> typing.Iterable[cst.SimpleStatementLine]: for name, metadata_and_tp in sort_items(p): if bad_name(name): continue metadata, tp = metadata_and_tp ann = tp.annotation yield cst.SimpleStatementLine( [ cst.AnnAssign( cst.Name(name), cst.Annotation( cst.Subscript(cst.Name("ClassVar"), [cst.SubscriptElement(cst.Index(ann))] ) if is_classvar else ann), ) ], leading_lines=[cst.EmptyLine()] + [ cst.EmptyLine(comment=cst.Comment("# " + l)) for l in metadata_lines(metadata) ], )
def type_declaration_statements( bindings: UnpackedBindings, annotations: UnpackedAnnotations, leading_lines: Sequence[cst.EmptyLine], quote_annotations: bool, ) -> List[cst.SimpleStatementLine]: return [ cst.SimpleStatementLine( body=[ AnnotationSpreader.type_declaration( binding=binding, raw_annotation=raw_annotation, quote_annotations=quote_annotations, ) ], leading_lines=leading_lines if i == 0 else [], ) for i, (binding, raw_annotation) in enumerate( AnnotationSpreader.annotated_bindings( bindings=bindings, annotations=annotations, ) ) ]
def import_to_node_multi(imp: SortableImport, module: cst.Module) -> cst.BaseStatement: body: List[cst.BaseSmallStatement] = [] names: List[cst.ImportAlias] = [] prev: Optional[cst.ImportAlias] = None following: List[str] = [] lpar_lines: List[cst.EmptyLine] = [] lpar_inline: cst.TrailingWhitespace = cst.TrailingWhitespace() item_count = len(imp.items) for idx, item in enumerate(imp.items): name = name_to_node(item.name) asname = cst.AsName( name=cst.Name(item.asname)) if item.asname else None # Leading comments actually have to be trailing comments on the previous node. # That means putting them on the lpar node for the first item if item.comments.before: lines = [ cst.EmptyLine( indent=True, comment=cst.Comment(c), whitespace=cst.SimpleWhitespace(module.default_indent), ) for c in item.comments.before ] if prev is None: lpar_lines.extend(lines) else: prev.comma.whitespace_after.empty_lines.extend( lines) # type: ignore # all items except the last needs whitespace to indent the *next* line/item indent = idx != (len(imp.items) - 1) first_line = cst.TrailingWhitespace() inline = COMMENT_INDENT.join(item.comments.inline) if inline: first_line = cst.TrailingWhitespace( whitespace=cst.SimpleWhitespace(COMMENT_INDENT), comment=cst.Comment(inline), ) if idx == item_count - 1: following = item.comments.following + imp.comments.final else: following = item.comments.following after = cst.ParenthesizedWhitespace( indent=True, first_line=first_line, empty_lines=[ cst.EmptyLine( indent=True, comment=cst.Comment(c), whitespace=cst.SimpleWhitespace(module.default_indent), ) for c in following ], last_line=cst.SimpleWhitespace( module.default_indent if indent else ""), ) node = cst.ImportAlias( name=name, asname=asname, comma=cst.Comma(whitespace_after=after), ) names.append(node) prev = node # from foo import ( # bar # ) if imp.stem: stem, ndots = split_relative(imp.stem) if not stem: module_name = None else: module_name = name_to_node(stem) relative = (cst.Dot(), ) * ndots # inline comment following lparen if imp.comments.first_inline: inline = COMMENT_INDENT.join(imp.comments.first_inline) lpar_inline = cst.TrailingWhitespace( whitespace=cst.SimpleWhitespace(COMMENT_INDENT), comment=cst.Comment(inline), ) body = [ cst.ImportFrom( module=module_name, names=names, relative=relative, lpar=cst.LeftParen( whitespace_after=cst.ParenthesizedWhitespace( indent=True, first_line=lpar_inline, empty_lines=lpar_lines, last_line=cst.SimpleWhitespace(module.default_indent), ), ), rpar=cst.RightParen(), ) ] # import foo else: raise ValueError("can't render basic imports on multiple lines") # comment lines above import leading_lines = [ cst.EmptyLine(indent=True, comment=cst.Comment(line)) if line.startswith("#") else cst.EmptyLine(indent=False) for line in imp.comments.before ] # inline comments following import/rparen if imp.comments.last_inline: inline = COMMENT_INDENT.join(imp.comments.last_inline) trailing = cst.TrailingWhitespace( whitespace=cst.SimpleWhitespace(COMMENT_INDENT), comment=cst.Comment(inline)) else: trailing = cst.TrailingWhitespace() return cst.SimpleStatementLine( body=body, leading_lines=leading_lines, trailing_whitespace=trailing, )
class WhileTest(CSTNodeTest): @data_provider(( # Simple while block # pyre-fixme[6]: Incompatible parameter type { "node": cst.While(cst.Call(cst.Name("iter")), cst.SimpleStatementSuite((cst.Pass(), ))), "code": "while iter(): pass\n", "parser": parse_statement, }, # While block with else { "node": cst.While( cst.Call(cst.Name("iter")), cst.SimpleStatementSuite((cst.Pass(), )), cst.Else(cst.SimpleStatementSuite((cst.Pass(), ))), ), "code": "while iter(): pass\nelse: pass\n", "parser": parse_statement, }, # indentation { "node": DummyIndentedBlock( " ", cst.While( cst.Call(cst.Name("iter")), cst.SimpleStatementSuite((cst.Pass(), )), ), ), "code": " while iter(): pass\n", "parser": None, "expected_position": CodeRange((1, 4), (1, 22)), }, # while an indented body { "node": DummyIndentedBlock( " ", cst.While( cst.Call(cst.Name("iter")), cst.IndentedBlock((cst.SimpleStatementLine( (cst.Pass(), )), )), ), ), "code": " while iter():\n pass\n", "parser": None, "expected_position": CodeRange((1, 4), (2, 12)), }, # leading_lines { "node": cst.While( cst.Call(cst.Name("iter")), cst.IndentedBlock((cst.SimpleStatementLine((cst.Pass(), )), )), leading_lines=(cst.EmptyLine( comment=cst.Comment("# leading comment")), ), ), "code": "# leading comment\nwhile iter():\n pass\n", "parser": parse_statement, "expected_position": CodeRange((2, 0), (3, 8)), }, { "node": cst.While( cst.Call(cst.Name("iter")), cst.IndentedBlock((cst.SimpleStatementLine((cst.Pass(), )), )), cst.Else( cst.IndentedBlock((cst.SimpleStatementLine( (cst.Pass(), )), )), leading_lines=(cst.EmptyLine( comment=cst.Comment("# else comment")), ), ), leading_lines=(cst.EmptyLine( comment=cst.Comment("# leading comment")), ), ), "code": "# leading comment\nwhile iter():\n pass\n# else comment\nelse:\n pass\n", "parser": None, "expected_position": CodeRange((2, 0), (6, 8)), }, # Weird spacing rules { "node": cst.While( cst.Call( cst.Name("iter"), lpar=(cst.LeftParen(), ), rpar=(cst.RightParen(), ), ), cst.SimpleStatementSuite((cst.Pass(), )), whitespace_after_while=cst.SimpleWhitespace(""), ), "code": "while(iter()): pass\n", "parser": parse_statement, "expected_position": CodeRange((1, 0), (1, 19)), }, # Whitespace { "node": cst.While( cst.Call(cst.Name("iter")), cst.SimpleStatementSuite((cst.Pass(), )), whitespace_after_while=cst.SimpleWhitespace(" "), whitespace_before_colon=cst.SimpleWhitespace(" "), ), "code": "while iter() : pass\n", "parser": parse_statement, "expected_position": CodeRange((1, 0), (1, 21)), }, )) def test_valid(self, **kwargs: Any) -> None: self.validate_node(**kwargs) @data_provider(({ "get_node": lambda: cst.While( cst.Call(cst.Name("iter")), cst.SimpleStatementSuite((cst.Pass(), )), whitespace_after_while=cst.SimpleWhitespace(""), ), "expected_re": "Must have at least one space after 'while' keyword", }, )) def test_invalid(self, **kwargs: Any) -> None: self.assert_invalid(**kwargs)
def to_stmt(node: cst.BaseSmallStatement) -> cst.SimpleStatementLine: return cst.SimpleStatementLine(body=[node])