def _create_import_from_annotation(self, returns: cst.Annotation) -> cst.Annotation: annotation = returns.annotation if isinstance(annotation, cst.Attribute): attr = self._add_annotation_to_imports(annotation) return cst.Annotation(annotation=attr) if isinstance(annotation, cst.Subscript): return cst.Annotation(annotation=self._handle_Subscript(annotation)) else: return returns
def _handle_Annotation(self, annotation: cst.Annotation) -> cst.Annotation: node = annotation.annotation if isinstance(node, cst.SimpleString): return annotation elif isinstance(node, cst.Subscript): return cst.Annotation(annotation=self._handle_Subscript(node)) elif isinstance(node, NAME_OR_ATTRIBUTE): return cst.Annotation( annotation=self._handle_NameOrAttribute(node)) else: raise ValueError(f"Unexpected annotation node: {node}")
def _create_import_from_annotation(self, returns: cst.Annotation) -> cst.Annotation: annotation = returns.annotation if isinstance(annotation, cst.Attribute): attr = self._add_annotation_to_imports(annotation) return cst.Annotation(annotation=attr) if isinstance(annotation, cst.Subscript): value = annotation.value if isinstance(value, cst.Name) and value.value == "Type": return returns return cst.Annotation(annotation=self._handle_Subscript(annotation)) else: return returns
def leave_FunctionDef(self, node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: returns = self.stack.pop() if returns is None: return updated_node if not returns: return updated_node.with_changes(returns=cst.Annotation( annotation=cst.Name(value="None"))) last_line = node.body.body[-1] if not isinstance(last_line, cst.SimpleStatementLine): if returns and all(r.value is None or isinstance( r.value, cst.Name) and r.value.value == 'None' for r in returns): return updated_node.with_changes(returns=cst.Annotation( annotation=cst.Name(value="None"))) return updated_node elif not isinstance(last_line.body[-1], cst.Return): if returns and all(r.value is None or isinstance( r.value, cst.Name) and r.value.value == 'None' for r in returns): return updated_node.with_changes(returns=cst.Annotation( annotation=cst.Name(value="None"))) return updated_node if len(returns) == 1: rvalue = returns[0].value if isinstance(rvalue, cst.BaseString): if isinstance( rvalue, cst.SimpleString) and rvalue.value.startswith("b"): return updated_node.with_changes(returns=cst.Annotation( annotation=cst.Name(value="bytes"))) return updated_node.with_changes(returns=cst.Annotation( annotation=cst.Name(value="str"))) if isinstance(rvalue, cst.Name): if rvalue.value in ("False", "True"): return updated_node.with_changes(returns=cst.Annotation( annotation=cst.Name(value="bool"))) if rvalue.value == "None": return updated_node.with_changes(returns=cst.Annotation( annotation=cst.Name(value="None"))) if isinstance(rvalue, cst.Integer): return updated_node.with_changes(returns=cst.Annotation( annotation=cst.Name(value="int"))) if isinstance(rvalue, cst.Float): return updated_node.with_changes(returns=cst.Annotation( annotation=cst.Name(value="float"))) elif returns and all(r.value is None or isinstance(r.value, cst.Name) and r.value.value == 'None' for r in returns): return updated_node.with_changes(returns=cst.Annotation( annotation=cst.Name(value="None"))) return updated_node
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: docstring = None docstring_node = get_docstring_node(updated_node.body) if docstring_node: if isinstance(docstring_node.value, (cst.SimpleString, cst.ConcatenatedString)): docstring = docstring_node.value.evaluated_value if not docstring: return updated_node new_docstring, types = gather_types(docstring) if types.get(RETURN): updated_node = updated_node.with_changes(returns=cst.Annotation( cst.Name(types.pop(RETURN))), ) if types: def get_annotation(p: cst.Param) -> Optional[cst.Annotation]: pname = p.name.value if types.get(pname): return cst.Annotation(cst.parse_expression(types[pname])) return None updated_node = updated_node.with_changes(params=update_parameters( updated_node.params, get_annotation, False)) new_docstring_node = cst.SimpleString('"""%s"""' % new_docstring) return updated_node.deep_replace(docstring_node, cst.Expr(new_docstring_node))
def leave_FunctionDef( self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef ) -> cst.FunctionDef: if matchers.matches(updated_node, self.matcher): return updated_node.with_changes(returns=cst.Annotation(cst.Name(value="None"))) return updated_node
def leave_Module( self, original_node: cst.Module, updated_node: cst.Module ) -> cst.Module: if self.is_generated: return original_node if not self.toplevel_annotations and not self.imports: return updated_node toplevel_statements = [] # First, find the insertion point for imports statements_before_imports, statements_after_imports = self._split_module( original_node, updated_node ) # Make sure there's at least one empty line before the first non-import statements_after_imports = self._insert_empty_line(statements_after_imports) imported = set() for statement in self.import_statements: names = statement.names if isinstance(names, cst.ImportStar): continue for name in names: if name.asname: name = name.asname if name: imported.add(_get_name_as_string(name.name)) for _, import_statement in self.imports.items(): # Filter out anything that has already been imported. names = import_statement.names.difference(imported) names = [cst.ImportAlias(cst.Name(name)) for name in sorted(names)] if not names: continue import_statement = cst.ImportFrom( module=import_statement.module, names=names ) # Add import statements to module body. # Need to assign an Iterable, and the argument to SimpleStatementLine # must be subscriptable. toplevel_statements.append(cst.SimpleStatementLine([import_statement])) for name, annotation in self.toplevel_annotations.items(): annotated_assign = cst.AnnAssign( cst.Name(name), # pyre-fixme[16]: `CSTNode` has no attribute `annotation`. cst.Annotation(annotation.annotation), None, ) toplevel_statements.append(cst.SimpleStatementLine([annotated_assign])) return updated_node.with_changes( body=[ *statements_before_imports, *toplevel_statements, *statements_after_imports, ] )
def _create_import_from_annotation(self, returns: cst.CSTNode) -> cst.CSTNode: # pyre-fixme[16]: `CSTNode` has no attribute `annotation`. if isinstance(returns.annotation, cst.Attribute): annotation = returns.annotation key = _get_attribute_as_string(annotation.value) self._add_to_imports([cst.ImportAlias(name=annotation.attr)], annotation.value, key) return cst.Annotation(annotation=returns.annotation.attr) else: return returns
def leave_Module(self, node: cst.Module, updated_node: cst.Module) -> cst.CSTNode: body = list(updated_node.body) index = self._get_toplevel_index(body) for name, annotation in self.toplevel_annotations.items(): annotated_assign = cst.AnnAssign( cst.Name(name), # pyre-fixme[16]: `CSTNode` has no attribute `annotation`. cst.Annotation(annotation.annotation), None, ) body.insert(index, cst.SimpleStatementLine([annotated_assign])) return updated_node.with_changes(body=tuple(body))
def _convert_annotation( raw: str, quote_annotations: bool, ) -> cst.Annotation: """ Convert a raw annotation - which is a string coming from a type comment - into a suitable libcst Annotation node. If `quote_annotations`, we'll always quote annotations unless they are builtin types. The reason for this is to make the codemod safer to apply on legacy code where type comments may well include invalid types that would crash at runtime. """ if _is_builtin(raw): return cst.Annotation(annotation=cst.Name(value=raw)) if not quote_annotations: try: return cst.Annotation(annotation=cst.parse_expression(raw)) except cst.ParserSyntaxError: pass return cst.Annotation(annotation=cst.SimpleString(f'"{raw}"'))
def parameters( self, type: typing.Literal["function", "classmethod", "method"] ) -> cst.Parameters: posonly_params = [ cst.Param(cst.Name(k), cst.Annotation(v.annotation)) for k, v in self.pos_only_required.items() ] + [ cst.Param(cst.Name(k), cst.Annotation(v.annotation), default=cst.Ellipsis()) for k, v in possibly_order_dict( self.pos_only_optional, self.pos_only_optional_ordering).items() ] if type == "classmethod": posonly_params.insert(0, cst.Param(cst.Name("cls"))) elif type == "method": posonly_params.insert(0, cst.Param(cst.Name("self"))) return cst.Parameters( posonly_params=posonly_params, params=[ cst.Param(cst.Name(k), cst.Annotation(v.annotation)) for k, v in self.pos_or_kw_required.items() ] + [ cst.Param(cst.Name(k), cst.Annotation(v.annotation), default=cst.Ellipsis()) for k, v in possibly_order_dict( self.pos_or_kw_optional, self.pos_or_kw_optional_ordering).items() ], star_arg=(cst.Param( cst.Name(self.var_pos[0]), cst.Annotation(self.var_pos[1].annotation), ) if self.var_pos else cst.MaybeSentinel.DEFAULT), star_kwarg=(cst.Param(cst.Name(self.var_kw[0]), cst.Annotation(self.var_kw[1].annotation)) if self.var_kw else None), kwonly_params=[ cst.Param(cst.Name(k), cst.Annotation(v.annotation)) for k, v in self.kw_only_required.items() ] + [ cst.Param(cst.Name(k), cst.Annotation(v.annotation), default=cst.Ellipsis()) for k, v in self.kw_only_optional.items() ], )
def test_annotation(self) -> None: # Test that we can insert an annotation expression normally. statement = parse_template_statement( "x: {type} = {val}", type=cst.Name("int"), val=cst.Integer("5"), ) self.assertEqual( self.code(statement), "x: int = 5\n", ) # Test that we can insert an annotation node as a special case. statement = parse_template_statement( "x: {type} = {val}", type=cst.Annotation(cst.Name("int")), val=cst.Integer("5"), ) self.assertEqual( self.code(statement), "x: int = 5\n", )
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: if not self.toplevel_annotations and not self.imports: return updated_node toplevel_statements = [] # First, find the insertion point for imports statements_before_imports, statements_after_imports = self._split_module( original_node, updated_node) # Make sure there's at least one empty line before the first non-import statements_after_imports = self._insert_empty_line( statements_after_imports) for _, import_statement in self.imports.items(): import_statement = cst.ImportFrom( module=import_statement.module, # pyre-fixme[6]: Expected `Union[Sequence[ImportAlias], ImportStar]` # for 2nd param but got `List[ImportFrom]`. names=import_statement.names, ) # Add import statements to module body. # Need to assign an Iterable, and the argument to SimpleStatementLine # must be subscriptable. toplevel_statements.append( cst.SimpleStatementLine([import_statement])) for name, annotation in self.toplevel_annotations.items(): annotated_assign = cst.AnnAssign( cst.Name(name), # pyre-fixme[16]: `CSTNode` has no attribute `annotation`. cst.Annotation(annotation.annotation), None, ) toplevel_statements.append( cst.SimpleStatementLine([annotated_assign])) return updated_node.with_changes(body=[ *statements_before_imports, *toplevel_statements, *statements_after_imports, ])
def assign_properties( p: typing.Dict[str, typing.Tuple[Metadata, Type]], is_classvar=False) -> typing.Iterable[cst.SimpleStatementLine]: for name, metadata_and_tp in sort_items(p): if bad_name(name): continue metadata, tp = metadata_and_tp ann = tp.annotation yield cst.SimpleStatementLine( [ cst.AnnAssign( cst.Name(name), cst.Annotation( cst.Subscript(cst.Name("ClassVar"), [cst.SubscriptElement(cst.Index(ann))] ) if is_classvar else ann), ) ], leading_lines=[cst.EmptyLine()] + [ cst.EmptyLine(comment=cst.Comment("# " + l)) for l in metadata_lines(metadata) ], )
class AnnAssignTest(CSTNodeTest): @data_provider(( # Simple assignment creation case. { "node": cst.AnnAssign(cst.Name("foo"), cst.Annotation(cst.Name("str")), cst.Integer("5")), "code": "foo: str = 5", "parser": None, "expected_position": CodeRange((1, 0), (1, 12)), }, # Annotation creation without assignment { "node": cst.AnnAssign(cst.Name("foo"), cst.Annotation(cst.Name("str"))), "code": "foo: str", "parser": None, "expected_position": CodeRange((1, 0), (1, 8)), }, # Complex annotation creation { "node": cst.AnnAssign( cst.Name("foo"), cst.Annotation( cst.Subscript( cst.Name("Optional"), (cst.SubscriptElement(cst.Index(cst.Name("str"))), ), )), cst.Integer("5"), ), "code": "foo: Optional[str] = 5", "parser": None, "expected_position": CodeRange((1, 0), (1, 22)), }, # Simple assignment parser case. { "node": cst.SimpleStatementLine((cst.AnnAssign( target=cst.Name("foo"), annotation=cst.Annotation( annotation=cst.Name("str"), whitespace_before_indicator=cst.SimpleWhitespace(""), ), equal=cst.AssignEqual(), value=cst.Integer("5"), ), )), "code": "foo: str = 5\n", "parser": parse_statement, "expected_position": None, }, # Annotation without assignment { "node": cst.SimpleStatementLine((cst.AnnAssign( target=cst.Name("foo"), annotation=cst.Annotation( annotation=cst.Name("str"), whitespace_before_indicator=cst.SimpleWhitespace(""), ), value=None, ), )), "code": "foo: str\n", "parser": parse_statement, "expected_position": None, }, # Complex annotation { "node": cst.SimpleStatementLine((cst.AnnAssign( target=cst.Name("foo"), annotation=cst.Annotation( annotation=cst.Subscript( cst.Name("Optional"), (cst.SubscriptElement(cst.Index(cst.Name("str"))), ), ), whitespace_before_indicator=cst.SimpleWhitespace(""), ), equal=cst.AssignEqual(), value=cst.Integer("5"), ), )), "code": "foo: Optional[str] = 5\n", "parser": parse_statement, "expected_position": None, }, # Whitespace test { "node": cst.AnnAssign( target=cst.Name("foo"), annotation=cst.Annotation( annotation=cst.Subscript( cst.Name("Optional"), (cst.SubscriptElement(cst.Index(cst.Name("str"))), ), ), whitespace_before_indicator=cst.SimpleWhitespace(" "), whitespace_after_indicator=cst.SimpleWhitespace(" "), ), equal=cst.AssignEqual( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), value=cst.Integer("5"), ), "code": "foo : Optional[str] = 5", "parser": None, "expected_position": CodeRange((1, 0), (1, 26)), }, { "node": cst.SimpleStatementLine((cst.AnnAssign( target=cst.Name("foo"), annotation=cst.Annotation( annotation=cst.Subscript( cst.Name("Optional"), (cst.SubscriptElement(cst.Index(cst.Name("str"))), ), ), whitespace_before_indicator=cst.SimpleWhitespace(" "), whitespace_after_indicator=cst.SimpleWhitespace(" "), ), equal=cst.AssignEqual( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), value=cst.Integer("5"), ), )), "code": "foo : Optional[str] = 5\n", "parser": parse_statement, "expected_position": None, }, )) def test_valid(self, **kwargs: Any) -> None: self.validate_node(**kwargs) @data_provider(({ "get_node": (lambda: cst.AnnAssign( target=cst.Name("foo"), annotation=cst.Annotation(cst.Name("str")), equal=cst.AssignEqual(), value=None, )), "expected_re": "Must have a value when specifying an AssignEqual.", }, )) def test_invalid(self, **kwargs: Any) -> None: self.assert_invalid(**kwargs)
class LambdaCreationTest(CSTNodeTest): @data_provider(( # Simple lambda (cst.Lambda(cst.Parameters(), cst.Integer("5")), "lambda: 5"), # Test basic positional params ( cst.Lambda( cst.Parameters(params=(cst.Param(cst.Name("bar")), cst.Param(cst.Name("baz")))), cst.Integer("5"), ), "lambda bar, baz: 5", ), # Test basic positional default params ( cst.Lambda( cst.Parameters(default_params=( cst.Param(cst.Name("bar"), default=cst.SimpleString('"one"')), cst.Param(cst.Name("baz"), default=cst.Integer("5")), )), cst.Integer("5"), ), 'lambda bar = "one", baz = 5: 5', ), # Mixed positional and default params. ( cst.Lambda( cst.Parameters( params=(cst.Param(cst.Name("bar")), ), default_params=(cst.Param(cst.Name("baz"), default=cst.Integer("5")), ), ), cst.Integer("5"), ), "lambda bar, baz = 5: 5", ), # Test kwonly params ( cst.Lambda( cst.Parameters(kwonly_params=( cst.Param(cst.Name("bar"), default=cst.SimpleString('"one"')), cst.Param(cst.Name("baz")), )), cst.Integer("5"), ), 'lambda *, bar = "one", baz: 5', ), # Mixed params and kwonly_params ( cst.Lambda( cst.Parameters( params=( cst.Param(cst.Name("first")), cst.Param(cst.Name("second")), ), kwonly_params=( cst.Param(cst.Name("bar"), default=cst.SimpleString('"one"')), cst.Param(cst.Name("baz")), cst.Param(cst.Name("biz"), default=cst.SimpleString('"two"')), ), ), cst.Integer("5"), ), 'lambda first, second, *, bar = "one", baz, biz = "two": 5', ), # Mixed default_params and kwonly_params ( cst.Lambda( cst.Parameters( default_params=( cst.Param(cst.Name("first"), default=cst.Float("1.0")), cst.Param(cst.Name("second"), default=cst.Float("1.5")), ), kwonly_params=( cst.Param(cst.Name("bar"), default=cst.SimpleString('"one"')), cst.Param(cst.Name("baz")), cst.Param(cst.Name("biz"), default=cst.SimpleString('"two"')), ), ), cst.Integer("5"), ), 'lambda first = 1.0, second = 1.5, *, bar = "one", baz, biz = "two": 5', ), # Mixed params, default_params, and kwonly_params ( cst.Lambda( cst.Parameters( params=( cst.Param(cst.Name("first")), cst.Param(cst.Name("second")), ), default_params=( cst.Param(cst.Name("third"), default=cst.Float("1.0")), cst.Param(cst.Name("fourth"), default=cst.Float("1.5")), ), kwonly_params=( cst.Param(cst.Name("bar"), default=cst.SimpleString('"one"')), cst.Param(cst.Name("baz")), cst.Param(cst.Name("biz"), default=cst.SimpleString('"two"')), ), ), cst.Integer("5"), ), 'lambda first, second, third = 1.0, fourth = 1.5, *, bar = "one", baz, biz = "two": 5', CodeRange((1, 0), (1, 84)), ), # Test star_arg ( cst.Lambda( cst.Parameters(star_arg=cst.Param(cst.Name("params"))), cst.Integer("5"), ), "lambda *params: 5", ), # Typed star_arg, include kwonly_params ( cst.Lambda( cst.Parameters( star_arg=cst.Param(cst.Name("params")), kwonly_params=( cst.Param(cst.Name("bar"), default=cst.SimpleString('"one"')), cst.Param(cst.Name("baz")), cst.Param(cst.Name("biz"), default=cst.SimpleString('"two"')), ), ), cst.Integer("5"), ), 'lambda *params, bar = "one", baz, biz = "two": 5', ), # Mixed params default_params, star_arg and kwonly_params ( cst.Lambda( cst.Parameters( params=( cst.Param(cst.Name("first")), cst.Param(cst.Name("second")), ), default_params=( cst.Param(cst.Name("third"), default=cst.Float("1.0")), cst.Param(cst.Name("fourth"), default=cst.Float("1.5")), ), star_arg=cst.Param(cst.Name("params")), kwonly_params=( cst.Param(cst.Name("bar"), default=cst.SimpleString('"one"')), cst.Param(cst.Name("baz")), cst.Param(cst.Name("biz"), default=cst.SimpleString('"two"')), ), ), cst.Integer("5"), ), 'lambda first, second, third = 1.0, fourth = 1.5, *params, bar = "one", baz, biz = "two": 5', ), # Test star_arg and star_kwarg ( cst.Lambda( cst.Parameters(star_kwarg=cst.Param(cst.Name("kwparams"))), cst.Integer("5"), ), "lambda **kwparams: 5", ), # Test star_arg and kwarg ( cst.Lambda( cst.Parameters( star_arg=cst.Param(cst.Name("params")), star_kwarg=cst.Param(cst.Name("kwparams")), ), cst.Integer("5"), ), "lambda *params, **kwparams: 5", ), # Inner whitespace ( cst.Lambda( lpar=(cst.LeftParen( whitespace_after=cst.SimpleWhitespace(" ")), ), whitespace_after_lambda=cst.SimpleWhitespace(" "), params=cst.Parameters(), colon=cst.Colon(whitespace_after=cst.SimpleWhitespace(" ")), body=cst.Integer("5"), rpar=(cst.RightParen( whitespace_before=cst.SimpleWhitespace(" ")), ), ), "( lambda : 5 )", CodeRange((1, 2), (1, 13)), ), )) def test_valid(self, node: cst.CSTNode, code: str, position: Optional[CodeRange] = None) -> None: self.validate_node(node, code, expected_position=position) @data_provider(( ( lambda: cst.Lambda( cst.Parameters(params=(cst.Param(cst.Name("arg")), )), cst.Integer("5"), lpar=(cst.LeftParen(), ), ), "left paren without right paren", ), ( lambda: cst.Lambda( cst.Parameters(params=(cst.Param(cst.Name("arg")), )), cst.Integer("5"), rpar=(cst.RightParen(), ), ), "right paren without left paren", ), ( lambda: cst.Lambda( cst.Parameters(params=(cst.Param(cst.Name("arg")), )), cst.Integer("5"), whitespace_after_lambda=cst.SimpleWhitespace(""), ), "at least one space after lambda", ), ( lambda: cst.Lambda( cst.Parameters(default_params=(cst.Param( cst.Name("arg"), default=cst.Integer("5")), )), cst.Integer("5"), whitespace_after_lambda=cst.SimpleWhitespace(""), ), "at least one space after lambda", ), ( lambda: cst.Lambda( cst.Parameters(star_arg=cst.Param(cst.Name("arg"))), cst.Integer("5"), whitespace_after_lambda=cst.SimpleWhitespace(""), ), "at least one space after lambda", ), ( lambda: cst.Lambda( cst.Parameters(kwonly_params=(cst.Param(cst.Name("arg")), )), cst.Integer("5"), whitespace_after_lambda=cst.SimpleWhitespace(""), ), "at least one space after lambda", ), ( lambda: cst.Lambda( cst.Parameters(star_kwarg=cst.Param(cst.Name("arg"))), cst.Integer("5"), whitespace_after_lambda=cst.SimpleWhitespace(""), ), "at least one space after lambda", ), ( lambda: cst.Lambda( cst.Parameters(star_kwarg=cst.Param(cst.Name("bar"), equal=cst.AssignEqual())), cst.Integer("5"), ), "Must have a default when specifying an AssignEqual.", ), ( lambda: cst.Lambda( cst.Parameters(star_kwarg=cst.Param(cst.Name("bar"), star="***")), cst.Integer("5"), ), r"Must specify either '', '\*' or '\*\*' for star.", ), ( lambda: cst.Lambda( cst.Parameters(params=(cst.Param( cst.Name("bar"), default=cst.SimpleString('"one"')), )), cst.Integer("5"), ), "Cannot have defaults for params", ), ( lambda: cst.Lambda( cst.Parameters(default_params=(cst.Param(cst.Name("bar")), )), cst.Integer("5"), ), "Must have defaults for default_params", ), ( lambda: cst.Lambda(cst.Parameters(star_arg=cst.ParamStar()), cst.Integer("5")), "Must have at least one kwonly param if ParamStar is used.", ), ( lambda: cst.Lambda( cst.Parameters(params=(cst.Param(cst.Name("bar"), star="*"), ) ), cst.Integer("5"), ), "Expecting a star prefix of ''", ), ( lambda: cst.Lambda( cst.Parameters(default_params=(cst.Param( cst.Name("bar"), default=cst.SimpleString('"one"'), star="*", ), )), cst.Integer("5"), ), "Expecting a star prefix of ''", ), ( lambda: cst.Lambda( cst.Parameters(kwonly_params=(cst.Param(cst.Name("bar"), star="*"), )), cst.Integer("5"), ), "Expecting a star prefix of ''", ), ( lambda: cst.Lambda( cst.Parameters(star_arg=cst.Param(cst.Name("bar"), star="**")), cst.Integer("5"), ), r"Expecting a star prefix of '\*'", ), ( lambda: cst.Lambda( cst.Parameters(star_kwarg=cst.Param(cst.Name("bar"), star="*") ), cst.Integer("5"), ), r"Expecting a star prefix of '\*\*'", ), ( lambda: cst.Lambda( cst.Parameters(params=(cst.Param( cst.Name("arg"), annotation=cst.Annotation(cst.Name("str")), ), )), cst.Integer("5"), whitespace_after_lambda=cst.SimpleWhitespace(""), ), "Lambda params cannot have type annotations", ), ( lambda: cst.Lambda( cst.Parameters(default_params=(cst.Param( cst.Name("arg"), default=cst.Integer("5"), annotation=cst.Annotation(cst.Name("str")), ), )), cst.Integer("5"), whitespace_after_lambda=cst.SimpleWhitespace(""), ), "Lambda params cannot have type annotations", ), ( lambda: cst.Lambda( cst.Parameters(star_arg=cst.Param(cst.Name("arg"), annotation=cst.Annotation( cst.Name("str")))), cst.Integer("5"), whitespace_after_lambda=cst.SimpleWhitespace(""), ), "Lambda params cannot have type annotations", ), ( lambda: cst.Lambda( cst.Parameters(kwonly_params=(cst.Param( cst.Name("arg"), annotation=cst.Annotation(cst.Name("str")), ), )), cst.Integer("5"), whitespace_after_lambda=cst.SimpleWhitespace(""), ), "Lambda params cannot have type annotations", ), ( lambda: cst.Lambda( cst.Parameters(star_kwarg=cst.Param(cst.Name("arg"), annotation=cst.Annotation( cst.Name("str")))), cst.Integer("5"), whitespace_after_lambda=cst.SimpleWhitespace(""), ), "Lambda params cannot have type annotations", ), )) def test_invalid(self, get_node: Callable[[], cst.CSTNode], expected_re: str) -> None: self.assert_invalid(get_node, expected_re)
def get_annotation(p: cst.Param) -> Optional[cst.Annotation]: pname = p.name.value if types.get(pname): return cst.Annotation(cst.parse_expression(types[pname])) return None
def test_from_function_data(self) -> None: three_parameters = [ cst.Param(name=cst.Name("x1"), annotation=None), cst.Param(name=cst.Name("x2"), annotation=None), cst.Param(name=cst.Name("x3"), annotation=None), ] self.assertEqual( FunctionAnnotationKind.from_function_data( is_return_annotated=True, annotated_parameter_count=3, is_method_or_classmethod=False, parameters=three_parameters, ), FunctionAnnotationKind.FULLY_ANNOTATED, ) self.assertEqual( FunctionAnnotationKind.from_function_data( is_return_annotated=True, annotated_parameter_count=0, is_method_or_classmethod=False, parameters=three_parameters, ), FunctionAnnotationKind.PARTIALLY_ANNOTATED, ) self.assertEqual( FunctionAnnotationKind.from_function_data( is_return_annotated=False, annotated_parameter_count=0, is_method_or_classmethod=False, parameters=three_parameters, ), FunctionAnnotationKind.NOT_ANNOTATED, ) self.assertEqual( FunctionAnnotationKind.from_function_data( is_return_annotated=False, annotated_parameter_count=1, is_method_or_classmethod=False, parameters=three_parameters, ), FunctionAnnotationKind.PARTIALLY_ANNOTATED, ) # An untyped `self` parameter of a method does not count for partial # annotation. As per PEP 484, we need an explicitly annotated parameter. self.assertEqual( FunctionAnnotationKind.from_function_data( is_return_annotated=False, annotated_parameter_count=1, is_method_or_classmethod=True, parameters=three_parameters, ), FunctionAnnotationKind.NOT_ANNOTATED, ) self.assertEqual( FunctionAnnotationKind.from_function_data( is_return_annotated=False, annotated_parameter_count=2, is_method_or_classmethod=True, parameters=three_parameters, ), FunctionAnnotationKind.PARTIALLY_ANNOTATED, ) # An annotated `self` suffices to make Pyre typecheck the method. self.assertEqual( FunctionAnnotationKind.from_function_data( is_return_annotated=False, annotated_parameter_count=1, is_method_or_classmethod=True, parameters=[ cst.Param( name=cst.Name("self"), annotation=cst.Annotation(cst.Name("Foo")), ) ], ), FunctionAnnotationKind.PARTIALLY_ANNOTATED, ) self.assertEqual( FunctionAnnotationKind.from_function_data( is_return_annotated=False, annotated_parameter_count=0, is_method_or_classmethod=True, parameters=[], ), FunctionAnnotationKind.NOT_ANNOTATED, )
def return_type_annotation(self) -> typing.Optional[cst.Annotation]: return_type_annotation = None if self.return_type: return_type_annotation = cst.Annotation( self.return_type.annotation) return return_type_annotation