コード例 #1
0
 def _create_import_from_annotation(self, returns: cst.Annotation) -> cst.Annotation:
     annotation = returns.annotation
     if isinstance(annotation, cst.Attribute):
         attr = self._add_annotation_to_imports(annotation)
         return cst.Annotation(annotation=attr)
     if isinstance(annotation, cst.Subscript):
         return cst.Annotation(annotation=self._handle_Subscript(annotation))
     else:
         return returns
コード例 #2
0
 def _handle_Annotation(self, annotation: cst.Annotation) -> cst.Annotation:
     node = annotation.annotation
     if isinstance(node, cst.SimpleString):
         return annotation
     elif isinstance(node, cst.Subscript):
         return cst.Annotation(annotation=self._handle_Subscript(node))
     elif isinstance(node, NAME_OR_ATTRIBUTE):
         return cst.Annotation(
             annotation=self._handle_NameOrAttribute(node))
     else:
         raise ValueError(f"Unexpected annotation node: {node}")
コード例 #3
0
 def _create_import_from_annotation(self, returns: cst.Annotation) -> cst.Annotation:
     annotation = returns.annotation
     if isinstance(annotation, cst.Attribute):
         attr = self._add_annotation_to_imports(annotation)
         return cst.Annotation(annotation=attr)
     if isinstance(annotation, cst.Subscript):
         value = annotation.value
         if isinstance(value, cst.Name) and value.value == "Type":
             return returns
         return cst.Annotation(annotation=self._handle_Subscript(annotation))
     else:
         return returns
コード例 #4
0
ファイル: annotations.py プロジェクト: kalekseev/nplint
 def leave_FunctionDef(self, node: cst.FunctionDef,
                       updated_node: cst.FunctionDef) -> cst.FunctionDef:
     returns = self.stack.pop()
     if returns is None:
         return updated_node
     if not returns:
         return updated_node.with_changes(returns=cst.Annotation(
             annotation=cst.Name(value="None")))
     last_line = node.body.body[-1]
     if not isinstance(last_line, cst.SimpleStatementLine):
         if returns and all(r.value is None or isinstance(
                 r.value, cst.Name) and r.value.value == 'None'
                            for r in returns):
             return updated_node.with_changes(returns=cst.Annotation(
                 annotation=cst.Name(value="None")))
         return updated_node
     elif not isinstance(last_line.body[-1], cst.Return):
         if returns and all(r.value is None or isinstance(
                 r.value, cst.Name) and r.value.value == 'None'
                            for r in returns):
             return updated_node.with_changes(returns=cst.Annotation(
                 annotation=cst.Name(value="None")))
         return updated_node
     if len(returns) == 1:
         rvalue = returns[0].value
         if isinstance(rvalue, cst.BaseString):
             if isinstance(
                     rvalue,
                     cst.SimpleString) and rvalue.value.startswith("b"):
                 return updated_node.with_changes(returns=cst.Annotation(
                     annotation=cst.Name(value="bytes")))
             return updated_node.with_changes(returns=cst.Annotation(
                 annotation=cst.Name(value="str")))
         if isinstance(rvalue, cst.Name):
             if rvalue.value in ("False", "True"):
                 return updated_node.with_changes(returns=cst.Annotation(
                     annotation=cst.Name(value="bool")))
             if rvalue.value == "None":
                 return updated_node.with_changes(returns=cst.Annotation(
                     annotation=cst.Name(value="None")))
         if isinstance(rvalue, cst.Integer):
             return updated_node.with_changes(returns=cst.Annotation(
                 annotation=cst.Name(value="int")))
         if isinstance(rvalue, cst.Float):
             return updated_node.with_changes(returns=cst.Annotation(
                 annotation=cst.Name(value="float")))
     elif returns and all(r.value is None or isinstance(r.value, cst.Name)
                          and r.value.value == 'None' for r in returns):
         return updated_node.with_changes(returns=cst.Annotation(
             annotation=cst.Name(value="None")))
     return updated_node
コード例 #5
0
    def leave_FunctionDef(self, original_node: cst.FunctionDef,
                          updated_node: cst.FunctionDef) -> cst.FunctionDef:
        docstring = None
        docstring_node = get_docstring_node(updated_node.body)
        if docstring_node:
            if isinstance(docstring_node.value,
                          (cst.SimpleString, cst.ConcatenatedString)):
                docstring = docstring_node.value.evaluated_value
        if not docstring:
            return updated_node
        new_docstring, types = gather_types(docstring)
        if types.get(RETURN):
            updated_node = updated_node.with_changes(returns=cst.Annotation(
                cst.Name(types.pop(RETURN))), )

        if types:

            def get_annotation(p: cst.Param) -> Optional[cst.Annotation]:
                pname = p.name.value
                if types.get(pname):
                    return cst.Annotation(cst.parse_expression(types[pname]))
                return None

            updated_node = updated_node.with_changes(params=update_parameters(
                updated_node.params, get_annotation, False))

        new_docstring_node = cst.SimpleString('"""%s"""' % new_docstring)
        return updated_node.deep_replace(docstring_node,
                                         cst.Expr(new_docstring_node))
コード例 #6
0
    def leave_FunctionDef(
        self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef
    ) -> cst.FunctionDef:
        if matchers.matches(updated_node, self.matcher):
            return updated_node.with_changes(returns=cst.Annotation(cst.Name(value="None")))

        return updated_node
コード例 #7
0
    def leave_Module(
        self, original_node: cst.Module, updated_node: cst.Module
    ) -> cst.Module:
        if self.is_generated:
            return original_node
        if not self.toplevel_annotations and not self.imports:
            return updated_node
        toplevel_statements = []
        # First, find the insertion point for imports
        statements_before_imports, statements_after_imports = self._split_module(
            original_node, updated_node
        )

        # Make sure there's at least one empty line before the first non-import
        statements_after_imports = self._insert_empty_line(statements_after_imports)

        imported = set()
        for statement in self.import_statements:
            names = statement.names
            if isinstance(names, cst.ImportStar):
                continue
            for name in names:
                if name.asname:
                    name = name.asname
                if name:
                    imported.add(_get_name_as_string(name.name))

        for _, import_statement in self.imports.items():
            # Filter out anything that has already been imported.
            names = import_statement.names.difference(imported)
            names = [cst.ImportAlias(cst.Name(name)) for name in sorted(names)]
            if not names:
                continue
            import_statement = cst.ImportFrom(
                module=import_statement.module, names=names
            )
            # Add import statements to module body.
            # Need to assign an Iterable, and the argument to SimpleStatementLine
            # must be subscriptable.
            toplevel_statements.append(cst.SimpleStatementLine([import_statement]))

        for name, annotation in self.toplevel_annotations.items():
            annotated_assign = cst.AnnAssign(
                cst.Name(name),
                # pyre-fixme[16]: `CSTNode` has no attribute `annotation`.
                cst.Annotation(annotation.annotation),
                None,
            )
            toplevel_statements.append(cst.SimpleStatementLine([annotated_assign]))

        return updated_node.with_changes(
            body=[
                *statements_before_imports,
                *toplevel_statements,
                *statements_after_imports,
            ]
        )
コード例 #8
0
 def _create_import_from_annotation(self,
                                    returns: cst.CSTNode) -> cst.CSTNode:
     # pyre-fixme[16]: `CSTNode` has no attribute `annotation`.
     if isinstance(returns.annotation, cst.Attribute):
         annotation = returns.annotation
         key = _get_attribute_as_string(annotation.value)
         self._add_to_imports([cst.ImportAlias(name=annotation.attr)],
                              annotation.value, key)
         return cst.Annotation(annotation=returns.annotation.attr)
     else:
         return returns
コード例 #9
0
 def leave_Module(self, node: cst.Module, updated_node: cst.Module) -> cst.CSTNode:
     body = list(updated_node.body)
     index = self._get_toplevel_index(body)
     for name, annotation in self.toplevel_annotations.items():
         annotated_assign = cst.AnnAssign(
             cst.Name(name),
             # pyre-fixme[16]: `CSTNode` has no attribute `annotation`.
             cst.Annotation(annotation.annotation),
             None,
         )
         body.insert(index, cst.SimpleStatementLine([annotated_assign]))
     return updated_node.with_changes(body=tuple(body))
コード例 #10
0
def _convert_annotation(
    raw: str,
    quote_annotations: bool,
) -> cst.Annotation:
    """
    Convert a raw annotation - which is a string coming from a type
    comment - into a suitable libcst Annotation node.

    If `quote_annotations`, we'll always quote annotations unless they are builtin
    types. The reason for this is to make the codemod safer to apply
    on legacy code where type comments may well include invalid types
    that would crash at runtime.
    """
    if _is_builtin(raw):
        return cst.Annotation(annotation=cst.Name(value=raw))
    if not quote_annotations:
        try:
            return cst.Annotation(annotation=cst.parse_expression(raw))
        except cst.ParserSyntaxError:
            pass
    return cst.Annotation(annotation=cst.SimpleString(f'"{raw}"'))
コード例 #11
0
ファイル: apis.py プロジェクト: vfdev-5/python-record-api
    def parameters(
        self, type: typing.Literal["function", "classmethod", "method"]
    ) -> cst.Parameters:
        posonly_params = [
            cst.Param(cst.Name(k), cst.Annotation(v.annotation))
            for k, v in self.pos_only_required.items()
        ] + [
            cst.Param(cst.Name(k),
                      cst.Annotation(v.annotation),
                      default=cst.Ellipsis()) for k, v in possibly_order_dict(
                          self.pos_only_optional,
                          self.pos_only_optional_ordering).items()
        ]

        if type == "classmethod":
            posonly_params.insert(0, cst.Param(cst.Name("cls")))
        elif type == "method":
            posonly_params.insert(0, cst.Param(cst.Name("self")))

        return cst.Parameters(
            posonly_params=posonly_params,
            params=[
                cst.Param(cst.Name(k), cst.Annotation(v.annotation))
                for k, v in self.pos_or_kw_required.items()
            ] + [
                cst.Param(cst.Name(k),
                          cst.Annotation(v.annotation),
                          default=cst.Ellipsis())
                for k, v in possibly_order_dict(
                    self.pos_or_kw_optional,
                    self.pos_or_kw_optional_ordering).items()
            ],
            star_arg=(cst.Param(
                cst.Name(self.var_pos[0]),
                cst.Annotation(self.var_pos[1].annotation),
            ) if self.var_pos else cst.MaybeSentinel.DEFAULT),
            star_kwarg=(cst.Param(cst.Name(self.var_kw[0]),
                                  cst.Annotation(self.var_kw[1].annotation))
                        if self.var_kw else None),
            kwonly_params=[
                cst.Param(cst.Name(k), cst.Annotation(v.annotation))
                for k, v in self.kw_only_required.items()
            ] + [
                cst.Param(cst.Name(k),
                          cst.Annotation(v.annotation),
                          default=cst.Ellipsis())
                for k, v in self.kw_only_optional.items()
            ],
        )
コード例 #12
0
    def test_annotation(self) -> None:
        # Test that we can insert an annotation expression normally.
        statement = parse_template_statement(
            "x: {type} = {val}", type=cst.Name("int"), val=cst.Integer("5"),
        )
        self.assertEqual(
            self.code(statement), "x: int = 5\n",
        )

        # Test that we can insert an annotation node as a special case.
        statement = parse_template_statement(
            "x: {type} = {val}",
            type=cst.Annotation(cst.Name("int")),
            val=cst.Integer("5"),
        )
        self.assertEqual(
            self.code(statement), "x: int = 5\n",
        )
コード例 #13
0
ファイル: apply_annotations.py プロジェクト: nimar/pyre-check
    def leave_Module(self, original_node: cst.Module,
                     updated_node: cst.Module) -> cst.Module:
        if not self.toplevel_annotations and not self.imports:
            return updated_node

        toplevel_statements = []

        # First, find the insertion point for imports
        statements_before_imports, statements_after_imports = self._split_module(
            original_node, updated_node)

        # Make sure there's at least one empty line before the first non-import
        statements_after_imports = self._insert_empty_line(
            statements_after_imports)

        for _, import_statement in self.imports.items():
            import_statement = cst.ImportFrom(
                module=import_statement.module,
                # pyre-fixme[6]: Expected `Union[Sequence[ImportAlias], ImportStar]`
                #  for 2nd param but got `List[ImportFrom]`.
                names=import_statement.names,
            )
            # Add import statements to module body.
            # Need to assign an Iterable, and the argument to SimpleStatementLine
            # must be subscriptable.
            toplevel_statements.append(
                cst.SimpleStatementLine([import_statement]))

        for name, annotation in self.toplevel_annotations.items():
            annotated_assign = cst.AnnAssign(
                cst.Name(name),
                # pyre-fixme[16]: `CSTNode` has no attribute `annotation`.
                cst.Annotation(annotation.annotation),
                None,
            )
            toplevel_statements.append(
                cst.SimpleStatementLine([annotated_assign]))

        return updated_node.with_changes(body=[
            *statements_before_imports,
            *toplevel_statements,
            *statements_after_imports,
        ])
コード例 #14
0
ファイル: apis.py プロジェクト: vfdev-5/python-record-api
def assign_properties(
        p: typing.Dict[str, typing.Tuple[Metadata, Type]],
        is_classvar=False) -> typing.Iterable[cst.SimpleStatementLine]:
    for name, metadata_and_tp in sort_items(p):
        if bad_name(name):
            continue
        metadata, tp = metadata_and_tp
        ann = tp.annotation
        yield cst.SimpleStatementLine(
            [
                cst.AnnAssign(
                    cst.Name(name),
                    cst.Annotation(
                        cst.Subscript(cst.Name("ClassVar"),
                                      [cst.SubscriptElement(cst.Index(ann))]
                                      ) if is_classvar else ann),
                )
            ],
            leading_lines=[cst.EmptyLine()] + [
                cst.EmptyLine(comment=cst.Comment("# " + l))
                for l in metadata_lines(metadata)
            ],
        )
コード例 #15
0
class AnnAssignTest(CSTNodeTest):
    @data_provider((
        # Simple assignment creation case.
        {
            "node":
            cst.AnnAssign(cst.Name("foo"), cst.Annotation(cst.Name("str")),
                          cst.Integer("5")),
            "code":
            "foo: str = 5",
            "parser":
            None,
            "expected_position":
            CodeRange((1, 0), (1, 12)),
        },
        # Annotation creation without assignment
        {
            "node": cst.AnnAssign(cst.Name("foo"),
                                  cst.Annotation(cst.Name("str"))),
            "code": "foo: str",
            "parser": None,
            "expected_position": CodeRange((1, 0), (1, 8)),
        },
        # Complex annotation creation
        {
            "node":
            cst.AnnAssign(
                cst.Name("foo"),
                cst.Annotation(
                    cst.Subscript(
                        cst.Name("Optional"),
                        (cst.SubscriptElement(cst.Index(cst.Name("str"))), ),
                    )),
                cst.Integer("5"),
            ),
            "code":
            "foo: Optional[str] = 5",
            "parser":
            None,
            "expected_position":
            CodeRange((1, 0), (1, 22)),
        },
        # Simple assignment parser case.
        {
            "node":
            cst.SimpleStatementLine((cst.AnnAssign(
                target=cst.Name("foo"),
                annotation=cst.Annotation(
                    annotation=cst.Name("str"),
                    whitespace_before_indicator=cst.SimpleWhitespace(""),
                ),
                equal=cst.AssignEqual(),
                value=cst.Integer("5"),
            ), )),
            "code":
            "foo: str = 5\n",
            "parser":
            parse_statement,
            "expected_position":
            None,
        },
        # Annotation without assignment
        {
            "node":
            cst.SimpleStatementLine((cst.AnnAssign(
                target=cst.Name("foo"),
                annotation=cst.Annotation(
                    annotation=cst.Name("str"),
                    whitespace_before_indicator=cst.SimpleWhitespace(""),
                ),
                value=None,
            ), )),
            "code":
            "foo: str\n",
            "parser":
            parse_statement,
            "expected_position":
            None,
        },
        # Complex annotation
        {
            "node":
            cst.SimpleStatementLine((cst.AnnAssign(
                target=cst.Name("foo"),
                annotation=cst.Annotation(
                    annotation=cst.Subscript(
                        cst.Name("Optional"),
                        (cst.SubscriptElement(cst.Index(cst.Name("str"))), ),
                    ),
                    whitespace_before_indicator=cst.SimpleWhitespace(""),
                ),
                equal=cst.AssignEqual(),
                value=cst.Integer("5"),
            ), )),
            "code":
            "foo: Optional[str] = 5\n",
            "parser":
            parse_statement,
            "expected_position":
            None,
        },
        # Whitespace test
        {
            "node":
            cst.AnnAssign(
                target=cst.Name("foo"),
                annotation=cst.Annotation(
                    annotation=cst.Subscript(
                        cst.Name("Optional"),
                        (cst.SubscriptElement(cst.Index(cst.Name("str"))), ),
                    ),
                    whitespace_before_indicator=cst.SimpleWhitespace(" "),
                    whitespace_after_indicator=cst.SimpleWhitespace("  "),
                ),
                equal=cst.AssignEqual(
                    whitespace_before=cst.SimpleWhitespace("  "),
                    whitespace_after=cst.SimpleWhitespace("  "),
                ),
                value=cst.Integer("5"),
            ),
            "code":
            "foo :  Optional[str]  =  5",
            "parser":
            None,
            "expected_position":
            CodeRange((1, 0), (1, 26)),
        },
        {
            "node":
            cst.SimpleStatementLine((cst.AnnAssign(
                target=cst.Name("foo"),
                annotation=cst.Annotation(
                    annotation=cst.Subscript(
                        cst.Name("Optional"),
                        (cst.SubscriptElement(cst.Index(cst.Name("str"))), ),
                    ),
                    whitespace_before_indicator=cst.SimpleWhitespace(" "),
                    whitespace_after_indicator=cst.SimpleWhitespace("  "),
                ),
                equal=cst.AssignEqual(
                    whitespace_before=cst.SimpleWhitespace("  "),
                    whitespace_after=cst.SimpleWhitespace("  "),
                ),
                value=cst.Integer("5"),
            ), )),
            "code":
            "foo :  Optional[str]  =  5\n",
            "parser":
            parse_statement,
            "expected_position":
            None,
        },
    ))
    def test_valid(self, **kwargs: Any) -> None:
        self.validate_node(**kwargs)

    @data_provider(({
        "get_node": (lambda: cst.AnnAssign(
            target=cst.Name("foo"),
            annotation=cst.Annotation(cst.Name("str")),
            equal=cst.AssignEqual(),
            value=None,
        )),
        "expected_re":
        "Must have a value when specifying an AssignEqual.",
    }, ))
    def test_invalid(self, **kwargs: Any) -> None:
        self.assert_invalid(**kwargs)
コード例 #16
0
class LambdaCreationTest(CSTNodeTest):
    @data_provider((
        # Simple lambda
        (cst.Lambda(cst.Parameters(), cst.Integer("5")), "lambda: 5"),
        # Test basic positional params
        (
            cst.Lambda(
                cst.Parameters(params=(cst.Param(cst.Name("bar")),
                                       cst.Param(cst.Name("baz")))),
                cst.Integer("5"),
            ),
            "lambda bar, baz: 5",
        ),
        # Test basic positional default params
        (
            cst.Lambda(
                cst.Parameters(default_params=(
                    cst.Param(cst.Name("bar"),
                              default=cst.SimpleString('"one"')),
                    cst.Param(cst.Name("baz"), default=cst.Integer("5")),
                )),
                cst.Integer("5"),
            ),
            'lambda bar = "one", baz = 5: 5',
        ),
        # Mixed positional and default params.
        (
            cst.Lambda(
                cst.Parameters(
                    params=(cst.Param(cst.Name("bar")), ),
                    default_params=(cst.Param(cst.Name("baz"),
                                              default=cst.Integer("5")), ),
                ),
                cst.Integer("5"),
            ),
            "lambda bar, baz = 5: 5",
        ),
        # Test kwonly params
        (
            cst.Lambda(
                cst.Parameters(kwonly_params=(
                    cst.Param(cst.Name("bar"),
                              default=cst.SimpleString('"one"')),
                    cst.Param(cst.Name("baz")),
                )),
                cst.Integer("5"),
            ),
            'lambda *, bar = "one", baz: 5',
        ),
        # Mixed params and kwonly_params
        (
            cst.Lambda(
                cst.Parameters(
                    params=(
                        cst.Param(cst.Name("first")),
                        cst.Param(cst.Name("second")),
                    ),
                    kwonly_params=(
                        cst.Param(cst.Name("bar"),
                                  default=cst.SimpleString('"one"')),
                        cst.Param(cst.Name("baz")),
                        cst.Param(cst.Name("biz"),
                                  default=cst.SimpleString('"two"')),
                    ),
                ),
                cst.Integer("5"),
            ),
            'lambda first, second, *, bar = "one", baz, biz = "two": 5',
        ),
        # Mixed default_params and kwonly_params
        (
            cst.Lambda(
                cst.Parameters(
                    default_params=(
                        cst.Param(cst.Name("first"), default=cst.Float("1.0")),
                        cst.Param(cst.Name("second"),
                                  default=cst.Float("1.5")),
                    ),
                    kwonly_params=(
                        cst.Param(cst.Name("bar"),
                                  default=cst.SimpleString('"one"')),
                        cst.Param(cst.Name("baz")),
                        cst.Param(cst.Name("biz"),
                                  default=cst.SimpleString('"two"')),
                    ),
                ),
                cst.Integer("5"),
            ),
            'lambda first = 1.0, second = 1.5, *, bar = "one", baz, biz = "two": 5',
        ),
        # Mixed params, default_params, and kwonly_params
        (
            cst.Lambda(
                cst.Parameters(
                    params=(
                        cst.Param(cst.Name("first")),
                        cst.Param(cst.Name("second")),
                    ),
                    default_params=(
                        cst.Param(cst.Name("third"), default=cst.Float("1.0")),
                        cst.Param(cst.Name("fourth"),
                                  default=cst.Float("1.5")),
                    ),
                    kwonly_params=(
                        cst.Param(cst.Name("bar"),
                                  default=cst.SimpleString('"one"')),
                        cst.Param(cst.Name("baz")),
                        cst.Param(cst.Name("biz"),
                                  default=cst.SimpleString('"two"')),
                    ),
                ),
                cst.Integer("5"),
            ),
            'lambda first, second, third = 1.0, fourth = 1.5, *, bar = "one", baz, biz = "two": 5',
            CodeRange((1, 0), (1, 84)),
        ),
        # Test star_arg
        (
            cst.Lambda(
                cst.Parameters(star_arg=cst.Param(cst.Name("params"))),
                cst.Integer("5"),
            ),
            "lambda *params: 5",
        ),
        # Typed star_arg, include kwonly_params
        (
            cst.Lambda(
                cst.Parameters(
                    star_arg=cst.Param(cst.Name("params")),
                    kwonly_params=(
                        cst.Param(cst.Name("bar"),
                                  default=cst.SimpleString('"one"')),
                        cst.Param(cst.Name("baz")),
                        cst.Param(cst.Name("biz"),
                                  default=cst.SimpleString('"two"')),
                    ),
                ),
                cst.Integer("5"),
            ),
            'lambda *params, bar = "one", baz, biz = "two": 5',
        ),
        # Mixed params default_params, star_arg and kwonly_params
        (
            cst.Lambda(
                cst.Parameters(
                    params=(
                        cst.Param(cst.Name("first")),
                        cst.Param(cst.Name("second")),
                    ),
                    default_params=(
                        cst.Param(cst.Name("third"), default=cst.Float("1.0")),
                        cst.Param(cst.Name("fourth"),
                                  default=cst.Float("1.5")),
                    ),
                    star_arg=cst.Param(cst.Name("params")),
                    kwonly_params=(
                        cst.Param(cst.Name("bar"),
                                  default=cst.SimpleString('"one"')),
                        cst.Param(cst.Name("baz")),
                        cst.Param(cst.Name("biz"),
                                  default=cst.SimpleString('"two"')),
                    ),
                ),
                cst.Integer("5"),
            ),
            'lambda first, second, third = 1.0, fourth = 1.5, *params, bar = "one", baz, biz = "two": 5',
        ),
        # Test star_arg and star_kwarg
        (
            cst.Lambda(
                cst.Parameters(star_kwarg=cst.Param(cst.Name("kwparams"))),
                cst.Integer("5"),
            ),
            "lambda **kwparams: 5",
        ),
        # Test star_arg and kwarg
        (
            cst.Lambda(
                cst.Parameters(
                    star_arg=cst.Param(cst.Name("params")),
                    star_kwarg=cst.Param(cst.Name("kwparams")),
                ),
                cst.Integer("5"),
            ),
            "lambda *params, **kwparams: 5",
        ),
        # Inner whitespace
        (
            cst.Lambda(
                lpar=(cst.LeftParen(
                    whitespace_after=cst.SimpleWhitespace(" ")), ),
                whitespace_after_lambda=cst.SimpleWhitespace("  "),
                params=cst.Parameters(),
                colon=cst.Colon(whitespace_after=cst.SimpleWhitespace(" ")),
                body=cst.Integer("5"),
                rpar=(cst.RightParen(
                    whitespace_before=cst.SimpleWhitespace(" ")), ),
            ),
            "( lambda  : 5 )",
            CodeRange((1, 2), (1, 13)),
        ),
    ))
    def test_valid(self,
                   node: cst.CSTNode,
                   code: str,
                   position: Optional[CodeRange] = None) -> None:
        self.validate_node(node, code, expected_position=position)

    @data_provider((
        (
            lambda: cst.Lambda(
                cst.Parameters(params=(cst.Param(cst.Name("arg")), )),
                cst.Integer("5"),
                lpar=(cst.LeftParen(), ),
            ),
            "left paren without right paren",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(params=(cst.Param(cst.Name("arg")), )),
                cst.Integer("5"),
                rpar=(cst.RightParen(), ),
            ),
            "right paren without left paren",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(params=(cst.Param(cst.Name("arg")), )),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(""),
            ),
            "at least one space after lambda",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(default_params=(cst.Param(
                    cst.Name("arg"), default=cst.Integer("5")), )),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(""),
            ),
            "at least one space after lambda",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(star_arg=cst.Param(cst.Name("arg"))),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(""),
            ),
            "at least one space after lambda",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(kwonly_params=(cst.Param(cst.Name("arg")), )),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(""),
            ),
            "at least one space after lambda",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(star_kwarg=cst.Param(cst.Name("arg"))),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(""),
            ),
            "at least one space after lambda",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(star_kwarg=cst.Param(cst.Name("bar"),
                                                    equal=cst.AssignEqual())),
                cst.Integer("5"),
            ),
            "Must have a default when specifying an AssignEqual.",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(star_kwarg=cst.Param(cst.Name("bar"),
                                                    star="***")),
                cst.Integer("5"),
            ),
            r"Must specify either '', '\*' or '\*\*' for star.",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(params=(cst.Param(
                    cst.Name("bar"), default=cst.SimpleString('"one"')), )),
                cst.Integer("5"),
            ),
            "Cannot have defaults for params",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(default_params=(cst.Param(cst.Name("bar")), )),
                cst.Integer("5"),
            ),
            "Must have defaults for default_params",
        ),
        (
            lambda: cst.Lambda(cst.Parameters(star_arg=cst.ParamStar()),
                               cst.Integer("5")),
            "Must have at least one kwonly param if ParamStar is used.",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(params=(cst.Param(cst.Name("bar"), star="*"), )
                               ),
                cst.Integer("5"),
            ),
            "Expecting a star prefix of ''",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(default_params=(cst.Param(
                    cst.Name("bar"),
                    default=cst.SimpleString('"one"'),
                    star="*",
                ), )),
                cst.Integer("5"),
            ),
            "Expecting a star prefix of ''",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(kwonly_params=(cst.Param(cst.Name("bar"),
                                                        star="*"), )),
                cst.Integer("5"),
            ),
            "Expecting a star prefix of ''",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(star_arg=cst.Param(cst.Name("bar"), star="**")),
                cst.Integer("5"),
            ),
            r"Expecting a star prefix of '\*'",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(star_kwarg=cst.Param(cst.Name("bar"), star="*")
                               ),
                cst.Integer("5"),
            ),
            r"Expecting a star prefix of '\*\*'",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(params=(cst.Param(
                    cst.Name("arg"),
                    annotation=cst.Annotation(cst.Name("str")),
                ), )),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(""),
            ),
            "Lambda params cannot have type annotations",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(default_params=(cst.Param(
                    cst.Name("arg"),
                    default=cst.Integer("5"),
                    annotation=cst.Annotation(cst.Name("str")),
                ), )),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(""),
            ),
            "Lambda params cannot have type annotations",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(star_arg=cst.Param(cst.Name("arg"),
                                                  annotation=cst.Annotation(
                                                      cst.Name("str")))),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(""),
            ),
            "Lambda params cannot have type annotations",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(kwonly_params=(cst.Param(
                    cst.Name("arg"),
                    annotation=cst.Annotation(cst.Name("str")),
                ), )),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(""),
            ),
            "Lambda params cannot have type annotations",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(star_kwarg=cst.Param(cst.Name("arg"),
                                                    annotation=cst.Annotation(
                                                        cst.Name("str")))),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(""),
            ),
            "Lambda params cannot have type annotations",
        ),
    ))
    def test_invalid(self, get_node: Callable[[], cst.CSTNode],
                     expected_re: str) -> None:
        self.assert_invalid(get_node, expected_re)
コード例 #17
0
 def get_annotation(p: cst.Param) -> Optional[cst.Annotation]:
     pname = p.name.value
     if types.get(pname):
         return cst.Annotation(cst.parse_expression(types[pname]))
     return None
コード例 #18
0
 def test_from_function_data(self) -> None:
     three_parameters = [
         cst.Param(name=cst.Name("x1"), annotation=None),
         cst.Param(name=cst.Name("x2"), annotation=None),
         cst.Param(name=cst.Name("x3"), annotation=None),
     ]
     self.assertEqual(
         FunctionAnnotationKind.from_function_data(
             is_return_annotated=True,
             annotated_parameter_count=3,
             is_method_or_classmethod=False,
             parameters=three_parameters,
         ),
         FunctionAnnotationKind.FULLY_ANNOTATED,
     )
     self.assertEqual(
         FunctionAnnotationKind.from_function_data(
             is_return_annotated=True,
             annotated_parameter_count=0,
             is_method_or_classmethod=False,
             parameters=three_parameters,
         ),
         FunctionAnnotationKind.PARTIALLY_ANNOTATED,
     )
     self.assertEqual(
         FunctionAnnotationKind.from_function_data(
             is_return_annotated=False,
             annotated_parameter_count=0,
             is_method_or_classmethod=False,
             parameters=three_parameters,
         ),
         FunctionAnnotationKind.NOT_ANNOTATED,
     )
     self.assertEqual(
         FunctionAnnotationKind.from_function_data(
             is_return_annotated=False,
             annotated_parameter_count=1,
             is_method_or_classmethod=False,
             parameters=three_parameters,
         ),
         FunctionAnnotationKind.PARTIALLY_ANNOTATED,
     )
     # An untyped `self` parameter of a method does not count for partial
     # annotation. As per PEP 484, we need an explicitly annotated parameter.
     self.assertEqual(
         FunctionAnnotationKind.from_function_data(
             is_return_annotated=False,
             annotated_parameter_count=1,
             is_method_or_classmethod=True,
             parameters=three_parameters,
         ),
         FunctionAnnotationKind.NOT_ANNOTATED,
     )
     self.assertEqual(
         FunctionAnnotationKind.from_function_data(
             is_return_annotated=False,
             annotated_parameter_count=2,
             is_method_or_classmethod=True,
             parameters=three_parameters,
         ),
         FunctionAnnotationKind.PARTIALLY_ANNOTATED,
     )
     # An annotated `self` suffices to make Pyre typecheck the method.
     self.assertEqual(
         FunctionAnnotationKind.from_function_data(
             is_return_annotated=False,
             annotated_parameter_count=1,
             is_method_or_classmethod=True,
             parameters=[
                 cst.Param(
                     name=cst.Name("self"),
                     annotation=cst.Annotation(cst.Name("Foo")),
                 )
             ],
         ),
         FunctionAnnotationKind.PARTIALLY_ANNOTATED,
     )
     self.assertEqual(
         FunctionAnnotationKind.from_function_data(
             is_return_annotated=False,
             annotated_parameter_count=0,
             is_method_or_classmethod=True,
             parameters=[],
         ),
         FunctionAnnotationKind.NOT_ANNOTATED,
     )
コード例 #19
0
ファイル: apis.py プロジェクト: vfdev-5/python-record-api
 def return_type_annotation(self) -> typing.Optional[cst.Annotation]:
     return_type_annotation = None
     if self.return_type:
         return_type_annotation = cst.Annotation(
             self.return_type.annotation)
     return return_type_annotation