예제 #1
0
    def leave_SimpleStatementLine(self, original_node: cst.SimpleStatementLine,
                                  updated_node: cst.SimpleStatementLine):
        if match.matches(
                original_node,
                match.SimpleStatementLine(body=[
                    match.Assign(targets=[
                        match.AssignTarget(target=match.Name(
                            value=match.DoNotCare()))
                    ])
                ])):
            t = self.__get_var_type_assign_t(
                original_node.body[0].targets[0].target.value)

            if t is not None:
                t_annot_node_resolved = self.resolve_type_alias(t)
                t_annot_node = self.__name2annotation(t_annot_node_resolved)
                if t_annot_node is not None:
                    self.all_applied_types.add(
                        (t_annot_node_resolved, t_annot_node))
                    return updated_node.with_changes(body=[
                        cst.AnnAssign(
                            target=original_node.body[0].targets[0].target,
                            value=original_node.body[0].value,
                            annotation=t_annot_node,
                            equal=cst.AssignEqual(
                                whitespace_after=original_node.body[0].
                                targets[0].whitespace_after_equal,
                                whitespace_before=original_node.body[0].
                                targets[0].whitespace_before_equal))
                    ])
        elif match.matches(
                original_node,
                match.SimpleStatementLine(body=[
                    match.AnnAssign(target=match.Name(value=match.DoNotCare()))
                ])):
            t = self.__get_var_type_an_assign(
                original_node.body[0].target.value)
            if t is not None:
                t_annot_node_resolved = self.resolve_type_alias(t)
                t_annot_node = self.__name2annotation(t_annot_node_resolved)
                if t_annot_node is not None:
                    self.all_applied_types.add(
                        (t_annot_node_resolved, t_annot_node))
                    return updated_node.with_changes(body=[
                        cst.AnnAssign(target=original_node.body[0].target,
                                      value=original_node.body[0].value,
                                      annotation=t_annot_node,
                                      equal=original_node.body[0].equal)
                    ])

        return original_node
    def leave_Module(
        self, original_node: cst.Module, updated_node: cst.Module
    ) -> cst.Module:
        fresh_class_definitions = [
            definition
            for name, definition in self.annotations.class_definitions.items()
            if name not in self.visited_classes
        ]
        if not self.toplevel_annotations and not fresh_class_definitions:
            return updated_node
        toplevel_statements = []
        # First, find the insertion point for imports
        statements_before_imports, statements_after_imports = self._split_module(
            original_node, updated_node
        )

        # Make sure there's at least one empty line before the first non-import
        statements_after_imports = self._insert_empty_line(statements_after_imports)

        for name, annotation in self.toplevel_annotations.items():
            annotated_assign = cst.AnnAssign(cst.Name(name), annotation, None)
            toplevel_statements.append(cst.SimpleStatementLine([annotated_assign]))

        toplevel_statements.extend(fresh_class_definitions)

        return updated_node.with_changes(
            body=[
                *statements_before_imports,
                *toplevel_statements,
                *statements_after_imports,
            ]
        )
 def _annotate_single_target(
     self, node: cst.Assign, updated_node: cst.Assign
 ) -> Union[cst.Assign, cst.AnnAssign]:
     only_target = node.targets[0].target
     if isinstance(only_target, (cst.Tuple, cst.List)):
         for element in only_target.elements:
             value = element.value
             name = get_full_name_for_node(value)
             if name:
                 self._add_to_toplevel_annotations(name)
     elif isinstance(only_target, (cst.Subscript)):
         pass
     else:
         name = get_full_name_for_node(only_target)
         if name is not None:
             self.qualifier.append(name)
             if self._qualifier_name() in self.annotations.attribute_annotations and not isinstance(
                 only_target, cst.Subscript
             ):
                 annotation = self.annotations.attribute_annotations[
                     self._qualifier_name()
                 ]
                 self.qualifier.pop()
                 return cst.AnnAssign(cst.Name(name), annotation, node.value)
             else:
                 self.qualifier.pop()
     return updated_node
예제 #4
0
 def _annotate_single_target(
         self, node: cst.Assign,
         updated_node: cst.Assign) -> Union[cst.Assign, cst.AnnAssign]:
     if isinstance(node.targets[0].target, cst.Tuple):
         target = node.targets[0].target
         # pyre-fixme[16]: `BaseAssignTargetExpression` has no attribute `elements`.
         for element in target.elements:
             if not isinstance(element.value, cst.Subscript):
                 name = _get_name_as_string(element.value.value)
                 self._add_to_toplevel_annotations(name)
         return updated_node
     else:
         target = node.targets[0].target
         # pyre-fixme[16]: `BaseAssignTargetExpression` has no attribute `value`.
         name = _get_name_as_string(target.value)
         self.qualifier.append(name)
         if self._qualifier_name(
         ) in self.attribute_annotations and not isinstance(
                 target, cst.Subscript):
             annotation = self.attribute_annotations[self._qualifier_name()]
             self.qualifier.pop()
             return cst.AnnAssign(cst.Name(name), annotation, node.value)
         else:
             self.qualifier.pop()
             return updated_node
예제 #5
0
    def leave_Module(
        self, original_node: cst.Module, updated_node: cst.Module
    ) -> cst.Module:
        if self.is_generated:
            return original_node
        if not self.toplevel_annotations and not self.imports:
            return updated_node
        toplevel_statements = []
        # First, find the insertion point for imports
        statements_before_imports, statements_after_imports = self._split_module(
            original_node, updated_node
        )

        # Make sure there's at least one empty line before the first non-import
        statements_after_imports = self._insert_empty_line(statements_after_imports)

        imported = set()
        for statement in self.import_statements:
            names = statement.names
            if isinstance(names, cst.ImportStar):
                continue
            for name in names:
                if name.asname:
                    name = name.asname
                if name:
                    imported.add(_get_name_as_string(name.name))

        for _, import_statement in self.imports.items():
            # Filter out anything that has already been imported.
            names = import_statement.names.difference(imported)
            names = [cst.ImportAlias(cst.Name(name)) for name in sorted(names)]
            if not names:
                continue
            import_statement = cst.ImportFrom(
                module=import_statement.module, names=names
            )
            # Add import statements to module body.
            # Need to assign an Iterable, and the argument to SimpleStatementLine
            # must be subscriptable.
            toplevel_statements.append(cst.SimpleStatementLine([import_statement]))

        for name, annotation in self.toplevel_annotations.items():
            annotated_assign = cst.AnnAssign(
                cst.Name(name),
                # pyre-fixme[16]: `CSTNode` has no attribute `annotation`.
                cst.Annotation(annotation.annotation),
                None,
            )
            toplevel_statements.append(cst.SimpleStatementLine([annotated_assign]))

        return updated_node.with_changes(
            body=[
                *statements_before_imports,
                *toplevel_statements,
                *statements_after_imports,
            ]
        )
예제 #6
0
 def _apply_annotation_to_attribute_or_global(
     self,
     name: str,
     annotation: cst.Annotation,
     value: Optional[cst.BaseExpression],
 ) -> cst.AnnAssign:
     if len(self.qualifier) == 0:
         self.annotation_counts.global_annotations += 1
     else:
         self.annotation_counts.attribute_annotations += 1
     return cst.AnnAssign(cst.Name(name), annotation, value)
예제 #7
0
 def leave_Module(self, node: cst.Module, updated_node: cst.Module) -> cst.CSTNode:
     body = list(updated_node.body)
     index = self._get_toplevel_index(body)
     for name, annotation in self.toplevel_annotations.items():
         annotated_assign = cst.AnnAssign(
             cst.Name(name),
             # pyre-fixme[16]: `CSTNode` has no attribute `annotation`.
             cst.Annotation(annotation.annotation),
             None,
         )
         body.insert(index, cst.SimpleStatementLine([annotated_assign]))
     return updated_node.with_changes(body=tuple(body))
예제 #8
0
def convert_Assign(
    node: cst.Assign,
    annotation: ast.expr,
    quote_annotations: bool,
) -> Union[
    _FailedToApplyAnnotation,
    cst.AnnAssign,
    List[Union[cst.AnnAssign, cst.Assign]],
]:
    # zip the type and target information tother. If there are mismatched
    # arities, this is a PEP 484 violation (technically we could use
    # logic beyond the PEP to recover some cases as typing.Tuple, but this
    # should be rare) so we give up.
    try:
        annotations = AnnotationSpreader.unpack_annotation(annotation)
        annotated_targets = [
            AnnotationSpreader.annotated_bindings(
                bindings=AnnotationSpreader.unpack_target(target.target),
                annotations=annotations,
            )
            for target in node.targets
        ]
    except _ArityError:
        return _FailedToApplyAnnotation()
    if len(annotated_targets) == 1 and len(annotated_targets[0]) == 1:
        # We can convert simple one-target assignments into a single AnnAssign
        binding, raw_annotation = annotated_targets[0][0]
        return cst.AnnAssign(
            target=binding,
            annotation=_convert_annotation(
                raw=raw_annotation,
                quote_annotations=quote_annotations,
            ),
            value=node.value,
            semicolon=node.semicolon,
        )
    else:
        # For multi-target assigns (regardless of whether they are using tuples
        # on the LHS or multiple `=` tokens or both), we need to add a type
        # declaration per individual LHS target.
        type_declarations = [
            AnnotationSpreader.type_declaration(
                binding,
                raw_annotation,
                quote_annotations=quote_annotations,
            )
            for annotated_bindings in annotated_targets
            for binding, raw_annotation in annotated_bindings
        ]
        return [
            *type_declarations,
            node,
        ]
예제 #9
0
 def type_declaration(
     binding: cst.BaseAssignTargetExpression,
     raw_annotation: str,
     quote_annotations: bool,
 ) -> cst.AnnAssign:
     return cst.AnnAssign(
         target=binding,
         annotation=_convert_annotation(
             raw=raw_annotation,
             quote_annotations=quote_annotations,
         ),
         value=None,
     )
예제 #10
0
 def _annotate_single_target(self, node: cst.Assign,
                             updated_node: cst.Assign) -> cst.CSTNode:
     if isinstance(node.targets[0].target, cst.Tuple):
         target = node.targets[0].target
         # pyre-fixme[16]: `BaseAssignTargetExpression` has no attribute `elements`.
         for element in target.elements:
             self._add_to_toplevel_annotations(element.value.value)
         return updated_node
     else:
         # pyre-fixme[16]: `BaseAssignTargetExpression` has no attribute `value`.
         name = node.targets[0].target.value
         self.qualifier.append(name)
         if self._qualifier_name() in self.attribute_annotations:
             annotation = self.attribute_annotations[self._qualifier_name()]
             self.qualifier.pop()
             return cst.AnnAssign(cst.Name(name), annotation, node.value)
         else:
             self.qualifier.pop()
             return updated_node
예제 #11
0
    def leave_Module(self, original_node: cst.Module,
                     updated_node: cst.Module) -> cst.Module:
        if not self.toplevel_annotations and not self.imports:
            return updated_node

        toplevel_statements = []

        # First, find the insertion point for imports
        statements_before_imports, statements_after_imports = self._split_module(
            original_node, updated_node)

        # Make sure there's at least one empty line before the first non-import
        statements_after_imports = self._insert_empty_line(
            statements_after_imports)

        for _, import_statement in self.imports.items():
            import_statement = cst.ImportFrom(
                module=import_statement.module,
                # pyre-fixme[6]: Expected `Union[Sequence[ImportAlias], ImportStar]`
                #  for 2nd param but got `List[ImportFrom]`.
                names=import_statement.names,
            )
            # Add import statements to module body.
            # Need to assign an Iterable, and the argument to SimpleStatementLine
            # must be subscriptable.
            toplevel_statements.append(
                cst.SimpleStatementLine([import_statement]))

        for name, annotation in self.toplevel_annotations.items():
            annotated_assign = cst.AnnAssign(
                cst.Name(name),
                # pyre-fixme[16]: `CSTNode` has no attribute `annotation`.
                cst.Annotation(annotation.annotation),
                None,
            )
            toplevel_statements.append(
                cst.SimpleStatementLine([annotated_assign]))

        return updated_node.with_changes(body=[
            *statements_before_imports,
            *toplevel_statements,
            *statements_after_imports,
        ])
예제 #12
0
def assign_properties(
        p: typing.Dict[str, typing.Tuple[Metadata, Type]],
        is_classvar=False) -> typing.Iterable[cst.SimpleStatementLine]:
    for name, metadata_and_tp in sort_items(p):
        if bad_name(name):
            continue
        metadata, tp = metadata_and_tp
        ann = tp.annotation
        yield cst.SimpleStatementLine(
            [
                cst.AnnAssign(
                    cst.Name(name),
                    cst.Annotation(
                        cst.Subscript(cst.Name("ClassVar"),
                                      [cst.SubscriptElement(cst.Index(ann))]
                                      ) if is_classvar else ann),
                )
            ],
            leading_lines=[cst.EmptyLine()] + [
                cst.EmptyLine(comment=cst.Comment("# " + l))
                for l in metadata_lines(metadata)
            ],
        )
예제 #13
0
class AnnAssignTest(CSTNodeTest):
    @data_provider((
        # Simple assignment creation case.
        {
            "node":
            cst.AnnAssign(cst.Name("foo"), cst.Annotation(cst.Name("str")),
                          cst.Integer("5")),
            "code":
            "foo: str = 5",
            "parser":
            None,
            "expected_position":
            CodeRange((1, 0), (1, 12)),
        },
        # Annotation creation without assignment
        {
            "node": cst.AnnAssign(cst.Name("foo"),
                                  cst.Annotation(cst.Name("str"))),
            "code": "foo: str",
            "parser": None,
            "expected_position": CodeRange((1, 0), (1, 8)),
        },
        # Complex annotation creation
        {
            "node":
            cst.AnnAssign(
                cst.Name("foo"),
                cst.Annotation(
                    cst.Subscript(
                        cst.Name("Optional"),
                        (cst.SubscriptElement(cst.Index(cst.Name("str"))), ),
                    )),
                cst.Integer("5"),
            ),
            "code":
            "foo: Optional[str] = 5",
            "parser":
            None,
            "expected_position":
            CodeRange((1, 0), (1, 22)),
        },
        # Simple assignment parser case.
        {
            "node":
            cst.SimpleStatementLine((cst.AnnAssign(
                target=cst.Name("foo"),
                annotation=cst.Annotation(
                    annotation=cst.Name("str"),
                    whitespace_before_indicator=cst.SimpleWhitespace(""),
                ),
                equal=cst.AssignEqual(),
                value=cst.Integer("5"),
            ), )),
            "code":
            "foo: str = 5\n",
            "parser":
            parse_statement,
            "expected_position":
            None,
        },
        # Annotation without assignment
        {
            "node":
            cst.SimpleStatementLine((cst.AnnAssign(
                target=cst.Name("foo"),
                annotation=cst.Annotation(
                    annotation=cst.Name("str"),
                    whitespace_before_indicator=cst.SimpleWhitespace(""),
                ),
                value=None,
            ), )),
            "code":
            "foo: str\n",
            "parser":
            parse_statement,
            "expected_position":
            None,
        },
        # Complex annotation
        {
            "node":
            cst.SimpleStatementLine((cst.AnnAssign(
                target=cst.Name("foo"),
                annotation=cst.Annotation(
                    annotation=cst.Subscript(
                        cst.Name("Optional"),
                        (cst.SubscriptElement(cst.Index(cst.Name("str"))), ),
                    ),
                    whitespace_before_indicator=cst.SimpleWhitespace(""),
                ),
                equal=cst.AssignEqual(),
                value=cst.Integer("5"),
            ), )),
            "code":
            "foo: Optional[str] = 5\n",
            "parser":
            parse_statement,
            "expected_position":
            None,
        },
        # Whitespace test
        {
            "node":
            cst.AnnAssign(
                target=cst.Name("foo"),
                annotation=cst.Annotation(
                    annotation=cst.Subscript(
                        cst.Name("Optional"),
                        (cst.SubscriptElement(cst.Index(cst.Name("str"))), ),
                    ),
                    whitespace_before_indicator=cst.SimpleWhitespace(" "),
                    whitespace_after_indicator=cst.SimpleWhitespace("  "),
                ),
                equal=cst.AssignEqual(
                    whitespace_before=cst.SimpleWhitespace("  "),
                    whitespace_after=cst.SimpleWhitespace("  "),
                ),
                value=cst.Integer("5"),
            ),
            "code":
            "foo :  Optional[str]  =  5",
            "parser":
            None,
            "expected_position":
            CodeRange((1, 0), (1, 26)),
        },
        {
            "node":
            cst.SimpleStatementLine((cst.AnnAssign(
                target=cst.Name("foo"),
                annotation=cst.Annotation(
                    annotation=cst.Subscript(
                        cst.Name("Optional"),
                        (cst.SubscriptElement(cst.Index(cst.Name("str"))), ),
                    ),
                    whitespace_before_indicator=cst.SimpleWhitespace(" "),
                    whitespace_after_indicator=cst.SimpleWhitespace("  "),
                ),
                equal=cst.AssignEqual(
                    whitespace_before=cst.SimpleWhitespace("  "),
                    whitespace_after=cst.SimpleWhitespace("  "),
                ),
                value=cst.Integer("5"),
            ), )),
            "code":
            "foo :  Optional[str]  =  5\n",
            "parser":
            parse_statement,
            "expected_position":
            None,
        },
    ))
    def test_valid(self, **kwargs: Any) -> None:
        self.validate_node(**kwargs)

    @data_provider(({
        "get_node": (lambda: cst.AnnAssign(
            target=cst.Name("foo"),
            annotation=cst.Annotation(cst.Name("str")),
            equal=cst.AssignEqual(),
            value=None,
        )),
        "expected_re":
        "Must have a value when specifying an AssignEqual.",
    }, ))
    def test_invalid(self, **kwargs: Any) -> None:
        self.assert_invalid(**kwargs)