def leave_SimpleStatementLine(self, original_node: cst.SimpleStatementLine, updated_node: cst.SimpleStatementLine): if match.matches( original_node, match.SimpleStatementLine(body=[ match.Assign(targets=[ match.AssignTarget(target=match.Name( value=match.DoNotCare())) ]) ])): t = self.__get_var_type_assign_t( original_node.body[0].targets[0].target.value) if t is not None: t_annot_node_resolved = self.resolve_type_alias(t) t_annot_node = self.__name2annotation(t_annot_node_resolved) if t_annot_node is not None: self.all_applied_types.add( (t_annot_node_resolved, t_annot_node)) return updated_node.with_changes(body=[ cst.AnnAssign( target=original_node.body[0].targets[0].target, value=original_node.body[0].value, annotation=t_annot_node, equal=cst.AssignEqual( whitespace_after=original_node.body[0]. targets[0].whitespace_after_equal, whitespace_before=original_node.body[0]. targets[0].whitespace_before_equal)) ]) elif match.matches( original_node, match.SimpleStatementLine(body=[ match.AnnAssign(target=match.Name(value=match.DoNotCare())) ])): t = self.__get_var_type_an_assign( original_node.body[0].target.value) if t is not None: t_annot_node_resolved = self.resolve_type_alias(t) t_annot_node = self.__name2annotation(t_annot_node_resolved) if t_annot_node is not None: self.all_applied_types.add( (t_annot_node_resolved, t_annot_node)) return updated_node.with_changes(body=[ cst.AnnAssign(target=original_node.body[0].target, value=original_node.body[0].value, annotation=t_annot_node, equal=original_node.body[0].equal) ]) return original_node
def leave_Module( self, original_node: cst.Module, updated_node: cst.Module ) -> cst.Module: fresh_class_definitions = [ definition for name, definition in self.annotations.class_definitions.items() if name not in self.visited_classes ] if not self.toplevel_annotations and not fresh_class_definitions: return updated_node toplevel_statements = [] # First, find the insertion point for imports statements_before_imports, statements_after_imports = self._split_module( original_node, updated_node ) # Make sure there's at least one empty line before the first non-import statements_after_imports = self._insert_empty_line(statements_after_imports) for name, annotation in self.toplevel_annotations.items(): annotated_assign = cst.AnnAssign(cst.Name(name), annotation, None) toplevel_statements.append(cst.SimpleStatementLine([annotated_assign])) toplevel_statements.extend(fresh_class_definitions) return updated_node.with_changes( body=[ *statements_before_imports, *toplevel_statements, *statements_after_imports, ] )
def _annotate_single_target( self, node: cst.Assign, updated_node: cst.Assign ) -> Union[cst.Assign, cst.AnnAssign]: only_target = node.targets[0].target if isinstance(only_target, (cst.Tuple, cst.List)): for element in only_target.elements: value = element.value name = get_full_name_for_node(value) if name: self._add_to_toplevel_annotations(name) elif isinstance(only_target, (cst.Subscript)): pass else: name = get_full_name_for_node(only_target) if name is not None: self.qualifier.append(name) if self._qualifier_name() in self.annotations.attribute_annotations and not isinstance( only_target, cst.Subscript ): annotation = self.annotations.attribute_annotations[ self._qualifier_name() ] self.qualifier.pop() return cst.AnnAssign(cst.Name(name), annotation, node.value) else: self.qualifier.pop() return updated_node
def _annotate_single_target( self, node: cst.Assign, updated_node: cst.Assign) -> Union[cst.Assign, cst.AnnAssign]: if isinstance(node.targets[0].target, cst.Tuple): target = node.targets[0].target # pyre-fixme[16]: `BaseAssignTargetExpression` has no attribute `elements`. for element in target.elements: if not isinstance(element.value, cst.Subscript): name = _get_name_as_string(element.value.value) self._add_to_toplevel_annotations(name) return updated_node else: target = node.targets[0].target # pyre-fixme[16]: `BaseAssignTargetExpression` has no attribute `value`. name = _get_name_as_string(target.value) self.qualifier.append(name) if self._qualifier_name( ) in self.attribute_annotations and not isinstance( target, cst.Subscript): annotation = self.attribute_annotations[self._qualifier_name()] self.qualifier.pop() return cst.AnnAssign(cst.Name(name), annotation, node.value) else: self.qualifier.pop() return updated_node
def leave_Module( self, original_node: cst.Module, updated_node: cst.Module ) -> cst.Module: if self.is_generated: return original_node if not self.toplevel_annotations and not self.imports: return updated_node toplevel_statements = [] # First, find the insertion point for imports statements_before_imports, statements_after_imports = self._split_module( original_node, updated_node ) # Make sure there's at least one empty line before the first non-import statements_after_imports = self._insert_empty_line(statements_after_imports) imported = set() for statement in self.import_statements: names = statement.names if isinstance(names, cst.ImportStar): continue for name in names: if name.asname: name = name.asname if name: imported.add(_get_name_as_string(name.name)) for _, import_statement in self.imports.items(): # Filter out anything that has already been imported. names = import_statement.names.difference(imported) names = [cst.ImportAlias(cst.Name(name)) for name in sorted(names)] if not names: continue import_statement = cst.ImportFrom( module=import_statement.module, names=names ) # Add import statements to module body. # Need to assign an Iterable, and the argument to SimpleStatementLine # must be subscriptable. toplevel_statements.append(cst.SimpleStatementLine([import_statement])) for name, annotation in self.toplevel_annotations.items(): annotated_assign = cst.AnnAssign( cst.Name(name), # pyre-fixme[16]: `CSTNode` has no attribute `annotation`. cst.Annotation(annotation.annotation), None, ) toplevel_statements.append(cst.SimpleStatementLine([annotated_assign])) return updated_node.with_changes( body=[ *statements_before_imports, *toplevel_statements, *statements_after_imports, ] )
def _apply_annotation_to_attribute_or_global( self, name: str, annotation: cst.Annotation, value: Optional[cst.BaseExpression], ) -> cst.AnnAssign: if len(self.qualifier) == 0: self.annotation_counts.global_annotations += 1 else: self.annotation_counts.attribute_annotations += 1 return cst.AnnAssign(cst.Name(name), annotation, value)
def leave_Module(self, node: cst.Module, updated_node: cst.Module) -> cst.CSTNode: body = list(updated_node.body) index = self._get_toplevel_index(body) for name, annotation in self.toplevel_annotations.items(): annotated_assign = cst.AnnAssign( cst.Name(name), # pyre-fixme[16]: `CSTNode` has no attribute `annotation`. cst.Annotation(annotation.annotation), None, ) body.insert(index, cst.SimpleStatementLine([annotated_assign])) return updated_node.with_changes(body=tuple(body))
def convert_Assign( node: cst.Assign, annotation: ast.expr, quote_annotations: bool, ) -> Union[ _FailedToApplyAnnotation, cst.AnnAssign, List[Union[cst.AnnAssign, cst.Assign]], ]: # zip the type and target information tother. If there are mismatched # arities, this is a PEP 484 violation (technically we could use # logic beyond the PEP to recover some cases as typing.Tuple, but this # should be rare) so we give up. try: annotations = AnnotationSpreader.unpack_annotation(annotation) annotated_targets = [ AnnotationSpreader.annotated_bindings( bindings=AnnotationSpreader.unpack_target(target.target), annotations=annotations, ) for target in node.targets ] except _ArityError: return _FailedToApplyAnnotation() if len(annotated_targets) == 1 and len(annotated_targets[0]) == 1: # We can convert simple one-target assignments into a single AnnAssign binding, raw_annotation = annotated_targets[0][0] return cst.AnnAssign( target=binding, annotation=_convert_annotation( raw=raw_annotation, quote_annotations=quote_annotations, ), value=node.value, semicolon=node.semicolon, ) else: # For multi-target assigns (regardless of whether they are using tuples # on the LHS or multiple `=` tokens or both), we need to add a type # declaration per individual LHS target. type_declarations = [ AnnotationSpreader.type_declaration( binding, raw_annotation, quote_annotations=quote_annotations, ) for annotated_bindings in annotated_targets for binding, raw_annotation in annotated_bindings ] return [ *type_declarations, node, ]
def type_declaration( binding: cst.BaseAssignTargetExpression, raw_annotation: str, quote_annotations: bool, ) -> cst.AnnAssign: return cst.AnnAssign( target=binding, annotation=_convert_annotation( raw=raw_annotation, quote_annotations=quote_annotations, ), value=None, )
def _annotate_single_target(self, node: cst.Assign, updated_node: cst.Assign) -> cst.CSTNode: if isinstance(node.targets[0].target, cst.Tuple): target = node.targets[0].target # pyre-fixme[16]: `BaseAssignTargetExpression` has no attribute `elements`. for element in target.elements: self._add_to_toplevel_annotations(element.value.value) return updated_node else: # pyre-fixme[16]: `BaseAssignTargetExpression` has no attribute `value`. name = node.targets[0].target.value self.qualifier.append(name) if self._qualifier_name() in self.attribute_annotations: annotation = self.attribute_annotations[self._qualifier_name()] self.qualifier.pop() return cst.AnnAssign(cst.Name(name), annotation, node.value) else: self.qualifier.pop() return updated_node
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: if not self.toplevel_annotations and not self.imports: return updated_node toplevel_statements = [] # First, find the insertion point for imports statements_before_imports, statements_after_imports = self._split_module( original_node, updated_node) # Make sure there's at least one empty line before the first non-import statements_after_imports = self._insert_empty_line( statements_after_imports) for _, import_statement in self.imports.items(): import_statement = cst.ImportFrom( module=import_statement.module, # pyre-fixme[6]: Expected `Union[Sequence[ImportAlias], ImportStar]` # for 2nd param but got `List[ImportFrom]`. names=import_statement.names, ) # Add import statements to module body. # Need to assign an Iterable, and the argument to SimpleStatementLine # must be subscriptable. toplevel_statements.append( cst.SimpleStatementLine([import_statement])) for name, annotation in self.toplevel_annotations.items(): annotated_assign = cst.AnnAssign( cst.Name(name), # pyre-fixme[16]: `CSTNode` has no attribute `annotation`. cst.Annotation(annotation.annotation), None, ) toplevel_statements.append( cst.SimpleStatementLine([annotated_assign])) return updated_node.with_changes(body=[ *statements_before_imports, *toplevel_statements, *statements_after_imports, ])
def assign_properties( p: typing.Dict[str, typing.Tuple[Metadata, Type]], is_classvar=False) -> typing.Iterable[cst.SimpleStatementLine]: for name, metadata_and_tp in sort_items(p): if bad_name(name): continue metadata, tp = metadata_and_tp ann = tp.annotation yield cst.SimpleStatementLine( [ cst.AnnAssign( cst.Name(name), cst.Annotation( cst.Subscript(cst.Name("ClassVar"), [cst.SubscriptElement(cst.Index(ann))] ) if is_classvar else ann), ) ], leading_lines=[cst.EmptyLine()] + [ cst.EmptyLine(comment=cst.Comment("# " + l)) for l in metadata_lines(metadata) ], )
class AnnAssignTest(CSTNodeTest): @data_provider(( # Simple assignment creation case. { "node": cst.AnnAssign(cst.Name("foo"), cst.Annotation(cst.Name("str")), cst.Integer("5")), "code": "foo: str = 5", "parser": None, "expected_position": CodeRange((1, 0), (1, 12)), }, # Annotation creation without assignment { "node": cst.AnnAssign(cst.Name("foo"), cst.Annotation(cst.Name("str"))), "code": "foo: str", "parser": None, "expected_position": CodeRange((1, 0), (1, 8)), }, # Complex annotation creation { "node": cst.AnnAssign( cst.Name("foo"), cst.Annotation( cst.Subscript( cst.Name("Optional"), (cst.SubscriptElement(cst.Index(cst.Name("str"))), ), )), cst.Integer("5"), ), "code": "foo: Optional[str] = 5", "parser": None, "expected_position": CodeRange((1, 0), (1, 22)), }, # Simple assignment parser case. { "node": cst.SimpleStatementLine((cst.AnnAssign( target=cst.Name("foo"), annotation=cst.Annotation( annotation=cst.Name("str"), whitespace_before_indicator=cst.SimpleWhitespace(""), ), equal=cst.AssignEqual(), value=cst.Integer("5"), ), )), "code": "foo: str = 5\n", "parser": parse_statement, "expected_position": None, }, # Annotation without assignment { "node": cst.SimpleStatementLine((cst.AnnAssign( target=cst.Name("foo"), annotation=cst.Annotation( annotation=cst.Name("str"), whitespace_before_indicator=cst.SimpleWhitespace(""), ), value=None, ), )), "code": "foo: str\n", "parser": parse_statement, "expected_position": None, }, # Complex annotation { "node": cst.SimpleStatementLine((cst.AnnAssign( target=cst.Name("foo"), annotation=cst.Annotation( annotation=cst.Subscript( cst.Name("Optional"), (cst.SubscriptElement(cst.Index(cst.Name("str"))), ), ), whitespace_before_indicator=cst.SimpleWhitespace(""), ), equal=cst.AssignEqual(), value=cst.Integer("5"), ), )), "code": "foo: Optional[str] = 5\n", "parser": parse_statement, "expected_position": None, }, # Whitespace test { "node": cst.AnnAssign( target=cst.Name("foo"), annotation=cst.Annotation( annotation=cst.Subscript( cst.Name("Optional"), (cst.SubscriptElement(cst.Index(cst.Name("str"))), ), ), whitespace_before_indicator=cst.SimpleWhitespace(" "), whitespace_after_indicator=cst.SimpleWhitespace(" "), ), equal=cst.AssignEqual( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), value=cst.Integer("5"), ), "code": "foo : Optional[str] = 5", "parser": None, "expected_position": CodeRange((1, 0), (1, 26)), }, { "node": cst.SimpleStatementLine((cst.AnnAssign( target=cst.Name("foo"), annotation=cst.Annotation( annotation=cst.Subscript( cst.Name("Optional"), (cst.SubscriptElement(cst.Index(cst.Name("str"))), ), ), whitespace_before_indicator=cst.SimpleWhitespace(" "), whitespace_after_indicator=cst.SimpleWhitespace(" "), ), equal=cst.AssignEqual( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), value=cst.Integer("5"), ), )), "code": "foo : Optional[str] = 5\n", "parser": parse_statement, "expected_position": None, }, )) def test_valid(self, **kwargs: Any) -> None: self.validate_node(**kwargs) @data_provider(({ "get_node": (lambda: cst.AnnAssign( target=cst.Name("foo"), annotation=cst.Annotation(cst.Name("str")), equal=cst.AssignEqual(), value=None, )), "expected_re": "Must have a value when specifying an AssignEqual.", }, )) def test_invalid(self, **kwargs: Any) -> None: self.assert_invalid(**kwargs)