class RequireDescriptiveNameRule(CstLintRule):

    VALID = [
        Valid("""
            class DescriptiveName:
                pass
            """),
        Valid("""
            class ThisClass:
                def this_method(self):
                    pass
            """),
        Valid("""
            def descriptive_function():
                pass
            """),
        Valid("""
            def function(descriptive, parameter):
                pass
            """),
    ]

    INVALID = [
        Invalid("""
            class T:
                pass
            """),
        Invalid("""
            class ThisClass:
                def m(self):
                    pass
            """),
        Invalid("""
            def f():
                pass
            """),
        Invalid("""
            def fun(a):
                pass
            """),
    ]

    def visit_ClassDef(self, node: cst.ClassDef) -> None:
        self._validate_name_length(node, "class")

    def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
        self._validate_name_length(node, "function")

    def visit_Param(self, node: cst.Param) -> None:
        self._validate_name_length(node, "parameter")

    def _validate_name_length(self, node: Union[cst.ClassDef, cst.FunctionDef,
                                                cst.Param],
                              nodetype: str) -> None:
        nodename = node.name.value
        if len(nodename) == 1:
            self.report(node,
                        message=MESSAGE.format(nodetype=nodetype,
                                               nodename=nodename))
class NoRedundantLambdaRule(CstLintRule):
    """A lamba function which has a single objective of
    passing all it is arguments to another callable can
    be safely replaced by that callable."""

    VALID = [
        Valid("lambda x: foo(y)"),
        Valid("lambda x: foo(x, y)"),
        Valid("lambda x, y: foo(x)"),
        Valid("lambda *, x: foo(x)"),
        Valid("lambda x = y: foo(x)"),
        Valid("lambda x, y: foo(y, x)"),
        Valid("lambda self: self.func()"),
        Valid("lambda x, y: foo(y=x, x=y)"),
        Valid("lambda x, y, *z: foo(x, y, z)"),
        Valid("lambda x, y, **z: foo(x, y, z)"),
    ]
    INVALID = [
        Invalid("lambda: self.func()", expected_replacement="self.func"),
        Invalid("lambda x: foo(x)", expected_replacement="foo"),
        Invalid(
            "lambda x, y, z: (t + u).math_call(x, y, z)",
            expected_replacement="(t + u).math_call",
        ),
    ]

    @staticmethod
    def _is_simple_parameter_spec(node: cst.Parameters) -> bool:
        if (node.star_kwarg is not None or len(node.kwonly_params) > 0
                or len(node.posonly_params) > 0
                or not isinstance(node.star_arg, cst.MaybeSentinel)):
            return False

        return all(param.default is None for param in node.params)

    def visit_Lambda(self, node: cst.Lambda) -> None:
        if m.matches(
                node,
                m.Lambda(
                    params=m.MatchIfTrue(self._is_simple_parameter_spec),
                    body=m.Call(args=[
                        m.Arg(value=m.Name(value=param.name.value),
                              star="",
                              keyword=None) for param in node.params.params
                    ]),
                ),
        ):
            call = cst.ensure_type(node.body, cst.Call)
            full_name = get_full_name_for_node(call)
            if full_name is None:
                full_name = "function"

            self.report(
                node,
                UNNECESSARY_LAMBDA.format(function=full_name),
                replacement=call.func,
            )
Beispiel #3
0
class NoRedundantFStringRule(CstLintRule):
    """
    Remove redundant f-string without placeholders.
    """

    MESSAGE: str = "f-string doesn't have placeholders, remove redundant f-string."

    VALID = [
        Valid('good: str = "good"'),
        Valid('good: str = f"with_arg{arg}"'),
        Valid('good = "good{arg1}".format(1234)'),
        Valid('good = "good".format()'),
        Valid('good = "good" % {}'),
        Valid('good = "good" % ()'),
        Valid('good = rf"good\t+{bar}"'),
    ]

    INVALID = [
        Invalid(
            'bad: str = f"bad" + "bad"',
            line=1,
            expected_replacement='bad: str = "bad" + "bad"',
        ),
        Invalid(
            "bad: str = f'bad'",
            line=1,
            expected_replacement="bad: str = 'bad'",
        ),
        Invalid(
            "bad: str = rf'bad\t+'",
            line=1,
            expected_replacement="bad: str = r'bad\t+'",
        ),
        Invalid(
            'bad: str = f"no args but messing up {{ braces }}"',
            line=1,
            expected_replacement=
            'bad: str = "no args but messing up { braces }"',
        ),
    ]

    def visit_FormattedString(self, node: cst.FormattedString) -> None:
        if not m.matches(node,
                         m.FormattedString(parts=(m.FormattedStringText(), ))):
            return

        old_string_inner = cst.ensure_type(node.parts[0],
                                           cst.FormattedStringText).value
        if "{{" in old_string_inner or "}}" in old_string_inner:
            old_string_inner = old_string_inner.replace("{{", "{").replace(
                "}}", "}")

        new_string_literal = (node.start.replace("f", "").replace("F", "") +
                              old_string_inner + node.end)

        self.report(node, replacement=cst.SimpleString(new_string_literal))
class NoInheritFromObjectRule(CstLintRule):
    """
    In Python 3, a class is inherited from ``object`` by default.
    Explicitly inheriting from ``object`` is redundant, so removing it keeps the code simpler.
    """

    MESSAGE = "Inheriting from object is a no-op.  'class Foo:' is just fine =)"
    VALID = [
        Valid("class A(something):    pass"),
        Valid(
            """
            class A:
                pass"""
        ),
    ]
    INVALID = [
        Invalid(
            """
            class B(object):
                pass""",
            line=1,
            column=1,
            expected_replacement="""
                class B:
                    pass""",
        ),
        Invalid(
            """
            class B(object, A):
                pass""",
            line=1,
            column=1,
            expected_replacement="""
                class B(A):
                    pass""",
        ),
    ]

    def visit_ClassDef(self, node: cst.ClassDef) -> None:
        new_bases = tuple(
            base for base in node.bases if not m.matches(base.value, m.Name("object"))
        )

        if tuple(node.bases) != new_bases:
            # reconstruct classdef, removing parens if bases and keywords are empty
            new_classdef = node.with_changes(
                bases=new_bases,
                lpar=cst.MaybeSentinel.DEFAULT,
                rpar=cst.MaybeSentinel.DEFAULT,
            )

            # report warning and autofix
            self.report(node, replacement=new_classdef)
Beispiel #5
0
class NoRedundantListComprehensionRule(CstLintRule):
    """
    A derivative of flake8-comprehensions's C407 rule.
    """

    VALID = [
        Valid("any(val for val in iterable)"),
        Valid("all(val for val in iterable)"),
        # C407 would complain about these, but we won't
        Valid("frozenset([val for val in iterable])"),
        Valid("max([val for val in iterable])"),
        Valid("min([val for val in iterable])"),
        Valid("sorted([val for val in iterable])"),
        Valid("sum([val for val in iterable])"),
        Valid("tuple([val for val in iterable])"),
    ]
    INVALID = [
        Invalid(
            "any([val for val in iterable])",
            expected_replacement="any(val for val in iterable)",
        ),
        Invalid(
            "all([val for val in iterable])",
            expected_replacement="all(val for val in iterable)",
        ),
    ]

    def visit_Call(self, node: cst.Call) -> None:
        # This set excludes frozenset, max, min, sorted, sum, and tuple, which C407 would warn
        # about, because none of those functions short-circuit.
        if m.matches(
                node,
                m.Call(func=m.Name("all") | m.Name("any"),
                       args=[m.Arg(value=m.ListComp())]),
        ):
            list_comp = cst.ensure_type(node.args[0].value, cst.ListComp)
            self.report(
                node,
                UNNECESSARY_LIST_COMPREHENSION.format(
                    func=cst.ensure_type(node.func, cst.Name).value),
                replacement=node.deep_replace(
                    list_comp,
                    cst.GeneratorExp(elt=list_comp.elt,
                                     for_in=list_comp.for_in,
                                     lpar=[],
                                     rpar=[]),
                ),
            )
class NoAssertEqualsRule(CstLintRule):
    """
    Discourages use of ``assertEquals`` as it is deprecated (see https://docs.python.org/2/library/unittest.html#deprecated-aliases
    and https://bugs.python.org/issue9424). Use the standardized ``assertEqual`` instead.
    """

    MESSAGE: str = (
        '"assertEquals" is deprecated, use "assertEqual" instead.\n'
        + "See https://docs.python.org/2/library/unittest.html#deprecated-aliases and https://bugs.python.org/issue9424."
    )
    VALID = [Valid("self.assertEqual(a, b)")]
    INVALID = [
        Invalid(
            "self.assertEquals(a, b)",
            expected_replacement="self.assertEqual(a, b)",
        )
    ]

    def visit_Call(self, node: cst.Call) -> None:
        if m.matches(
            node,
            m.Call(func=m.Attribute(value=m.Name("self"), attr=m.Name("assertEquals"))),
        ):
            new_call = node.with_deep_changes(
                old_node=cst.ensure_type(node.func, cst.Attribute).attr,
                value="assertEqual",
            )
            self.report(node, replacement=new_call)
Beispiel #7
0
class UseLintFixmeCommentRule(CstLintRule):
    """
    To silence a lint warning, use ``lint-fixme`` (when plans to fix the issue later) or ``lint-ignore``
    (when the lint warning is not valid) comments.
    The comment requires to be in a standalone comment line and follows the format ``lint-fixme: RULE_NAMES EXTRA_COMMENTS``.
    It suppresses the lint warning with the RULE_NAMES in the next line.
    RULE_NAMES can be one or more lint rule names separated by comma.
    ``noqa`` is deprecated and not supported because explicitly providing lint rule names to be suppressed
    in lint-fixme comment is preferred over implicit noqa comments. Implicit noqa suppression comments
    sometimes accidentally silence warnings unexpectedly.
    """

    MESSAGE: str = "noqa is deprecated. Use `lint-fixme` or `lint-ignore` instead."

    VALID = [
        Valid("""
            # lint-fixme: UseFstringRule
            "%s" % "hi"
            """),
        Valid("""
            # lint-ignore: UsePlusForStringConcatRule
            'ab' 'cd'
            """),
    ]
    INVALID = [
        Invalid("fn() # noqa"),
        Invalid("""
            (
             1,
             2,  # noqa
            )
            """),
        Invalid("""
            class C:
                # noqa
                ...
            """),
    ]

    def visit_Comment(self, node: cst.Comment) -> None:
        target = "# noqa"
        if node.value[:len(target)].lower() == target:
            self.report(node)
class UseFstringRule(CstLintRule):

    MESSAGE: str = (
        "As mentioned in the [Contributing Guidelines]" +
        "(https://github.com/TheAlgorithms/Python/blob/master/CONTRIBUTING.md), "
        + "please do not use printf style formatting or `str.format()`. " +
        "Use [f-string](https://realpython.com/python-f-strings/) instead to be "
        + "more readable and efficient.")

    VALID = [
        Valid("assigned='string'; f'testing {assigned}'"),
        Valid("'simple string'"),
        Valid("'concatenated' + 'string'"),
        Valid("b'bytes %s' % 'string'.encode('utf-8')"),
    ]

    INVALID = [
        Invalid("'hello, {name}'.format(name='you')"),
        Invalid("'hello, %s' % 'you'"),
        Invalid("r'raw string value=%s' % val"),
    ]

    def visit_Call(self, node: cst.Call) -> None:
        if m.matches(
                node,
                m.Call(func=m.Attribute(value=m.SimpleString(),
                                        attr=m.Name(value="format"))),
        ):
            self.report(node)

    def visit_BinaryOperation(self, node: cst.BinaryOperation) -> None:
        if (m.matches(
                node,
                m.BinaryOperation(left=m.SimpleString(), operator=m.Modulo()))
                # SimpleString can be bytes and fstring don't support bytes.
                # https://www.python.org/dev/peps/pep-0498/#no-binary-f-strings
                and isinstance(
                    cst.ensure_type(node.left,
                                    cst.SimpleString).evaluated_value, str)):
            self.report(node)
Beispiel #9
0
class AwaitAsyncCallRule(CstLintRule):
    """
    Enforces calls to coroutines are preceeded by the ``await`` keyword. Awaiting on a coroutine will execute it while
    simply calling a coroutine returns a coroutine object (https://docs.python.org/3/library/asyncio-task.html#coroutines).
    """

    MESSAGE: str = (
        "Async function call will only be executed with `await` statement. Did you forget to add `await`? "
        +
        "If you intend to not await, please add comment to disable this warning: # lint-fixme: AwaitAsyncCallRule "
    )

    METADATA_DEPENDENCIES = (TypeInferenceProvider, )

    VALID = [
        Valid("""
            async def async_func():
                await async_foo()
            """),
        Valid("""
            def foo(): pass
            foo()
            """),
        Valid("""
            async def foo(): pass
            async def bar():
                await foo()
            """),
        Valid("""
            async def foo(): pass
            async def bar():
                x = await foo()
            """),
        Valid("""
            async def foo() -> bool: pass
            async def bar():
                while not await foo(): pass
            """),
        Valid("""
            import asyncio
            async def foo(): pass
            asyncio.run(foo())
            """),
    ]
    INVALID = [
        Invalid(
            """
            async def foo(): pass
            async def bar():
                foo()
            """,
            expected_replacement="""
            async def foo(): pass
            async def bar():
                await foo()
            """,
        ),
        Invalid(
            """
            class Foo:
                async def _attr(self): pass
            obj = Foo()
            obj._attr
            """,
            expected_replacement="""
            class Foo:
                async def _attr(self): pass
            obj = Foo()
            await obj._attr
            """,
        ),
        Invalid(
            """
            class Foo:
                async def _method(self): pass
            obj = Foo()
            obj._method()
            """,
            expected_replacement="""
            class Foo:
                async def _method(self): pass
            obj = Foo()
            await obj._method()
            """,
        ),
        Invalid(
            """
            class Foo:
                async def _method(self): pass
            obj = Foo()
            result = obj._method()
            """,
            expected_replacement="""
            class Foo:
                async def _method(self): pass
            obj = Foo()
            result = await obj._method()
            """,
        ),
        Invalid(
            """
            class Foo:
                async def bar(): pass
            class NodeUser:
                async def get():
                    do_stuff()
                    return Foo()
            user = NodeUser.get().bar()
            """,
            expected_replacement="""
            class Foo:
                async def bar(): pass
            class NodeUser:
                async def get():
                    do_stuff()
                    return Foo()
            user = await NodeUser.get().bar()
            """,
        ),
        Invalid(
            """
            class Foo:
                async def _attr(self): pass
            obj = Foo()
            attribute = obj._attr
            """,
            expected_replacement="""
            class Foo:
                async def _attr(self): pass
            obj = Foo()
            attribute = await obj._attr
            """,
        ),
        Invalid(
            code="""
            async def foo() -> bool: pass
            x = True
            if x and foo(): pass
            """,
            expected_replacement="""
            async def foo() -> bool: pass
            x = True
            if x and await foo(): pass
            """,
        ),
        Invalid(
            code="""
            async def foo() -> bool: pass
            x = True
            are_both_true = x and foo()
            """,
            expected_replacement="""
            async def foo() -> bool: pass
            x = True
            are_both_true = x and await foo()
            """,
        ),
        Invalid(
            """
            async def foo() -> bool: pass
            if foo():
                do_stuff()
            """,
            expected_replacement="""
            async def foo() -> bool: pass
            if await foo():
                do_stuff()
            """,
        ),
        Invalid(
            """
            async def foo() -> bool: pass
            if not foo():
                do_stuff()
            """,
            expected_replacement="""
            async def foo() -> bool: pass
            if not await foo():
                do_stuff()
            """,
        ),
        Invalid(
            """
            class Foo:
                async def _attr(self): pass
                def bar(self):
                    if self._attr: pass
            """,
            expected_replacement="""
            class Foo:
                async def _attr(self): pass
                def bar(self):
                    if await self._attr: pass
            """,
        ),
        Invalid(
            """
            class Foo:
                async def _attr(self): pass
                def bar(self):
                    if not self._attr: pass
            """,
            expected_replacement="""
            class Foo:
                async def _attr(self): pass
                def bar(self):
                    if not await self._attr: pass
            """,
        ),
        # Case where only cst.Attribute node's `attr` returns awaitable
        Invalid(
            """
            class Foo:
                async def _attr(self): pass
            def bar() -> Foo:
                return Foo()
            attribute = bar()._attr
            """,
            expected_replacement="""
            class Foo:
                async def _attr(self): pass
            def bar() -> Foo:
                return Foo()
            attribute = await bar()._attr
            """,
        ),
        # Case where only cst.Attribute node's `value` returns awaitable
        Invalid(
            """
            class Foo:
                def _attr(self): pass
            async def bar():
                await do_stuff()
                return Foo()
            attribute = bar()._attr
            """,
            expected_replacement="""
            class Foo:
                def _attr(self): pass
            async def bar():
                await do_stuff()
                return Foo()
            attribute = await bar()._attr
            """,
        ),
        Invalid(
            """
            async def bar() -> bool: pass
            while bar(): pass
            """,
            expected_replacement="""
            async def bar() -> bool: pass
            while await bar(): pass
            """,
        ),
        Invalid(
            """
            async def bar() -> bool: pass
            while not bar(): pass
            """,
            expected_replacement="""
            async def bar() -> bool: pass
            while not await bar(): pass
            """,
        ),
        Invalid(
            """
            class Foo:
                @classmethod
                async def _method(cls): pass
            Foo._method()
            """,
            expected_replacement="""
            class Foo:
                @classmethod
                async def _method(cls): pass
            await Foo._method()
            """,
        ),
        Invalid(
            """
            class Foo:
                @staticmethod
                async def _method(self): pass
            Foo._method()
            """,
            expected_replacement="""
            class Foo:
                @staticmethod
                async def _method(self): pass
            await Foo._method()
            """,
        ),
    ]

    @staticmethod
    def _is_awaitable_callable(annotation: str) -> bool:
        if not (annotation.startswith("typing.Callable")
                or annotation.startswith("typing.ClassMethod")
                or annotation.startswith("StaticMethod")):
            # Exit early if this is not even a `typing.Callable` annotation.
            return False
        try:
            # Wrap this in a try-except since the type annotation may not be parse-able as a module.
            # If it is not parse-able, we know it's not what we are looking for anyway, so return `False`.
            parsed_ann = cst.parse_module(annotation)
        except Exception:
            return False
        # If passed annotation does not match the expected annotation structure for a `typing.Callable` with
        # typing.Coroutine as the return type, matched_callable_ann will simply be `None`.
        # The expected structure of an awaitable callable annotation from Pyre is: typing.Callable()[[...], typing.Coroutine[...]]
        matched_callable_ann: Optional[Dict[str, Union[
            Sequence[cst.CSTNode], cst.CSTNode]]] = m.extract(
                parsed_ann,
                m.Module(body=[
                    m.SimpleStatementLine(body=[
                        m.Expr(value=m.Subscript(slice=[
                            m.SubscriptElement(),
                            m.SubscriptElement(slice=m.Index(value=m.Subscript(
                                value=m.SaveMatchedNode(
                                    m.Attribute(),
                                    "base_return_type",
                                )))),
                        ], ))
                    ]),
                ]),
            )
        if (matched_callable_ann is not None
                and "base_return_type" in matched_callable_ann):
            base_return_type = get_full_name_for_node(
                cst.ensure_type(matched_callable_ann["base_return_type"],
                                cst.CSTNode))
            return (base_return_type is not None
                    and base_return_type == "typing.Coroutine")
        return False

    def _get_awaitable_replacement(self,
                                   node: cst.CSTNode) -> Optional[cst.CSTNode]:
        annotation = self.get_metadata(TypeInferenceProvider, node, None)
        if annotation is not None and (
                annotation.startswith("typing.Coroutine")
                or self._is_awaitable_callable(annotation)):
            if isinstance(node, cst.BaseExpression):
                return cst.Await(expression=node)
        return None

    def _get_async_attr_replacement(
            self, node: cst.Attribute) -> Optional[cst.CSTNode]:
        value = node.value
        if m.matches(value, m.Call()):
            value = cast(cst.Call, value)
            value_replacement = self._get_async_call_replacement(value)
            if value_replacement is not None:
                return node.with_changes(value=value_replacement)
        return self._get_awaitable_replacement(node)

    def _get_async_call_replacement(self,
                                    node: cst.Call) -> Optional[cst.CSTNode]:
        func = node.func
        if m.matches(func, m.Attribute()):
            func = cast(cst.Attribute, func)
            attr_func_replacement = self._get_async_attr_replacement(func)
            if attr_func_replacement is not None:
                return node.with_changes(func=attr_func_replacement)
        return self._get_awaitable_replacement(node)

    def _get_async_expr_replacement(
            self, node: cst.CSTNode) -> Optional[cst.CSTNode]:
        if m.matches(node, m.Call()):
            node = cast(cst.Call, node)
            return self._get_async_call_replacement(node)
        elif m.matches(node, m.Attribute()):
            node = cast(cst.Attribute, node)
            return self._get_async_attr_replacement(node)
        elif m.matches(node, m.UnaryOperation(operator=m.Not())):
            node = cast(cst.UnaryOperation, node)
            replacement_expression = self._get_async_expr_replacement(
                node.expression)
            if replacement_expression is not None:
                return node.with_changes(expression=replacement_expression)
        elif m.matches(node, m.BooleanOperation()):
            node = cast(cst.BooleanOperation, node)
            maybe_left = self._get_async_expr_replacement(node.left)
            maybe_right = self._get_async_expr_replacement(node.right)
            if maybe_left is not None or maybe_right is not None:
                left_replacement = maybe_left if maybe_left is not None else node.left
                right_replacement = (maybe_right if maybe_right is not None
                                     else node.right)
                return node.with_changes(left=left_replacement,
                                         right=right_replacement)
        return None

    def _maybe_autofix_node(self, node: cst.CSTNode,
                            attribute_name: str) -> None:
        replacement_value = self._get_async_expr_replacement(
            getattr(node, attribute_name))
        if replacement_value is not None:
            replacement = node.with_changes(
                **{attribute_name: replacement_value})
            self.report(node, replacement=replacement)

    def visit_If(self, node: cst.If) -> None:
        self._maybe_autofix_node(node, "test")

    def visit_While(self, node: cst.While) -> None:
        self._maybe_autofix_node(node, "test")

    def visit_Assign(self, node: cst.Assign) -> None:
        self._maybe_autofix_node(node, "value")

    def visit_Expr(self, node: cst.Expr) -> None:
        self._maybe_autofix_node(node, "value")
class RequireTypeHintRule(CstLintRule):

    VALID = [
        Valid("""
            def func() -> str:
                pass
            """),
        Valid("""
            def func() -> None:
                pass
            """),
        Valid("""
            def func(some: str, other: int) -> None:
                pass
            """),
        Valid("""
            class Random:
                def random_method(self, value: int) -> None:
                    pass
            """),
        Valid("""
            class Random:
                @classmethod
                def initiate(cls, value: str) -> str:
                    pass
            """),
        Valid("""
            lambda ignore: ignore
            """),
        Valid("""
            lambda closure: lambda inside: closure + inside
            """),
    ]

    INVALID = [
        Invalid("""
            def func():
                pass
            """),
        Invalid("""
            def func(num: int, val: str):
                pass
            """),
        Invalid("""
            def func(num: int, val) -> None:
                pass
            """),
        Invalid("""
            class Random:
                def __init__(self, val) -> None:
                    pass
            """),
        Invalid("""
            class Random:
                @classmethod
                def from_class(cls, val) -> None:
                    pass
            """),
        Invalid("""
            def spam() -> None:
                foo = lambda bar: str(bar)

                def wrapper(call) -> None:
                    pass
                return wrapper(foo)
            """),
    ]

    def __init__(self, context: CstContext) -> None:
        super().__init__(context)
        self._lambda_counter: int = 0

    def visit_Lambda(self, node: cst.Lambda) -> None:
        self._lambda_counter += 1

    def leave_Lambda(self, original_node: cst.Lambda) -> None:
        self._lambda_counter -= 1

    def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
        if node.returns is None:
            self.report(
                node,
                MISSING_RETURN_TYPE_HINT.format(nodename=node.name.value))

    def visit_Param(self, node: cst.Param) -> None:
        # Annotating parameters in ``lambda`` is not possible.
        if self._lambda_counter == 0:
            nodename = node.name.value
            if node.annotation is None and nodename not in IGNORE_PARAM:
                self.report(node, MISSING_TYPE_HINT.format(nodename=nodename))
Beispiel #11
0
class UseIsNoneOnOptionalRule(CstLintRule):
    """
    Enforces explicit use of ``is None`` or ``is not None`` when checking whether an Optional has a value.
    Directly testing an object (e.g. ``if x``) implicitely tests for a truth value which returns ``True`` unless the
    object's ``__bool__()`` method returns False, its ``__len__()`` method returns '0', or it is one of the constants ``None`` or ``False``.
    (https://docs.python.org/3.8/library/stdtypes.html#truth-value-testing).
    """

    METADATA_DEPENDENCIES = (TypeInferenceProvider,)
    MESSAGE: str = (
        "When checking if an `Optional` has a value, avoid using it as a boolean since it implicitly checks the object's `__bool__()`, `__len__()` is not `0`, or the value is not `None`. "
        + "Instead, use `is None` or `is not None` to be explicit."
    )

    VALID: List[Valid] = [
        Valid(
            """
            from typing import Optional
            a: Optional[str]
            if a is not None:
                pass
            """,
        ),
        Valid(
            """
            a: bool
            if a:
                pass
            """,
        ),
    ]

    INVALID: List[Invalid] = [
        Invalid(
            code="""
            from typing import Optional

            a: Optional[str] = None
            if a:
                pass
            """,
            expected_replacement="""
            from typing import Optional

            a: Optional[str] = None
            if a is not None:
                pass
            """,
        ),
        Invalid(
            code="""
            from typing import Optional
            a: Optional[str] = None
            x: bool = False
            if x and a:
                ...
            """,
            expected_replacement="""
            from typing import Optional
            a: Optional[str] = None
            x: bool = False
            if x and a is not None:
                ...
            """,
        ),
        Invalid(
            code="""
            from typing import Optional
            a: Optional[str] = None
            x: bool = False
            if a and x:
                ...
            """,
            expected_replacement="""
            from typing import Optional
            a: Optional[str] = None
            x: bool = False
            if a is not None and x:
                ...
            """,
        ),
        Invalid(
            code="""
            from typing import Optional
            a: Optional[str] = None
            x: bool = not a
            """,
            expected_replacement="""
            from typing import Optional
            a: Optional[str] = None
            x: bool = a is None
            """,
        ),
        Invalid(
            code="""
            from typing import Optional
            a: Optional[str]
            x: bool
            if x or a: pass
            """,
            expected_replacement="""
            from typing import Optional
            a: Optional[str]
            x: bool
            if x or a is not None: pass
            """,
        ),
        Invalid(
            code="""
            from typing import Optional
            a: Optional[str]
            x: bool
            if x: pass
            elif a: pass
            """,
            expected_replacement="""
            from typing import Optional
            a: Optional[str]
            x: bool
            if x: pass
            elif a is not None: pass
            """,
        ),
        Invalid(
            code="""
            from typing import Optional
            a: Optional[str] = None
            b: Optional[str] = None
            if a: pass
            elif b: pass
            """,
            expected_replacement="""
            from typing import Optional
            a: Optional[str] = None
            b: Optional[str] = None
            if a is not None: pass
            elif b is not None: pass
            """,
        ),
    ]

    def leave_If(self, original_node: cst.If) -> None:
        changes: Dict[str, cst.CSTNode] = {}
        test_expression: cst.BaseExpression = original_node.test
        if m.matches(test_expression, m.Name()):
            # We are inside a simple check such as "if x".
            test_expression = cast(cst.Name, test_expression)
            if self._is_optional_type(test_expression):
                # We want to replace "if x" with "if x is not None".
                replacement_comparison: cst.Comparison = self._gen_comparison_to_none(
                    variable_name=test_expression.value, operator=cst.IsNot()
                )
                changes["test"] = replacement_comparison

        orelse = original_node.orelse
        if orelse is not None and m.matches(orelse, m.If()):
            # We want to catch this case upon leaving an `If` node so that we generate an `elif` statement correctly.
            # We check if the orelse node was reported, and if so, remove the report and generate a new report on
            # the current parent `If` node.
            new_reports = []
            orelse_report: Optional[CstLintRuleReport] = None
            for report in self.context.reports:
                if isinstance(report, CstLintRuleReport):
                    # Check whether the lint rule code matches this lint rule's code so we don't remove another
                    # lint rule's report.
                    if report.node is orelse and report.code == self.__class__.__name__:
                        orelse_report = report
                    else:
                        new_reports.append(report)
                else:
                    new_reports.append(report)
            if orelse_report is not None:
                self.context.reports = new_reports
                replacement_orelse = orelse_report.replacement_node
                changes["orelse"] = cst.ensure_type(replacement_orelse, cst.CSTNode)

        if changes:
            self.report(
                original_node, replacement=original_node.with_changes(**changes)
            )

    def visit_BooleanOperation(self, node: cst.BooleanOperation) -> None:
        left_expression: cst.BaseExpression = node.left
        right_expression: cst.BaseExpression = node.right
        if m.matches(node.left, m.Name()):
            # Eg: "x and y".
            left_expression = cast(cst.Name, left_expression)
            if self._is_optional_type(left_expression):
                replacement_comparison = self._gen_comparison_to_none(
                    variable_name=left_expression.value, operator=cst.IsNot()
                )
                self.report(
                    node, replacement=node.with_changes(left=replacement_comparison)
                )
        if m.matches(right_expression, m.Name()):
            # Eg: "x and y".
            right_expression = cast(cst.Name, right_expression)
            if self._is_optional_type(right_expression):
                replacement_comparison = self._gen_comparison_to_none(
                    variable_name=right_expression.value, operator=cst.IsNot()
                )
                self.report(
                    node, replacement=node.with_changes(right=replacement_comparison)
                )

    def visit_UnaryOperation(self, node: cst.UnaryOperation) -> None:
        if m.matches(node, m.UnaryOperation(operator=m.Not(), expression=m.Name())):
            # Eg: "not x".
            expression: cst.Name = cast(cst.Name, node.expression)
            if self._is_optional_type(expression):
                replacement_comparison = self._gen_comparison_to_none(
                    variable_name=expression.value, operator=cst.Is()
                )
                self.report(node, replacement=replacement_comparison)

    def _is_optional_type(self, node: cst.Name) -> bool:
        reported_type = self.get_metadata(TypeInferenceProvider, node, None)
        # We want to use `startswith()` here since the type data will take on the form 'typing.Optional[SomeType]'.
        if reported_type is not None and reported_type.startswith("typing.Optional"):
            return True
        return False

    def _gen_comparison_to_none(
        self, variable_name: str, operator: Union[cst.Is, cst.IsNot]
    ) -> cst.Comparison:
        return cst.Comparison(
            left=cst.Name(value=variable_name),
            comparisons=[
                cst.ComparisonTarget(
                    operator=operator, comparator=cst.Name(value="None")
                )
            ],
        )
Beispiel #12
0
class NamingConventionRule(CstLintRule):

    METADATA_DEPENDENCIES = (QualifiedNameProvider, )  # type: ignore

    VALID = [
        Valid("type_hint: str"),
        Valid("type_hint_var: int = 5"),
        Valid("CONSTANT_WITH_UNDERSCORE12 = 10"),
        Valid("hello = 'world'"),
        Valid("snake_case = 'assign'"),
        Valid("for iteration in range(5): pass"),
        Valid("class _PrivateClass: pass"),
        Valid("class SomeClass: pass"),
        Valid("class One: pass"),
        Valid("def oneword(): pass"),
        Valid("def some_extra_words(): pass"),
        Valid("all = names_are = valid_in_multiple_assign = 5"),
        Valid("(walrus := 'operator')"),
        Valid("multiple, valid, assignments = 1, 2, 3"),
        Valid("""
            class Spam:
                def __init__(self, valid, another_valid):
                    self.valid = valid
                    self.another_valid = another_valid
                    self._private = None
                    self.__extreme_private = None

                def bar(self):
                    # This is just to test that the access is not being tested.
                    return self.some_Invalid_NaMe
            """),
        Valid("""
            from typing import List
            from collections import namedtuple

            Matrix = List[int]
            Point = namedtuple("Point", "x, y")

            some_matrix: Matrix = [1, 2]
            """),
    ]

    INVALID = [
        Invalid("type_Hint_Var: int = 5"),
        Invalid("hellO = 'world'"),
        Invalid("ranDom_UpPercAse = 'testing'"),
        Invalid("for RandomCaps in range(5): pass"),
        Invalid("class _Invalid_PrivateClass: pass"),
        Invalid("class _invalidPrivateClass: pass"),
        Invalid("class lowerPascalCase: pass"),
        Invalid("class all_lower_case: pass"),
        Invalid("def oneWordInvalid(): pass"),
        Invalid("def Pascal_Case(): pass"),
        Invalid("valid = another_valid = Invalid = 5"),
        Invalid("(waLRus := 'operator')"),
        Invalid("def func(invalidParam, valid_param): pass"),
        Invalid("multiple, inValid, assignments = 1, 2, 3"),
        Invalid("[inside, list, inValid] = Invalid, 2, 3"),
        Invalid("""
            class Spam:
                def __init__(self, foo, bar):
                    self.foo = foo
                    self._Bar = bar
            """),
    ]

    def __init__(self, context: CstContext) -> None:
        super().__init__(context)
        self._assigntarget_counter: int = 0

    def visit_Assign(self, node: cst.Assign) -> None:
        metadata: Optional[Collection[QualifiedName]] = self.get_metadata(
            QualifiedNameProvider, node.value, None)

        if metadata is not None:
            for qualname in metadata:
                # If the assignment is done with some objects from the typing or
                # collections module, then we will skip the check as the assignment
                # could be a type alias or the variable could be a class made using
                # ``collections.namedtuple``.
                if qualname.name.startswith(("typing", "collections")):
                    return None

        for target_node in node.targets:
            if m.matches(target_node, m.AssignTarget(target=m.Name())):
                nodename = cst.ensure_type(target_node.target, cst.Name).value
                self._validate_nodename(node, nodename,
                                        NamingConvention.SNAKE_CASE)

    def visit_AnnAssign(self, node: cst.AnnAssign) -> None:
        # The assignment value is optional, as it is possible to annotate an
        # expression without assigning to it: ``var: int``
        if m.matches(
                node,
                m.AnnAssign(
                    target=m.Name(),
                    value=m.MatchIfTrue(lambda value: value is not None)),
        ):
            nodename = cst.ensure_type(node.target, cst.Name).value
            self._validate_nodename(node, nodename,
                                    NamingConvention.SNAKE_CASE)

    def visit_AssignTarget(self, node: cst.AssignTarget) -> None:
        self._assigntarget_counter += 1

    def leave_AssignTarget(self, node: cst.AssignTarget) -> None:
        self._assigntarget_counter -= 1

    def visit_ClassDef(self, node: cst.ClassDef) -> None:
        self._validate_nodename(node, node.name.value,
                                NamingConvention.CAMEL_CASE)

    def visit_Attribute(self, node: cst.Attribute) -> None:
        # The attribute node can come through other context as well but we only care
        # about the ones coming from assignments.
        if self._assigntarget_counter > 0:
            # We only care about assignment attribute to *self*.
            if m.matches(node, m.Attribute(value=m.Name(value="self"))):
                self._validate_nodename(node, node.attr.value,
                                        NamingConvention.SNAKE_CASE)

    def visit_Element(self, node: cst.Element) -> None:
        # We only care about elements in *List* or *Tuple* specifically coming from
        # inside the multiple assignments.
        if self._assigntarget_counter > 0:
            if m.matches(node, m.Element(value=m.Name())):
                nodename = cst.ensure_type(node.value, cst.Name).value
                self._validate_nodename(node, nodename,
                                        NamingConvention.SNAKE_CASE)

    def visit_For(self, node: cst.For) -> None:
        if m.matches(node, m.For(target=m.Name())):
            nodename = cst.ensure_type(node.target, cst.Name).value
            self._validate_nodename(node, nodename,
                                    NamingConvention.SNAKE_CASE)

    def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
        self._validate_nodename(node, node.name.value,
                                NamingConvention.SNAKE_CASE)

    def visit_NamedExpr(self, node: cst.NamedExpr) -> None:
        if m.matches(node, m.NamedExpr(target=m.Name())):
            nodename = cst.ensure_type(node.target, cst.Name).value
            self._validate_nodename(node, nodename,
                                    NamingConvention.SNAKE_CASE)

    def visit_Param(self, node: cst.Param) -> None:
        self._validate_nodename(node, node.name.value,
                                NamingConvention.SNAKE_CASE)

    def _validate_nodename(self, node: cst.CSTNode, nodename: str,
                           naming_convention: NamingConvention) -> None:
        """Validate the provided *nodename* as per the given *naming_convention*.

        This is a convenience method as the same steps will be repeated for every
        visit functions which are to validate the name and report if found invalid.
        """
        if not naming_convention.valid(nodename):
            self.report(node,
                        naming_convention.value.format(nodename=nodename))
Beispiel #13
0
class AvoidOrInExceptRule(CstLintRule):
    """
    Discourages use of ``or`` in except clauses. If an except clause needs to catch multiple exceptions,
    they must be expressed as a parenthesized tuple, for example:
    ``except (ValueError, TypeError)``
    (https://docs.python.org/3/tutorial/errors.html#handling-exceptions)

    When ``or`` is used, only the first operand exception type of the conditional statement will be caught.
    For example::

        In [1]: class Exc1(Exception):
            ...:     pass
            ...:

        In [2]: class Exc2(Exception):
            ...:     pass
            ...:

        In [3]: try:
            ...:     raise Exception()
            ...: except Exc1 or Exc2:
            ...:     print("caught!")
            ...:
        ---------------------------------------------------------------------------
        Exception                                 Traceback (most recent call last)
        <ipython-input-3-3340d66a006c> in <module>
            1 try:
        ----> 2     raise Exception()
            3 except Exc1 or Exc2:
            4     print("caught!")
            5

        Exception:

        In [4]: try:
            ...:     raise Exc1()
            ...: except Exc1 or Exc2:
            ...:     print("caught!")
            ...:
            caught!

        In [5]: try:
            ...:     raise Exc2()
            ...: except Exc1 or Exc2:
            ...:     print("caught!")
            ...:
        ---------------------------------------------------------------------------
        Exc2                                      Traceback (most recent call last)
        <ipython-input-5-5d29c1589cc0> in <module>
            1 try:
        ----> 2     raise Exc2()
            3 except Exc1 or Exc2:
            4     print("caught!")
            5

        Exc2:
    """

    MESSAGE: str = (
        "Avoid using 'or' in an except block. For example:" +
        "'except ValueError or TypeError' only catches 'ValueError'. Instead, use "
        + "parentheses, 'except (ValueError, TypeError)'")
    VALID = [
        Valid("""
            try:
                print()
            except (ValueError, TypeError) as err:
                pass
            """)
    ]

    INVALID = [
        Invalid(
            """
            try:
                print()
            except ValueError or TypeError:
                pass
            """, )
    ]

    def visit_Try(self, node: cst.Try) -> None:
        if m.matches(
                node,
                m.Try(handlers=[
                    m.ExceptHandler(type=m.BooleanOperation(operator=m.Or()))
                ]),
        ):
            self.report(node)
Beispiel #14
0
class NoTypedDictRule(CstLintRule):
    """
    Enforce the use of ``dataclasses.dataclass`` decorator instead of ``NamedTuple`` for cleaner customization and
    inheritance. It supports default value, combining fields for inheritance, and omitting optional fields at
    instantiation. See `PEP 557 <https://www.python.org/dev/peps/pep-0557>`_.
    ``@dataclass`` is faster at reading an object's nested properties and executing its methods. (`benchmark <https://medium.com/@jacktator/dataclass-vs-namedtuple-vs-object-for-performance-optimization-in-python-691e234253b9>`_)
    """

    MESSAGE: str = "Instead of TypedDict, consider using the @dataclass decorator from dataclasses instead for simplicity, efficiency and consistency."
    METADATA_DEPENDENCIES = (QualifiedNameProvider,)

    VALID = [
        Valid(
            """
					@dataclass(frozen=True)
					class Foo:
						pass
					"""
        ),
        Valid(
            """
					@dataclass(frozen=False)
					class Foo:
						pass
					"""
        ),
        Valid(
            """
					class Foo:
						pass
					"""
        ),
        Valid(
            """
					class Foo(SomeOtherBase):
						pass
					"""
        ),
        Valid(
            """
					@some_other_decorator
					class Foo:
						pass
					"""
        ),
        Valid(
            """
					@some_other_decorator
					class Foo(SomeOtherBase):
						pass
					"""
        ),
    ]
    INVALID = [
        Invalid(
            code="""
            from typing import NamedTuple

            class Foo(NamedTuple):
                pass
            """,
            expected_replacement="""
            from typing import NamedTuple

            @dataclass(frozen=True)
            class Foo:
                pass
            """,
        ),
        Invalid(
            code="""
            from typing import NamedTuple as NT

            class Foo(NT):
                pass
            """,
            expected_replacement="""
            from typing import NamedTuple as NT

            @dataclass(frozen=True)
            class Foo:
                pass
            """,
        ),
        Invalid(
            code="""
            import typing as typ

            class Foo(typ.NamedTuple):
                pass
            """,
            expected_replacement="""
            import typing as typ

            @dataclass(frozen=True)
            class Foo:
                pass
            """,
        ),
        Invalid(
            code="""
            from typing import NamedTuple

            class Foo(NamedTuple, AnotherBase, YetAnotherBase):
                pass
            """,
            expected_replacement="""
            from typing import NamedTuple

            @dataclass(frozen=True)
            class Foo(AnotherBase, YetAnotherBase):
                pass
            """,
        ),
        Invalid(
            code="""
            from typing import NamedTuple

            class OuterClass(SomeBase):
                class InnerClass(NamedTuple):
                    pass
            """,
            expected_replacement="""
            from typing import NamedTuple

            class OuterClass(SomeBase):
                @dataclass(frozen=True)
                class InnerClass:
                    pass
            """,
        ),
        Invalid(
            code="""
            from typing import NamedTuple

            @some_other_decorator
            class Foo(NamedTuple):
                pass
            """,
            expected_replacement="""
            from typing import NamedTuple

            @some_other_decorator
            @dataclass(frozen=True)
            class Foo:
                pass
            """,
        ),
    ]

    qualified_typeddict = QualifiedName(name="typing_extensionsTypedDict", source=QualifiedNameSource.IMPORT)

    def leave_ClassDef(self, original_node: cst.ClassDef) -> None:
        (namedtuple_base, new_bases) = self.partition_bases(original_node.bases)
        if namedtuple_base is not None:
            call = ensure_type(parse_expression("dataclass(frozen=True)"), cst.Call)

            replacement = original_node.with_changes(
                lpar=MaybeSentinel.DEFAULT,
                rpar=MaybeSentinel.DEFAULT,
                bases=new_bases,
                decorators=list(original_node.decorators) + [cst.Decorator(decorator=call)],
            )
            self.report(original_node, replacement=replacement)

    def partition_bases(self, original_bases: Sequence[cst.Arg]) -> Tuple[Optional[cst.Arg], List[cst.Arg]]:
        # Returns a tuple of NamedTuple base object if it exists, and a list of non-NamedTuple bases
        namedtuple_base: Optional[cst.Arg] = None
        new_bases: List[cst.Arg] = []
        for base_class in original_bases:
            if QualifiedNameProvider.has_name(self, base_class.value, self.qualified_typeddict):
                namedtuple_base = base_class
            else:
                new_bases.append(base_class)
        return (namedtuple_base, new_bases)
class RewriteToComprehensionRule(CstLintRule):
    """
    A derivative of flake8-comprehensions's C400-C402 and C403-C404.
    Comprehensions are more efficient than functions calls. This C400-C402
    suggest to use `dict/set/list` comprehensions rather than respective
    function calls whenever possible. C403-C404 suggest to remove unnecessary
    list comprehension in a set/dict call, and replace it with set/dict
    comprehension.
    """

    VALID = [
        Valid("[val for val in iterable]"),
        Valid("{val for val in iterable}"),
        Valid("{val: val+1 for val in iterable}"),
        # A function call is valid if the elt is a function that returns a tuple.
        Valid("dict(line.strip().split('=', 1) for line in attr_file)"),
    ]

    INVALID = [
        Invalid(
            "list(val for val in iterable)",
            expected_replacement="[val for val in iterable]",
        ),
        # Nested list comprehenstion
        Invalid(
            "list(val for row in matrix for val in row)",
            expected_replacement="[val for row in matrix for val in row]",
        ),
        Invalid(
            "set(val for val in iterable)",
            expected_replacement="{val for val in iterable}",
        ),
        Invalid(
            "dict((x, f(x)) for val in iterable)",
            expected_replacement="{x: f(x) for val in iterable}",
        ),
        Invalid(
            "dict((x, y) for y, x in iterable)",
            expected_replacement="{x: y for y, x in iterable}",
        ),
        Invalid(
            "dict([val, val+1] for val in iterable)",
            expected_replacement="{val: val+1 for val in iterable}",
        ),
        Invalid(
            'dict((x["name"], json.loads(x["data"])) for x in responses)',
            expected_replacement=
            '{x["name"]: json.loads(x["data"]) for x in responses}',
        ),
        # Nested dict comprehension
        Invalid(
            "dict((k, v) for k, v in iter for iter in iters)",
            expected_replacement="{k: v for k, v in iter for iter in iters}",
        ),
        Invalid(
            "set([val for val in iterable])",
            expected_replacement="{val for val in iterable}",
        ),
        Invalid(
            "dict([[val, val+1] for val in iterable])",
            expected_replacement="{val: val+1 for val in iterable}",
        ),
        Invalid(
            "dict([(x, f(x)) for x in foo])",
            expected_replacement="{x: f(x) for x in foo}",
        ),
        Invalid(
            "dict([(x, y) for y, x in iterable])",
            expected_replacement="{x: y for y, x in iterable}",
        ),
        Invalid(
            "set([val for row in matrix for val in row])",
            expected_replacement="{val for row in matrix for val in row}",
        ),
    ]

    def visit_Call(self, node: cst.Call) -> None:
        if m.matches(
                node,
                m.Call(
                    func=m.Name("list") | m.Name("set") | m.Name("dict"),
                    args=[m.Arg(value=m.GeneratorExp() | m.ListComp())],
                ),
        ):
            call_name = cst.ensure_type(node.func, cst.Name).value

            if m.matches(node.args[0].value, m.GeneratorExp()):
                exp = cst.ensure_type(node.args[0].value, cst.GeneratorExp)
                message_formatter = UNNECESSARY_GENERATOR
            else:
                exp = cst.ensure_type(node.args[0].value, cst.ListComp)
                message_formatter = UNNECESSARY_LIST_COMPREHENSION

            replacement = None
            if call_name == "list":
                replacement = node.deep_replace(
                    node, cst.ListComp(elt=exp.elt, for_in=exp.for_in))
            elif call_name == "set":
                replacement = node.deep_replace(
                    node, cst.SetComp(elt=exp.elt, for_in=exp.for_in))
            elif call_name == "dict":
                elt = exp.elt
                key = None
                value = None
                if m.matches(elt, m.Tuple(m.DoNotCare(), m.DoNotCare())):
                    elt = cst.ensure_type(elt, cst.Tuple)
                    key = elt.elements[0].value
                    value = elt.elements[1].value
                elif m.matches(elt, m.List(m.DoNotCare(), m.DoNotCare())):
                    elt = cst.ensure_type(elt, cst.List)
                    key = elt.elements[0].value
                    value = elt.elements[1].value
                else:
                    # Unrecoginized form
                    return

                replacement = node.deep_replace(
                    node,
                    # pyre-fixme[6]: Expected `BaseAssignTargetExpression` for 1st
                    #  param but got `BaseExpression`.
                    cst.DictComp(key=key, value=value, for_in=exp.for_in),
                )

            self.report(node,
                        message_formatter.format(func=call_name),
                        replacement=replacement)
Beispiel #16
0
class RewriteToLiteralRule(CstLintRule):
    """
    A derivative of flake8-comprehensions' C405-C406 and C409-C410. It's
    unnecessary to use a list or tuple literal within a call to tuple, list,
    set, or dict since there is literal syntax for these types.
    """

    VALID = [
        Valid("(1, 2)"),
        Valid("()"),
        Valid("[1, 2]"),
        Valid("[]"),
        Valid("{1, 2}"),
        Valid("set()"),
        Valid("{1: 2, 3: 4}"),
        Valid("{}"),
    ]

    INVALID = [
        Invalid("tuple([1, 2])", expected_replacement="(1, 2)"),
        Invalid("tuple((1, 2))", expected_replacement="(1, 2)"),
        Invalid("tuple([])", expected_replacement="()"),
        Invalid("list([1, 2, 3])", expected_replacement="[1, 2, 3]"),
        Invalid("list((1, 2, 3))", expected_replacement="[1, 2, 3]"),
        Invalid("list([])", expected_replacement="[]"),
        Invalid("set([1, 2, 3])", expected_replacement="{1, 2, 3}"),
        Invalid("set((1, 2, 3))", expected_replacement="{1, 2, 3}"),
        Invalid("set([])", expected_replacement="set()"),
        Invalid(
            "dict([(1, 2), (3, 4)])",
            expected_replacement="{1: 2, 3: 4}",
        ),
        Invalid(
            "dict(((1, 2), (3, 4)))",
            expected_replacement="{1: 2, 3: 4}",
        ),
        Invalid(
            "dict([[1, 2], [3, 4], [5, 6]])",
            expected_replacement="{1: 2, 3: 4, 5: 6}",
        ),
        Invalid("dict([])", expected_replacement="{}"),
        Invalid("tuple()", expected_replacement="()"),
        Invalid("list()", expected_replacement="[]"),
        Invalid("dict()", expected_replacement="{}"),
    ]

    def visit_Call(self, node: cst.Call) -> None:
        if m.matches(
                node,
                m.Call(
                    func=m.Name("tuple") | m.Name("list") | m.Name("set")
                    | m.Name("dict"),
                    args=[m.Arg(value=m.List() | m.Tuple())],
                ),
        ) or m.matches(
                node,
                m.Call(func=m.Name("tuple") | m.Name("list") | m.Name("dict"),
                       args=[]),
        ):

            pairs_matcher = m.ZeroOrMore(
                m.Element(m.Tuple(
                    elements=[m.DoNotCare(), m.DoNotCare()]))
                | m.Element(m.List(
                    elements=[m.DoNotCare(), m.DoNotCare()])))

            exp = cst.ensure_type(node, cst.Call)
            call_name = cst.ensure_type(exp.func, cst.Name).value

            # If this is a empty call, it's an Unnecessary Call where we rewrite the call
            # to literal, except set().
            if not exp.args:
                elements = []
                message_formatter = UNNCESSARY_CALL
            else:
                arg = exp.args[0].value
                elements = cst.ensure_type(
                    arg, cst.List
                    if isinstance(arg, cst.List) else cst.Tuple).elements
                message_formatter = UNNECESSARY_LITERAL

            if call_name == "tuple":
                new_node = cst.Tuple(elements=elements)
            elif call_name == "list":
                new_node = cst.List(elements=elements)
            elif call_name == "set":
                # set() doesn't have an equivelant literal call. If it was
                # matched here, it's an unnecessary literal suggestion.
                if len(elements) == 0:
                    self.report(
                        node,
                        UNNECESSARY_LITERAL.format(func=call_name),
                        replacement=node.deep_replace(
                            node, cst.Call(func=cst.Name("set"))),
                    )
                    return
                new_node = cst.Set(elements=elements)
            elif len(elements) == 0 or m.matches(
                    exp.args[0].value,
                    m.Tuple(elements=[pairs_matcher])
                    | m.List(elements=[pairs_matcher]),
            ):
                new_node = cst.Dict(elements=[(
                    lambda val: cst.DictElement(val.elements[
                        0].value, val.elements[1].value))(cst.ensure_type(
                            ele.value,
                            cst.Tuple if isinstance(ele.value, cst.Tuple
                                                    ) else cst.List,
                        )) for ele in elements])
            else:
                # Unrecoginized form
                return

            self.report(
                node,
                message_formatter.format(func=call_name),
                replacement=node.deep_replace(node, new_node),
            )
class NoAssertTrueForComparisonsRule(CstLintRule):
    """
    Finds incorrect use of ``assertTrue`` when the intention is to compare two values.
    These calls are replaced with ``assertEqual``.
    Comparisons with True, False and None are replaced with one-argument
    calls to ``assertTrue``, ``assertFalse`` and ``assertIsNone``.
    """

    MESSAGE: str = '"assertTrue" does not compare its arguments, use "assertEqual" or other ' + "appropriate functions."

    VALID = [
        Valid("self.assertTrue(a == b)"),
        Valid('self.assertTrue(data.is_valid(), "is_valid() method")'),
        Valid("self.assertTrue(validate(len(obj.getName(type=SHORT))))"),
        Valid("self.assertTrue(condition, message_string)"),
    ]

    INVALID = [
        Invalid("self.assertTrue(a, 3)",
                expected_replacement="self.assertEqual(a, 3)"),
        Invalid(
            "self.assertTrue(hash(s[:4]), 0x1234)",
            expected_replacement="self.assertEqual(hash(s[:4]), 0x1234)",
        ),
        Invalid(
            "self.assertTrue(list, [1, 3])",
            expected_replacement="self.assertEqual(list, [1, 3])",
        ),
        Invalid(
            "self.assertTrue(optional, None)",
            expected_replacement="self.assertIsNone(optional)",
        ),
        Invalid(
            "self.assertTrue(b == a, True)",
            expected_replacement="self.assertTrue(b == a)",
        ),
        Invalid(
            "self.assertTrue(b == a, False)",
            expected_replacement="self.assertFalse(b == a)",
        ),
    ]

    def visit_Call(self, node: cst.Call) -> None:
        result = m.extract(
            node,
            m.Call(
                func=m.Attribute(value=m.Name("self"),
                                 attr=m.Name("assertTrue")),
                args=[
                    m.DoNotCare(),
                    m.Arg(value=m.SaveMatchedNode(
                        m.OneOf(
                            m.Integer(),
                            m.Float(),
                            m.Imaginary(),
                            m.Tuple(),
                            m.List(),
                            m.Set(),
                            m.Dict(),
                            m.Name("None"),
                            m.Name("True"),
                            m.Name("False"),
                        ),
                        "second",
                    )),
                ],
            ),
        )

        if result:
            second_arg = result["second"]
            if isinstance(second_arg, Sequence):
                second_arg = second_arg[0]

            if m.matches(second_arg, m.Name("True")):
                new_call = node.with_changes(args=[
                    node.args[0].with_changes(comma=cst.MaybeSentinel.DEFAULT)
                ], )
            elif m.matches(second_arg, m.Name("None")):
                new_call = node.with_changes(
                    func=node.func.with_deep_changes(
                        old_node=cst.ensure_type(node.func,
                                                 cst.Attribute).attr,
                        value="assertIsNone",
                    ),
                    args=[
                        node.args[0].with_changes(
                            comma=cst.MaybeSentinel.DEFAULT)
                    ],
                )
            elif m.matches(second_arg, m.Name("False")):
                new_call = node.with_changes(
                    func=node.func.with_deep_changes(
                        old_node=cst.ensure_type(node.func,
                                                 cst.Attribute).attr,
                        value="assertFalse",
                    ),
                    args=[
                        node.args[0].with_changes(
                            comma=cst.MaybeSentinel.DEFAULT)
                    ],
                )
            else:
                new_call = node.with_deep_changes(
                    old_node=cst.ensure_type(node.func, cst.Attribute).attr,
                    value="assertEqual",
                )

            self.report(node, replacement=new_call)
class NoStaticIfConditionRule(CstLintRule):
    """
    Discourages ``if`` conditions which evaluate to a static value (e.g. ``or True``, ``and False``, etc).
    """

    MESSAGE: str = (
        "Your if condition appears to evaluate to a static value (e.g. `or True`, `and False`). "
        +
        "Please double check this logic and if it is actually temporary debug code."
    )
    VALID = [
        Valid("""
            if my_func() or not else_func():
                pass
            """),
        Valid("""
            if function_call(True):
                pass
            """),
        Valid("""
            # ew who would this???
            def true():
                return False
            if true() and else_call():  # True or False
                pass
            """),
        Valid("""
            # ew who would this???
            if False or some_func():
                pass
            """),
    ]
    INVALID = [
        Invalid(
            """
            if True:
                do_something()
            """, ),
        Invalid(
            """
            if crazy_expression or True:
                do_something()
            """, ),
        Invalid(
            """
            if crazy_expression and False:
                do_something()
            """, ),
        Invalid(
            """
            if crazy_expression and not True:
                do_something()
            """, ),
        Invalid(
            """
            if crazy_expression or not False:
                do_something()
            """, ),
        Invalid(
            """
            if crazy_expression or (something() or True):
                do_something()
            """, ),
        Invalid(
            """
            if crazy_expression and (something() and (not True)):
                do_something()
            """, ),
        Invalid(
            """
            if crazy_expression and (something() and (other_func() and not True)):
                do_something()
            """, ),
        Invalid(
            """
            if (crazy_expression and (something() and (not True))) or True:
                do_something()
            """, ),
        Invalid(
            """
            async def some_func() -> none:
                if (await expression()) and False:
                    pass
            """, ),
    ]

    @classmethod
    def _extract_static_bool(cls, node: cst.BaseExpression) -> Optional[bool]:
        if m.matches(node, m.Call()):
            # cannot reason about function calls
            return None
        if m.matches(node, m.UnaryOperation(operator=m.Not())):
            sub_value = cls._extract_static_bool(
                cst.ensure_type(node, cst.UnaryOperation).expression)
            if sub_value is None:
                return None
            return not sub_value

        if m.matches(node, m.Name("True")):
            return True

        if m.matches(node, m.Name("False")):
            return False

        if m.matches(node, m.BooleanOperation()):
            node = cst.ensure_type(node, cst.BooleanOperation)
            left_value = cls._extract_static_bool(node.left)
            right_value = cls._extract_static_bool(node.right)
            if m.matches(node.operator, m.Or()):
                if right_value is True or left_value is True:
                    return True

            if m.matches(node.operator, m.And()):
                if right_value is False or left_value is False:
                    return False

        return None

    def visit_If(self, node: cst.If) -> None:
        if self._extract_static_bool(node.test) in {True, False}:
            self.report(node)
class UseTypesFromTypingRule(CstLintRule):
    """
    Enforces the use of types from the ``typing`` module in type annotations in place of ``builtins.{builtin_type}``
    since the type system doesn't recognize the latter as a valid type.
    """

    METADATA_DEPENDENCIES = (
        QualifiedNameProvider,
        ScopeProvider,
    )
    VALID = [
        Valid(
            """
            def fuction(list: List[str]) -> None:
                pass
            """
        ),
        Valid(
            """
            def function() -> None:
                thing: Dict[str, str] = {}
            """
        ),
        Valid(
            """
            def function() -> None:
                thing: Tuple[str]
            """
        ),
        Valid(
            """
            from typing import Dict, List
            def function() -> bool:
                    return Dict == List
            """
        ),
        Valid(
            """
            from typing import List as list
            from graphene import List

            def function(a: list[int]) -> List[int]:
                    return []
            """
        ),
    ]
    INVALID = [
        Invalid(
            """
            from typing import List
            def whatever(list: list[str]) -> None:
                pass
            """,
            expected_replacement="""
            from typing import List
            def whatever(list: List[str]) -> None:
                pass
            """,
        ),
        Invalid(
            """
            def function(list: list[str]) -> None:
                pass
            """,
        ),
        Invalid(
            """
            def func() -> None:
                thing: dict[str, str] = {}
            """,
        ),
        Invalid(
            """
            def func() -> None:
                thing: tuple[str]
            """,
        ),
        Invalid(
            """
            from typing import Dict
            def func() -> None:
                thing: dict[str, str] = {}
            """,
            expected_replacement="""
            from typing import Dict
            def func() -> None:
                thing: Dict[str, str] = {}
            """,
        ),
    ]

    def __init__(self, context: CstContext) -> None:
        super().__init__(context)
        self.annotation_counter: int = 0

    def visit_Annotation(self, node: libcst.Annotation) -> None:
        self.annotation_counter += 1

    def leave_Annotation(self, original_node: libcst.Annotation) -> None:
        self.annotation_counter -= 1

    def visit_Name(self, node: libcst.Name) -> None:
        # Avoid a false-positive in this scenario:
        #
        # ```
        # from typing import List as list
        # from graphene import List
        # ```
        qualified_names = self.get_metadata(QualifiedNameProvider, node, set())

        is_builtin_type = node.value in BUILTINS_TO_REPLACE and all(qualified_name.name in QUALIFIED_BUILTINS_TO_REPLACE for qualified_name in qualified_names)

        if self.annotation_counter > 0 and is_builtin_type:
            correct_type = node.value.title()
            scope = self.get_metadata(ScopeProvider, node)
            replacement = None
            if scope is not None and correct_type in scope:
                replacement = node.with_changes(value=correct_type)
            self.report(
                node,
                REPLACE_BUILTIN_TYPE_ANNOTATION.format(builtin_type=node.value, correct_type=correct_type),
                replacement=replacement,
            )
class UseClassNameAsCodeRule(CstLintRule):
    """
    Meta lint rule which checks that codes of lint rules are migrated to new format in lint rule class definitions.
    """

    MESSAGE = "`IG`-series codes are deprecated. Use class name as code instead."
    VALID = [
        Valid(
            """
        MESSAGE = "This is a message"
        """
        ),
        Valid(
            """
        from fixit.common.base import CstLintRule
        class FakeRule(CstLintRule):
            MESSAGE = "This is a message"
        """
        ),
        Valid(
            """
        from fixit.common.base import CstLintRule
        class FakeRule(CstLintRule):
            INVALID = [
                Invalid(
                    code=""
                )
            ]
        """
        ),
    ]
    INVALID = [
        Invalid(
            code="""
            MESSAGE = "IG90000 Message"
            """,
            expected_replacement="""
            MESSAGE = "Message"
            """,
        ),
        Invalid(
            code="""
            from fixit.common.base import CstLintRule
            class FakeRule(CstLintRule):
                INVALID = [
                    Invalid(
                        code="",
                        kind="IG000"
                    )
                ]
            """,
            expected_replacement="""
            from fixit.common.base import CstLintRule
            class FakeRule(CstLintRule):
                INVALID = [
                    Invalid(
                        code="",
                        )
                ]
            """,
        ),
    ]

    def __init__(self, context: CstContext) -> None:
        super().__init__(context)
        self.inside_invalid_call: bool = False

    def visit_SimpleString(self, node: cst.SimpleString) -> None:
        matched = re.match(r"^(\'|\")(?P<igcode>IG\d+ )\S", node.value)

        if matched is not None:
            replacement_string = node.value.replace(matched.group("igcode"), "", 1)
            self.report(
                node,
                self.MESSAGE,
                replacement=node.with_changes(value=replacement_string),
            )

    def visit_Call(self, node: cst.Call) -> None:
        func = node.func
        if m.matches(func, m.Name()):
            func = cast(cst.Name, func)
            if func.value == "Invalid":
                self.inside_invalid_call = True

    def leave_Call(self, original_node: cst.Call) -> None:
        func = original_node.func
        if m.matches(func, m.Name()):
            func = cast(cst.Name, func)
            if func.value == "Invalid":
                self.inside_invalid_call = False

    def visit_Arg(self, node: cst.Arg) -> None:
        # Remove `kind` arguments to Invalid test cases as they are no longer needed
        if self.inside_invalid_call:
            arg_value = node.value
            if m.matches(arg_value, m.SimpleString()):
                arg_value = cast(cst.SimpleString, arg_value)
                string_value = arg_value.value
                matched = re.match(r"^(\'|\")(?P<igcode>IG\d+)(\'|\")\Z", string_value)
                if matched:
                    self.report(
                        node,
                        self.MESSAGE,
                        replacement=cst.RemovalSentinel.REMOVE,
                    )
Beispiel #21
0
class ImportConstraintsRule(CstLintRule):
    """
    Rule to impose import constraints in certain directories to improve runtime performance.
    The directories specified in the ImportConstraintsRule setting in the ``.fixit.config.yaml`` file's
    ``rule_config`` section can impose import constraints for that directory and its children as follows::

        rule_config:
            ImportConstraintsRule:
                dir_under_repo_root:
                    rules: [
                        ["module_under_repo_root", "allow"],
                        ["another_module_under_repo_root, "deny"],
                        ["*", "deny"]
                    ]
                    ignore_tests: True
                    ignore_types: True

    Each rule under ``rules`` is evaluated in order from top to bottom and the last rule for each directory
    should be a wildcard rule.
    ``ignore_tests`` and `ignore_types` should carry boolean values and can be omitted. They are both set to
    `True` by default.
    If ``ignore_types`` is True, this rule will ignore imports inside ``if TYPE_CHECKING`` blocks since those
    imports do not have an affect on runtime performance.
    If ``ignore_tests`` is True, this rule will not lint any files found in a testing module.
    """

    _config: Optional[_ImportConfig]
    _repo_root: Path
    _type_checking_stack: List[cst.If]
    _abs_file_path: Path

    MESSAGE: str = (
        "According to the settings for this directory in the .fixit.config.yaml configuration file, "
        + "{imported} cannot be imported from within {current_file}. "
    )

    VALID = [
        # Everything is allowed
        Valid("import common"),
        Valid(
            "import common",
            config=_gen_testcase_config({"some_dir": {"rules": [["*", "allow"]]}}),
            filename="some_dir/file.py",
        ),
        # This import is allowlisted
        Valid(
            "import common",
            config=_gen_testcase_config(
                {"some_dir": {"rules": [["common", "allow"], ["*", "deny"]]}}
            ),
            filename="some_dir/file.py",
        ),
        # Allow children of a allowlisted module
        Valid(
            "from common.foo import bar",
            config=_gen_testcase_config(
                {"some_dir": {"rules": [["common", "allow"], ["*", "deny"]]}}
            ),
            filename="some_dir/file.py",
        ),
        # Validate rules are evaluted in order
        Valid(
            "from common.foo import bar",
            config=_gen_testcase_config(
                {
                    "some_dir": {
                        "rules": [
                            ["common.foo.bar", "allow"],
                            ["common", "deny"],
                            ["*", "deny"],
                        ]
                    }
                }
            ),
            filename="some_dir/file.py",
        ),
        # Built-in modules are fine
        Valid(
            "import ast",
            config=_gen_testcase_config({"some_dir": {"rules": [["*", "deny"]]}}),
            filename="some_dir/file.py",
        ),
        # Relative imports
        Valid(
            "from . import module",
            config=_gen_testcase_config(
                {".": {"rules": [["common.safe", "allow"], ["*", "deny"]]}}
            ),
            filename="common/safe/file.py",
        ),
        Valid(
            "from ..safe import module",
            config=_gen_testcase_config(
                {"common": {"rules": [["common.safe", "allow"], ["*", "deny"]]}}
            ),
            filename="common/unsafe/file.py",
        ),
        # Ignore some relative module that leaves the repo root
        Valid(
            "from ....................................... import module",
            config=_gen_testcase_config({".": {"rules": [["*", "deny"]]}}),
            filename="file.py",
        ),
        # File belongs to more than one directory setting (should enforce closest parent directory)
        Valid(
            "from common.foo import bar",
            config=_gen_testcase_config(
                {
                    "dir_1/dir_2": {
                        "rules": [["common.foo.bar", "allow"], ["*", "deny"]]
                    },
                    "dir_1": {"rules": [["common.foo.bar", "deny"], ["*", "deny"]]},
                }
            ),
            filename="dir_1/dir_2/file.py",
        ),
        # File belongs to more than one directory setting, flipped order (should enforce closest parent directory)
        Valid(
            "from common.foo import bar",
            config=_gen_testcase_config(
                {
                    "dir_1": {"rules": [["common.foo.bar", "deny"], ["*", "deny"]]},
                    "dir_1/dir_2": {
                        "rules": [["common.foo.bar", "allow"], ["*", "deny"]]
                    },
                }
            ),
            filename="dir_1/dir_2/file.py",
        ),
    ]

    INVALID = [
        # Everything is denied
        Invalid(
            "import common",
            config=_gen_testcase_config({"some_dir": {"rules": [["*", "deny"]]}}),
            filename="some_dir/file.py",
        ),
        # Validate rules are evaluated in order
        Invalid(
            "from common.foo import bar",
            config=_gen_testcase_config(
                {
                    "some_dir": {
                        "rules": [
                            ["common.foo.bar", "deny"],
                            ["common", "allow"],
                            ["*", "allow"],
                        ]
                    }
                }
            ),
            filename="some_dir/file.py",
        ),
        # We should match against the real name, not the aliased name
        Invalid(
            "import common as not_common",
            config=_gen_testcase_config(
                {"some_dir": {"rules": [["common", "deny"], ["*", "allow"]]}}
            ),
            filename="some_dir/file.py",
        ),
        Invalid(
            "from common import bar as not_bar",
            config=_gen_testcase_config(
                {"some_dir": {"rules": [["common.bar", "deny"], ["*", "allow"]]}}
            ),
            filename="some_dir/file.py",
        ),
        # Relative imports
        Invalid(
            "from . import b",
            config=_gen_testcase_config({"common": {"rules": [["*", "deny"]]}}),
            filename="common/a.py",
        ),
        # File belongs to more than one directory setting, import from
        Invalid(
            "from common.foo import bar",
            config=_gen_testcase_config(
                {
                    "dir_1/dir_2": {"rules": [["*", "deny"]]},
                    "dir_1": {
                        "rules": [["common.foo.bar", "allow"], ["*", "deny"]],
                    },
                }
            ),
            filename="dir_1/dir_2/file.py",
        ),
        # File belongs to more than one directory setting, import
        Invalid(
            "import common",
            config=_gen_testcase_config(
                {
                    "dir_1/dir_2": {"rules": [["*", "deny"]]},
                    "dir_1": {
                        "rules": [["common", "allow"], ["*", "deny"]],
                    },
                }
            ),
            filename="dir_1/dir_2/file.py",
        ),
    ]

    def __init__(self, context: CstContext) -> None:
        super().__init__(context)
        self._repo_root = Path(self.context.config.repo_root).resolve()
        self._config = None
        self._abs_file_path = (self._repo_root / context.file_path).resolve()
        import_constraints_config = self.context.config.rule_config.get(
            self.__class__.__name__, None
        )
        if import_constraints_config is not None:
            formatted_config: Dict[
                Path, Dict[object, object]
            ] = self._parse_and_format_config(import_constraints_config)
            # Run through logical ancestors of the filepath stopping early if a parent
            # directory is found in the config. The closest parent's settings will be used.
            for parent_dir in self._abs_file_path.parents:
                if parent_dir in formatted_config:
                    settings_for_dir = formatted_config[parent_dir]
                    self._config = _ImportConfig.from_config(settings_for_dir)
                    break
        self._type_checking_stack = []

    def _parse_and_format_config(
        self, import_constraints_config: Dict[str, object]
    ) -> Dict[Path, Dict[object, object]]:
        # Normalizes paths, and converts all paths to absolute paths using the specified repo_root.
        formatted_config: Dict[Path, Dict[object, object]] = {}
        for dirname, dir_settings in import_constraints_config.items():
            abs_dirpath: Optional[Path] = None

            # If it's an absolute path, make sure it's relative to repo_root (which should be an absolute path).
            if os.path.isabs(dirname):
                if dirname.startswith(str(self._repo_root)):
                    abs_dirpath = Path(dirname)
            # Otherwise assume all relative paths exist under repo_root, and don't add paths that leave repo_root (eg: '../path')
            else:
                abs_dirname = os.path.normpath(os.path.join(self._repo_root, dirname))
                if abs_dirname.startswith(str(self._repo_root)):
                    abs_dirpath = Path(abs_dirname)

            if not isinstance(dir_settings, dict):
                raise ValueError(
                    f"Invalid entry `{dir_settings}`.\n"
                    + "You must specify settings in key-value format under a directory."
                )
            if abs_dirpath is not None:
                formatted_config[abs_dirpath] = dir_settings

        return formatted_config

    def should_skip_file(self) -> bool:
        config = self._config
        return config is None or (config.ignore_tests and self.context.in_tests)

    def visit_If(self, node: cst.If) -> None:
        # TODO: Handle stuff like typing.TYPE_CHECKING
        test = node.test
        if isinstance(test, cst.Name) and test.value == "TYPE_CHECKING":
            self._type_checking_stack.append(node)

    def leave_If(self, original_node: cst.If) -> None:
        if self._type_checking_stack and self._type_checking_stack[-1] is original_node:
            self._type_checking_stack.pop()

    def visit_Import(self, node: cst.Import) -> None:
        self._check_names(
            node, (get_full_name_for_node_or_raise(alias.name) for alias in node.names)
        )

    def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
        module = node.module
        abs_module = self._to_absolute_module(
            get_full_name_for_node_or_raise(module) if module is not None else "",
            len(node.relative),
        )
        if abs_module is not None:
            names = node.names
            if isinstance(names, Sequence):
                self._check_names(
                    node, (f"{abs_module}.{alias.name.value}" for alias in names)
                )

    def _to_absolute_module(self, module: Optional[str], level: int) -> Optional[str]:
        if level == 0:
            return module

        # Get the absolute path of the file
        current_dir = self._abs_file_path
        for __ in range(level):
            current_dir = current_dir.parent

        if (
            current_dir != self._repo_root
            and self._repo_root not in current_dir.parents
        ):
            return None

        prefix = ".".join(current_dir.relative_to(self._repo_root).parts)
        return f"{prefix}.{module}" if module is not None else prefix

    def _check_names(self, node: cst.CSTNode, names: Iterable[str]) -> None:
        config = self._config
        if config is None or (config.ignore_types and self._type_checking_stack):
            return
        for name in names:
            if name.split(".", 1)[0] not in _get_local_roots(self._repo_root):
                continue
            rule = config.match(name)
            if not rule.allow:
                self.report(
                    node,
                    self.MESSAGE.format(
                        imported=name, current_file=self.context.file_path
                    ),
                )
class UsePrintfLoggingRule(CstLintRule):
    MESSAGE: str = (
        "UsePrintfLoggingRule: Use %s style strings instead of f-string or format() "
        + "for python logging. Pass %s values as arguments or in the 'extra' argument "
        + "in loggers."
    )

    VALID = [
        # Printf strings in log statements are best practice
        Valid('logging.error("Hello %s", my_val)'),
        Valid('logging.info("Hello %(my_str)s!", {"my_str": "world"})'),
        Valid(
            """
            logger = logging.getLogger()
            logger.log(logging.DEBUG, "Concat " "Printf string %s", vals)
            logger.debug("Hello %(my_str)s!", {"my_str": "world"})
            """,
        ),
        Valid(
            """
            mylog = logging.getLogger()
            mylog.warning("A printf string %s, %d", 'George', 732)
            """,
        ),
        # Don't report if logger isn't a logging.getLogger
        Valid('logger.error("Hello %s", my_val)'),
        Valid(
            """
            logger = custom_logger.getLogger()
            logger.error("Hello %s", my_val)
            """,
        ),
        # fstrings should be allowed elsewhere
        Valid(
            """
            logging.warning("simple error %s", test)
            fn(f"formatted string {my_var}")
            """,
        ),
        Valid(
            """
            logger: logging.Logger = logging.getLogger()
            logger.warning("simple error %s", test)
            test_var = f"test string {msg}"
            func(3, other_var, "string format {}".format(test_var))
            """,
        ),
        # %s interpolation allowed outside of log calls
        Valid('test = "my %s" % var'),
    ]

    INVALID = [
        # Using fstring in a log
        Invalid('logging.error(f"Hello {my_var}")'),
        Invalid(
            """
            logger = logging.getLogger()
            logger.error("Cat" f"Hello {my_var}")
            """,
        ),
        # Using str.format() in a log
        Invalid('logging.info("Hello {}".format(my_str))'),
        # Also invalid to use either in loggers
        Invalid(
            """
            logger: logging.Logger = logging.getLogger()
            logger.log("DEBUG", f"Format string {vals}")
            """,
        ),
        Invalid(
            """
            log = logging.getLogger()
            log.warning("This string formats {}".format(foo))
            """,
        ),
        # Do not interpolate %s strings in log
        Invalid('logging.error("my error: %s" % msg)'),
    ]

    def __init__(self, context: CstContext) -> None:
        super().__init__(context)
        self.logging_stack: Set[cst.Call] = set()
        self.logger_names: Set[str] = {"logging"}

    def visit_Assign(self, node: cst.Assign) -> None:
        # Store the assignment of logger = logging.getLogger()
        if check_getLogger(node):
            target = node.targets[0].target
            if m.matches(target, m.Name()):
                self.logger_names.add(cst.ensure_type(target, cst.Name).value)

    def visit_AnnAssign(self, node: cst.AnnAssign) -> None:
        # Store the assignment of logger: logging.Logger = logging.getLogger()
        if check_getLogger(node):
            if m.matches(node.target, m.Name()):
                self.logger_names.add(cst.ensure_type(node.target, cst.Name).value)

    def visit_Call(self, node: cst.Call) -> None:
        # Record if we are in a call to a log function
        if match_logger_calls(node, self.logger_names):
            self.logging_stack.add(node)

        # Check and report calls to str.format() in calls to log functions
        if self.logging_stack and match_calls_str_format(node):
            self.report(node)

    def leave_Call(self, original_node: cst.Call) -> None:
        # Record leaving a call to a log function
        if original_node in self.logging_stack:
            self.logging_stack.remove(original_node)

    def visit_FormattedString(self, node: cst.FormattedString) -> None:
        # Report if using a formatted string inside a logging call
        if self.logging_stack:
            self.report(node)

    def visit_BinaryOperation(self, node: cst.BinaryOperation) -> None:
        if not self.logging_stack:
            return
        if m.matches(
            node,
            m.BinaryOperation(
                left=m.OneOf(m.SimpleString(), m.ConcatenatedString()),
                operator=m.Modulo(),
            ),
        ):
            self.report(node)
Beispiel #23
0
class UseFstringRule(CstLintRule):
    """
    Encourages the use of f-string instead of %-formatting or .format() for high code quality and efficiency.

    Following two cases not covered:

    1. arguments length greater than 30 characters: for better readibility reason
        For example:

        1: this is the answer: %d" % (a_long_function_call() + b_another_long_function_call())
        2: f"this is the answer: {a_long_function_call() + b_another_long_function_call()}"
        3: result = a_long_function_call() + b_another_long_function_call()
        f"this is the answer: {result}"

        Line 1 is more readable than line 2. Ideally, we’d like developers to manually fix this case to line 3

    2. only %s placeholders are linted against for now. We leave it as future work to support other placeholders.
        For example, %d raises TypeError for non-numeric objects, whereas f“{x:d}” raises ValueError.
        This discrepancy in the type of exception raised could potentially break the logic in the code where the exception is handled
    """

    MESSAGE: str = (
        "Do not use printf style formatting or .format(). " +
        "Use f-string instead to be more readable and efficient. " +
        "See https://www.python.org/dev/peps/pep-0498/")

    VALID = [
        Valid("somebody='you'; f\"Hey, {somebody}.\""),
        Valid('"hey"'),
        Valid('"hey" + "there"'),
        Valid('b"a type %s" % var'),
        Valid('logging.error("printf style logging %s", my_var)'),
    ]

    INVALID = [
        Invalid('"Hey, {somebody}.".format(somebody="you")'),
        Invalid('"%s" % "hi"', expected_replacement='''f"{'hi'}"'''),
        Invalid('"a name: %s" % name',
                expected_replacement='f"a name: {name}"'),
        Invalid(
            '"an attribute %s ." % obj.attr',
            expected_replacement='f"an attribute {obj.attr} ."',
        ),
        Invalid(
            'r"raw string value=%s" % val',
            expected_replacement='fr"raw string value={val}"',
        ),
        Invalid('"{%s}" % val', expected_replacement='f"{{{val}}}"'),
        Invalid('"{%s" % val', expected_replacement='f"{{{val}"'),
        Invalid(
            '"The type of var: %s" % type(var)',
            expected_replacement='f"The type of var: {type(var)}"',
        ),
        Invalid(
            '"%s" % obj.this_is_a_very_long_expression(parameter)["a_very_long_key"]',
        ),
        Invalid(
            '"type of var: %s, value of var: %s" % (type(var), var)',
            expected_replacement=
            'f"type of var: {type(var)}, value of var: {var}"',
        ),
        Invalid(
            "'%s\" double quote is used' % var",
            expected_replacement="f'{var}\" double quote is used'",
        ),
        Invalid(
            '"var1: %s, var2: %s, var3: %s, var4: %s" % (class_object.attribute, dict_lookup["some_key"], some_module.some_function(), var4)',
            expected_replacement=
            '''f"var1: {class_object.attribute}, var2: {dict_lookup['some_key']}, var3: {some_module.some_function()}, var4: {var4}"''',
        ),
        Invalid(
            '"a list: %s" % " ".join(var)',
            expected_replacement='''f"a list: {' '.join(var)}"''',
        ),
    ]

    def visit_Call(self, node: cst.Call) -> None:
        if m.matches(
                node,
                m.Call(func=m.Attribute(value=m.SimpleString(),
                                        attr=m.Name(value="format"))),
        ):
            self.report(node)

    def visit_BinaryOperation(self, node: cst.BinaryOperation) -> None:
        expr_key = "expr"
        extracts = m.extract(
            node,
            m.BinaryOperation(
                left=m.MatchIfTrue(_match_simple_string),
                operator=m.Modulo(),
                right=m.SaveMatchedNode(
                    m.MatchIfTrue(
                        _gen_match_simple_expression(
                            self.context.wrapper.module)),
                    expr_key,
                ),
            ),
        )

        if extracts:
            expr = extracts[expr_key]
            parts = []
            simple_string = cst.ensure_type(node.left, cst.SimpleString)
            innards = simple_string.raw_value.replace("{",
                                                      "{{").replace("}", "}}")
            tokens = innards.split("%s")
            token = tokens[0]
            if len(token) > 0:
                parts.append(cst.FormattedStringText(value=token))
            expressions = ([elm.value for elm in expr.elements] if isinstance(
                expr, cst.Tuple) else [expr])
            escape_transformer = EscapeStringQuote(simple_string.quote)
            i = 1
            while i < len(tokens):
                if i - 1 >= len(expressions):
                    # Only generate warning for cases where %-string not comes with same number of elements in tuple
                    self.report(node)
                    return
                try:
                    parts.append(
                        cst.FormattedStringExpression(expression=cast(
                            cst.BaseExpression,
                            expressions[i - 1].visit(escape_transformer),
                        )))
                except Exception:
                    self.report(node)
                    return
                token = tokens[i]
                if len(token) > 0:
                    parts.append(cst.FormattedStringText(value=token))
                i += 1
            start = f"f{simple_string.prefix}{simple_string.quote}"
            replacement = cst.FormattedString(parts=parts,
                                              start=start,
                                              end=simple_string.quote)
            self.report(node, replacement=replacement)
        elif m.matches(
                node,
                m.BinaryOperation(
                    left=m.SimpleString(),
                    operator=m.Modulo())) and isinstance(
                        cst.ensure_type(
                            node.left, cst.SimpleString).evaluated_value, str):
            self.report(node)
class SortedAttributesRule(CstLintRule):
    """
    Ever wanted to sort a bunch of class attributes alphabetically?
    Well now it's easy! Just add "@sorted-attributes" in the doc string of
    a class definition and lint will automatically sort all attributes alphabetically.

    Feel free to add other methods and such -- it should only affect class attributes.
    """

    INVALID = [
        Invalid(
            """
            class MyUnsortedConstants:
                \"\"\"
                @sorted-attributes
                \"\"\"
                z = "hehehe"
                B = 'aaa234'
                A = 'zzz123'
                cab = "foo bar"
                Daaa = "banana"

                @classmethod
                def get_foo(cls) -> str:
                    return "some random thing"
           """,
            expected_replacement="""
            class MyUnsortedConstants:
                \"\"\"
                @sorted-attributes
                \"\"\"
                A = 'zzz123'
                B = 'aaa234'
                Daaa = "banana"
                cab = "foo bar"
                z = "hehehe"

                @classmethod
                def get_foo(cls) -> str:
                    return "some random thing"
           """,
        )
    ]
    VALID = [
        Valid("""
            class MyConstants:
                \"\"\"
                @sorted-attributes
                \"\"\"
                A = 'zzz123'
                B = 'aaa234'

            class MyUnsortedConstants:
                B = 'aaa234'
                A = 'zzz123'
           """)
    ]
    MESSAGE: str = "It appears you are using the @sorted-attributes directive and the class variables are unsorted. See the lint autofix suggestion."

    def visit_ClassDef(self, node: cst.ClassDef) -> None:
        doc_string = node.get_docstring()
        if not doc_string or "@sorted-attributes" not in doc_string:
            return

        found_any_assign: bool = False
        pre_assign_lines: List[LineType] = []
        assign_lines: List[LineType] = []
        post_assign_lines: List[LineType] = []

        def _add_unmatched_line(line: LineType) -> None:
            post_assign_lines.append(
                line) if found_any_assign else pre_assign_lines.append(line)

        for line in node.body.body:
            if m.matches(
                    line,
                    m.SimpleStatementLine(
                        body=[m.Assign(targets=[m.AssignTarget()])])):
                found_any_assign = True
                assign_lines.append(line)
            else:
                _add_unmatched_line(line)
                continue

        sorted_assign_lines = sorted(
            assign_lines,
            key=lambda line: line.body[0].targets[0].target.value)
        if sorted_assign_lines == assign_lines:
            return
        self.report(
            node,
            replacement=node.with_changes(body=node.body.with_changes(
                body=pre_assign_lines + sorted_assign_lines +
                post_assign_lines)),
        )
class NoStringTypeAnnotationRule(CstLintRule):
    """
    Enforce the use of type identifier instead of using string type hints for simplicity and better syntax highlighting.
    Starting in Python 3.7, ``from __future__ import annotations`` can postpone evaluation of type annotations
    `PEP 563 <https://www.python.org/dev/peps/pep-0563/#forward-references>`_
    and thus forward references no longer need to use string annotation style.
    """

    MESSAGE = "String type hints are no longer necessary in Python, use the type identifier directly."

    VALID = [
        # Usage of a Class for instantiation and typing.
        Valid("""
            from a.b import Class

            def foo() -> Class:
                return Class()
            """),
        Valid("""
            import typing
            from a.b import Class

            def foo() -> typing.Type[Class]:
                return Class
            """),
        Valid("""
            import typing
            from a.b import Class
            from c import func

            def foo() -> typing.Optional[typing.Type[Class]]:
                return Class if func() else None
            """),
        Valid("""
            from a.b import Class

            def foo(arg: Class) -> None:
                pass

            foo(Class())
            """),
        Valid("""
            from a.b import Class

            module_var: Class = Class()
            """),
        Valid("""
            from typing import Literal

            def foo() -> Literal["a", "b"]:
                return "a"
            """),
        Valid("""
            import typing

            def foo() -> typing.Optional[typing.Literal["a", "b"]]:
                return "a"
            """),
    ]

    INVALID = [
        # Using string type hints isn't needed
        Invalid(
            """
            from __future__ import annotations

            from a.b import Class

            def foo() -> "Class":
                return Class()
            """,
            line=5,
            expected_replacement="""
            from __future__ import annotations

            from a.b import Class

            def foo() -> Class:
                return Class()
            """,
        ),
        Invalid(
            """
            from __future__ import annotations

            from a.b import Class

            async def foo() -> "Class":
                return await Class()
            """,
            line=5,
            expected_replacement="""
            from __future__ import annotations

            from a.b import Class

            async def foo() -> Class:
                return await Class()
            """,
        ),
        Invalid(
            """
            from __future__ import annotations

            import typing
            from a.b import Class

            def foo() -> typing.Type["Class"]:
                return Class
            """,
            line=6,
            expected_replacement="""
            from __future__ import annotations

            import typing
            from a.b import Class

            def foo() -> typing.Type[Class]:
                return Class
            """,
        ),
        Invalid(
            """
            from __future__ import annotations

            import typing
            from a.b import Class
            from c import func

            def foo() -> Optional[typing.Type["Class"]]:
                return Class if func() else None
            """,
            line=7,
            expected_replacement="""
            from __future__ import annotations

            import typing
            from a.b import Class
            from c import func

            def foo() -> Optional[typing.Type[Class]]:
                return Class if func() else None
            """,
        ),
        Invalid(
            """
            from __future__ import annotations

            from a.b import Class

            def foo(arg: "Class") -> None:
                pass

            foo(Class())
            """,
            line=5,
            expected_replacement="""
            from __future__ import annotations

            from a.b import Class

            def foo(arg: Class) -> None:
                pass

            foo(Class())
            """,
        ),
        Invalid(
            """
            from __future__ import annotations

            from a.b import Class

            module_var: "Class" = Class()
            """,
            line=5,
            expected_replacement="""
            from __future__ import annotations

            from a.b import Class

            module_var: Class = Class()
            """,
        ),
        Invalid(
            """
            from __future__ import annotations

            import typing
            from typing_extensions import Literal
            from a.b import Class

            def foo() -> typing.Tuple[Literal["a", "b"], "Class"]:
                return Class()
            """,
            line=7,
            expected_replacement="""
            from __future__ import annotations

            import typing
            from typing_extensions import Literal
            from a.b import Class

            def foo() -> typing.Tuple[Literal["a", "b"], Class]:
                return Class()
            """,
        ),
    ]

    def __init__(self, context: CstContext) -> None:
        super().__init__(context)
        self.in_annotation: Set[cst.Annotation] = set()
        self.in_literal: Set[cst.Subscript] = set()
        self.has_future_annotations_import = False

    def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
        if m.matches(
                node,
                m.ImportFrom(
                    module=m.Name("__future__"),
                    names=[
                        m.ZeroOrMore(),
                        m.ImportAlias(name=m.Name("annotations")),
                        m.ZeroOrMore(),
                    ],
                ),
        ):
            self.has_future_annotations_import = True

    def visit_Annotation(self, node: cst.Annotation) -> None:
        self.in_annotation.add(node)

    def leave_Annotation(self, original_node: cst.Annotation) -> None:
        self.in_annotation.remove(original_node)

    def visit_Subscript(self, node: cst.Subscript) -> None:
        if not self.has_future_annotations_import:
            return
        if self.in_annotation:
            if m.matches(
                    node,
                    m.Subscript(metadata=m.MatchMetadataIfTrue(
                        QualifiedNameProvider,
                        lambda qualnames: any(n.name ==
                                              "typing_extensions.Literal"
                                              for n in qualnames),
                    )),
                    metadata_resolver=self.context.wrapper,
            ):
                self.in_literal.add(node)

    def leave_Subscript(self, original_node: cst.Subscript) -> None:
        if not self.has_future_annotations_import:
            return
        if original_node in self.in_literal:
            self.in_literal.remove(original_node)

    def visit_SimpleString(self, node: cst.SimpleString) -> None:
        if not self.has_future_annotations_import:
            return
        if self.in_annotation and not self.in_literal:
            # This is not allowed past Python3.7 since it's no longer necessary.
            self.report(
                node,
                replacement=cst.parse_expression(
                    node.evaluated_value,
                    config=self.context.wrapper.module.config_for_parsing,
                ),
            )
Beispiel #26
0
class RequireDoctestRule(CstLintRule):

    VALID = [
        # Module-level docstring contains doctest.
        Valid("""
            '''
            Module-level docstring contains doctest
            >>> foo()
            None
            '''
            def foo():
                pass

            class Bar:
                def baz(self):
                    pass

            def bar():
                pass
            """),
        # Module contains a test function.
        Valid("""
            def foo():
                pass

            def bar():
                pass

            # Contains a test function
            def test_foo():
                pass

            class Baz:
                def baz(self):
                    pass

            def spam():
                pass
            """),
        # Module contains multiple test function.
        Valid("""
            def foo():
                pass

            def bar():
                pass

            def test_foo():
                pass

            def test_bar():
                pass

            class Baz:
                def baz(self):
                    pass

            def spam():
                pass
            """),
        # Module contains a test class.
        Valid("""
            def foo():
                pass

            class Baz:
                def baz(self):
                    pass

            def bar():
                pass

            # Contains a test class
            class TestSpam:
                def test_spam(self):
                    pass

            def egg():
                pass
            """),
        # Class level docstring contains doctest, so skip doctest checking only
        # for that class.
        Valid("""
            def foo():
                '''
                >>> foo()
                '''
                pass

            class Spam:
                '''
                Class-level docstring contains doctest
                >>> Spam()
                '''
                def foo(self):
                    pass

                def spam(self):
                    pass

            def bar():
                '''
                >>> bar()
                '''
                pass
            """),
        # No doctest required for the ``__init__`` function.
        Valid("""
            def spam():
                '''
                >>> spam()
                '''
                pass

            class Bar:
                # No doctest needed for the init function
                def __init__(self):
                    pass

                def bar(self):
                    '''
                    >>> bar()
                    '''
                    pass
            """),
    ]

    INVALID = [
        Invalid("""
            def bar():
                pass
            """),
        # Only the ``__init__`` function does not require doctest.
        Invalid("""
            def foo():
                '''
                >>> foo()
                '''
                pass

            class Spam:
                def __init__(self):
                    pass

                def spam(self):
                    pass
            """),
        # Check that `_skip_doctest` attribute is reseted after leaving the class.
        Invalid("""
            def bar():
                '''
                >>> bar()
                '''
                pass

            class Spam:
                '''
                >>> Spam()
                '''
                def spam():
                    pass

            def egg():
                pass
            """),
    ]

    def __init__(self, context: CstContext) -> None:
        super().__init__(context)
        self._skip_doctest: bool = False
        self._temporary: bool = False

    def visit_Module(self, node: cst.Module) -> None:
        self._skip_doctest = self._has_testnode(node) or self._has_doctest(
            node)

    def visit_ClassDef(self, node: cst.ClassDef) -> None:
        # Temporary storage of the ``skip_doctest`` value only during the class visit.
        # If the class-level docstring contains doctest, then the checks should only be
        # skipped for all its methods and not for other functions/class in the module.
        # After leaving the class, ``skip_doctest`` should be resetted to whatever the
        # value was before.
        self._temporary = self._skip_doctest
        self._skip_doctest = self._has_doctest(node)

    def leave_ClassDef(self, original_node: cst.ClassDef) -> None:
        self._skip_doctest = self._temporary

    def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
        nodename = node.name.value
        if nodename != INIT and not self._has_doctest(node):
            self.report(
                node,
                MISSING_DOCTEST.format(filepath=self.context.file_path,
                                       nodename=nodename),
            )

    def _has_doctest(
            self, node: Union[cst.Module, cst.ClassDef,
                              cst.FunctionDef]) -> bool:
        """Check whether the given node contains doctests.

        If the ``_skip_doctest`` attribute is ``True``, the function will by default
        return ``True``, otherwise it will extract the docstring and look for doctest
        patterns (>>> ) in it. If there is no docstring for the node, this will mean
        the absence of doctest.
        """
        if not self._skip_doctest:
            docstring = node.get_docstring()
            if docstring is not None:
                for line in docstring.splitlines():
                    if line.strip().startswith(">>> "):
                        return True
            return False
        return True

    @staticmethod
    def _has_testnode(node: cst.Module) -> bool:
        return m.matches(
            node,
            m.Module(body=[
                # Sequence wildcard matchers matches LibCAST nodes in a row in a
                # sequence. It does not implicitly match on partial sequences. So,
                # when matching against a sequence we will need to provide a
                # complete pattern. This often means using helpers such as
                # ``ZeroOrMore()`` as the first and last element of the sequence.
                m.ZeroOrMore(),
                m.AtLeastN(
                    n=1,
                    matcher=m.OneOf(
                        m.FunctionDef(name=m.Name(value=m.MatchIfTrue(
                            lambda value: value.startswith("test_")))),
                        m.ClassDef(name=m.Name(value=m.MatchIfTrue(
                            lambda value: value.startswith("Test")))),
                    ),
                ),
                m.ZeroOrMore(),
            ]),
        )
class ExplicitFrozenDataclassRule(CstLintRule):
    """
    Encourages the use of frozen dataclass objects by telling users to specify the
    kwarg.

    Without this lint rule, most users of dataclass won't know to use the kwarg, and
    may unintentionally end up with mutable objects.
    """

    MESSAGE: str = (
        "When using dataclasses, explicitly specify a frozen keyword argument. "
        + "Example: `@dataclass(frozen=True)` or `@dataclass(frozen=False)`. "
        + "Docs: https://docs.python.org/3/library/dataclasses.html"
    )
    METADATA_DEPENDENCIES = (QualifiedNameProvider,)
    VALID = [
        Valid(
            """
            @some_other_decorator
            class Cls: pass
            """
        ),
        Valid(
            """
            from dataclasses import dataclass
            @dataclass(frozen=False)
            class Cls: pass
            """
        ),
        Valid(
            """
            import dataclasses
            @dataclasses.dataclass(frozen=False)
            class Cls: pass
            """
        ),
        Valid(
            """
            import dataclasses as dc
            @dc.dataclass(frozen=False)
            class Cls: pass
            """
        ),
        Valid(
            """
            from dataclasses import dataclass as dc
            @dc(frozen=False)
            class Cls: pass
            """
        ),
    ]
    INVALID = [
        Invalid(
            """
            from dataclasses import dataclass
            @some_unrelated_decorator
            @dataclass  # not called as a function
            @another_unrelated_decorator
            class Cls: pass
            """,
            line=3,
            expected_replacement="""
            from dataclasses import dataclass
            @some_unrelated_decorator
            @dataclass(frozen=True)  # not called as a function
            @another_unrelated_decorator
            class Cls: pass
            """,
        ),
        Invalid(
            """
            from dataclasses import dataclass
            @dataclass()  # called as a function, no kwargs
            class Cls: pass
            """,
            line=2,
            expected_replacement="""
            from dataclasses import dataclass
            @dataclass(frozen=True)  # called as a function, no kwargs
            class Cls: pass
            """,
        ),
        Invalid(
            """
            from dataclasses import dataclass
            @dataclass(other_kwarg=False)
            class Cls: pass
            """,
            line=2,
            expected_replacement="""
            from dataclasses import dataclass
            @dataclass(other_kwarg=False, frozen=True)
            class Cls: pass
            """,
        ),
        Invalid(
            """
            import dataclasses
            @dataclasses.dataclass
            class Cls: pass
            """,
            line=2,
            expected_replacement="""
            import dataclasses
            @dataclasses.dataclass(frozen=True)
            class Cls: pass
            """,
        ),
        Invalid(
            """
            import dataclasses
            @dataclasses.dataclass()
            class Cls: pass
            """,
            line=2,
            expected_replacement="""
            import dataclasses
            @dataclasses.dataclass(frozen=True)
            class Cls: pass
            """,
        ),
        Invalid(
            """
            import dataclasses
            @dataclasses.dataclass(other_kwarg=False)
            class Cls: pass
            """,
            line=2,
            expected_replacement="""
            import dataclasses
            @dataclasses.dataclass(other_kwarg=False, frozen=True)
            class Cls: pass
            """,
        ),
        Invalid(
            """
            from dataclasses import dataclass as dc
            @dc
            class Cls: pass
            """,
            line=2,
            expected_replacement="""
            from dataclasses import dataclass as dc
            @dc(frozen=True)
            class Cls: pass
            """,
        ),
        Invalid(
            """
            from dataclasses import dataclass as dc
            @dc()
            class Cls: pass
            """,
            line=2,
            expected_replacement="""
            from dataclasses import dataclass as dc
            @dc(frozen=True)
            class Cls: pass
            """,
        ),
        Invalid(
            """
            from dataclasses import dataclass as dc
            @dc(other_kwarg=False)
            class Cls: pass
            """,
            line=2,
            expected_replacement="""
            from dataclasses import dataclass as dc
            @dc(other_kwarg=False, frozen=True)
            class Cls: pass
            """,
        ),
        Invalid(
            """
            import dataclasses as dc
            @dc.dataclass
            class Cls: pass
            """,
            line=2,
            expected_replacement="""
            import dataclasses as dc
            @dc.dataclass(frozen=True)
            class Cls: pass
            """,
        ),
        Invalid(
            """
            import dataclasses as dc
            @dc.dataclass()
            class Cls: pass
            """,
            line=2,
            expected_replacement="""
            import dataclasses as dc
            @dc.dataclass(frozen=True)
            class Cls: pass
            """,
        ),
        Invalid(
            """
            import dataclasses as dc
            @dc.dataclass(other_kwarg=False)
            class Cls: pass
            """,
            line=2,
            expected_replacement="""
            import dataclasses as dc
            @dc.dataclass(other_kwarg=False, frozen=True)
            class Cls: pass
            """,
        ),
    ]

    def visit_ClassDef(self, node: cst.ClassDef) -> None:
        for d in node.decorators:
            decorator = d.decorator
            if QualifiedNameProvider.has_name(
                self,
                decorator,
                QualifiedName(
                    name="dataclasses.dataclass", source=QualifiedNameSource.IMPORT
                ),
            ):
                if isinstance(decorator, cst.Call):
                    func = decorator.func
                    args = decorator.args
                else:  # decorator is either cst.Name or cst.Attribute
                    args = ()
                    func = decorator

                # pyre-fixme[29]: `typing.Union[typing.Callable(tuple.__iter__)[[], typing.Iterator[Variable[_T_co](covariant)]], typing.Callable(typing.Sequence.__iter__)[[], typing.Iterator[cst._nodes.expression.Arg]]]` is not a function.
                if not any(m.matches(arg.keyword, m.Name("frozen")) for arg in args):
                    new_decorator = cst.Call(
                        func=func,
                        args=list(args)
                        + [
                            cst.Arg(
                                keyword=cst.Name("frozen"),
                                value=cst.Name("True"),
                                equal=cst.AssignEqual(
                                    whitespace_before=SimpleWhitespace(value=""),
                                    whitespace_after=SimpleWhitespace(value=""),
                                ),
                            )
                        ],
                    )
                    self.report(d, replacement=d.with_changes(decorator=new_decorator))
Beispiel #28
0
class UseClsInClassmethodRule(CstLintRule):
    """
    Enforces using ``cls`` as the first argument in a ``@classmethod``.
    """

    METADATA_DEPENDENCIES = (QualifiedNameProvider, ScopeProvider)
    MESSAGE = "When using @classmethod, the first argument must be `cls`."
    VALID = [
        Valid(
            """
            class foo:
                # classmethod with cls first arg.
                @classmethod
                def cm(cls, a, b, c):
                    pass
            """
        ),
        Valid(
            """
            class foo:
                # non-classmethod with non-cls first arg.
                def nm(self, a, b, c):
                    pass
            """
        ),
        Valid(
            """
            class foo:
                # staticmethod with non-cls first arg.
                @staticmethod
                def sm(a):
                    pass
            """
        ),
    ]
    INVALID = [
        Invalid(
            """
            class foo:
                # No args at all.
                @classmethod
                def cm():
                    pass
            """,
            expected_replacement="""
            class foo:
                # No args at all.
                @classmethod
                def cm(cls):
                    pass
            """,
        ),
        Invalid(
            """
            class foo:
                # Single arg + reference.
                @classmethod
                def cm(a):
                    return a
            """,
            expected_replacement="""
            class foo:
                # Single arg + reference.
                @classmethod
                def cm(cls):
                    return cls
            """,
        ),
        Invalid(
            """
            class foo:
                # Another "cls" exists: do not autofix.
                @classmethod
                def cm(a):
                    cls = 2
            """,
        ),
        Invalid(
            """
            class foo:
                # Multiple args + references.
                @classmethod
                async def cm(a, b):
                    b = a
                    b = a.__name__
            """,
            expected_replacement="""
            class foo:
                # Multiple args + references.
                @classmethod
                async def cm(cls, b):
                    b = cls
                    b = cls.__name__
            """,
        ),
        Invalid(
            """
            class foo:
                # Do not replace in nested scopes.
                @classmethod
                async def cm(a, b):
                    b = a
                    b = lambda _: a.__name__
                    def g():
                        return a.__name__

                    # Same-named vars in sub-scopes should not be replaced.
                    b = [a for a in [1,2,3]]
                    def f(a):
                        return a + 1
            """,
            expected_replacement="""
            class foo:
                # Do not replace in nested scopes.
                @classmethod
                async def cm(cls, b):
                    b = cls
                    b = lambda _: cls.__name__
                    def g():
                        return cls.__name__

                    # Same-named vars in sub-scopes should not be replaced.
                    b = [a for a in [1,2,3]]
                    def f(a):
                        return a + 1
            """,
        ),
        Invalid(
            """
            # Do not replace in surrounding scopes.
            a = 1

            class foo:
                a = 2

                def im(a):
                    a = a

                @classmethod
                def cm(a):
                    a[1] = foo.cm(a=a)
            """,
            expected_replacement="""
            # Do not replace in surrounding scopes.
            a = 1

            class foo:
                a = 2

                def im(a):
                    a = a

                @classmethod
                def cm(cls):
                    cls[1] = foo.cm(a=cls)
            """,
        ),
        Invalid(
            """
            def another_decorator(x): pass

            class foo:
                # Multiple decorators.
                @another_decorator
                @classmethod
                @another_decorator
                async def cm(a, b, c):
                    pass
            """,
            expected_replacement="""
            def another_decorator(x): pass

            class foo:
                # Multiple decorators.
                @another_decorator
                @classmethod
                @another_decorator
                async def cm(cls, b, c):
                    pass
            """,
        ),
    ]

    def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
        if not any(
            QualifiedNameProvider.has_name(
                self,
                decorator.decorator,
                QualifiedName(name="builtins.classmethod", source=QualifiedNameSource.BUILTIN),
            )
            for decorator in node.decorators
        ):
            return  # If it's not a @classmethod, we are not interested.
        if not node.params.params:
            # No params, but there must be the 'cls' param.
            # Note that pyre[47] already catches this, but we also generate
            # an autofix, so it still makes sense for us to report it here.
            new_params = node.params.with_changes(params=(cst.Param(name=cst.Name(value=CLS)),))
            repl = node.with_changes(params=new_params)
            self.report(node, replacement=repl)
            return

        p0_name = node.params.params[0].name
        if p0_name.value == CLS:
            return  # All good.

        # Rename all assignments and references of the first param within the
        # function scope, as long as they are done via a Name node.
        # We rely on the parser to correctly derive all
        # assigments and references within the FunctionScope.
        # The Param node's scope is our classmethod's FunctionScope.
        scope = self.get_metadata(ScopeProvider, p0_name, None)
        if not scope:
            # Cannot autofix without scope metadata. Only report in this case.
            # Not sure how to repro+cover this in a unit test...
            # If metadata creation fails, then the whole lint fails, and if it succeeds,
            # then there is valid metadata. But many other lint rule implementations contain
            # a defensive scope None check like this one, so I assume it is necessary.
            self.report(node)
            return

        if scope[CLS]:
            # The scope already has another assignment to "cls".
            # Trying to rename the first param to "cls" as well may produce broken code.
            # We should therefore refrain from suggesting an autofix in this case.
            self.report(node)
            return

        refs: List[Union[cst.Name, cst.Attribute]] = []
        assignments = scope[p0_name.value]
        for a in assignments:
            if isinstance(a, Assignment):
                assign_node = a.node
                if isinstance(assign_node, cst.Name):
                    refs.append(assign_node)
                elif isinstance(assign_node, cst.Param):
                    refs.append(assign_node.name)
                # There are other types of possible assignment nodes: ClassDef,
                # FunctionDef, Import, etc. We deliberately do not handle those here.
            refs += [r.node for r in a.references]

        repl = node.visit(_RenameTransformer(refs, CLS))
        self.report(node, replacement=repl)
class ReplaceUnionWithOptionalRule(CstLintRule):
    """
    Enforces the use of ``Optional[T]`` over ``Union[T, None]`` and ``Union[None, T]``.
    See https://docs.python.org/3/library/typing.html#typing.Optional to learn more about Optionals.
    """

    MESSAGE: str = "`Optional[T]` is preferred over `Union[T, None]` or `Union[None, T]`. " + "Learn more: https://docs.python.org/3/library/typing.html#typing.Optional"
    METADATA_DEPENDENCIES = (cst.metadata.ScopeProvider, )
    VALID = [
        Valid("""
            def func() -> Optional[str]:
                pass
            """),
        Valid("""
            def func() -> Optional[Dict]:
                pass
            """),
        Valid("""
            def func() -> Union[str, int, None]:
                pass
            """),
    ]
    INVALID = [
        Invalid(
            """
            def func() -> Union[str, None]:
                pass
            """, ),
        Invalid(
            """
            from typing import Optional
            def func() -> Union[Dict[str, int], None]:
                pass
            """,
            expected_replacement="""
            from typing import Optional
            def func() -> Optional[Dict[str, int]]:
                pass
            """,
        ),
        Invalid(
            """
            from typing import Optional
            def func() -> Union[str, None]:
                pass
            """,
            expected_replacement="""
            from typing import Optional
            def func() -> Optional[str]:
                pass
            """,
        ),
        Invalid(
            """
            from typing import Optional
            def func() -> Union[Dict, None]:
                pass
            """,
            expected_replacement="""
            from typing import Optional
            def func() -> Optional[Dict]:
                pass
            """,
        ),
    ]

    def __init__(self, context: CstContext) -> None:
        super().__init__(context)

    def leave_Annotation(self, original_node: cst.Annotation) -> None:
        if self.contains_union_with_none(original_node):
            scope = self.get_metadata(cst.metadata.ScopeProvider,
                                      original_node, None)
            nones = 0
            indexes = []
            replacement = None
            if scope is not None and "Optional" in scope:
                for s in cst.ensure_type(original_node.annotation,
                                         cst.Subscript).slice:
                    if m.matches(s,
                                 m.SubscriptElement(m.Index(m.Name("None")))):
                        nones += 1
                    else:
                        indexes.append(s.slice)
                if not (nones > 1) and len(indexes) == 1:
                    replacement = original_node.with_changes(
                        annotation=cst.Subscript(
                            value=cst.Name("Optional"),
                            slice=(cst.SubscriptElement(indexes[0]), ),
                        ))
                    # TODO(T57106602) refactor lint replacement once extract exists
            self.report(original_node, replacement=replacement)

    def contains_union_with_none(self, node: cst.Annotation) -> bool:
        return m.matches(
            node,
            m.Annotation(
                m.Subscript(
                    value=m.Name("Union"),
                    slice=m.OneOf(
                        [
                            m.SubscriptElement(m.Index()),
                            m.SubscriptElement(m.Index(m.Name("None"))),
                        ],
                        [
                            m.SubscriptElement(m.Index(m.Name("None"))),
                            m.SubscriptElement(m.Index()),
                        ],
                    ),
                )),
        )
Beispiel #30
0
class GatherSequentialAwaitRule(CstLintRule):
    """
    Discourages awaiting coroutines in a loop as this will run them sequentially. Using ``asyncio.gather()`` will run them concurrently.
    """

    MESSAGE: str = (
        "Using await in a loop will run async function sequentially. Use " +
        "asyncio.gather() to run async functions concurrently.")
    VALID = [
        Valid("""
            async def async_foo():
                return await async_bar()
            """),
        Valid(
            """
            # await in a loop is fine if it's a test.
            # filename="foo/tests/test_foo.py"
            async def async_check_call():
                for _i in range(0, 2):
                    await async_foo()
            """,
            filename="foo/tests/test_foo.py",
        ),
    ]

    INVALID = [
        Invalid(
            """
            async def async_check_call():
                for _i in range(0, 2):
                    await async_foo()
            """,
            line=3,
        ),
        Invalid(
            """
            async def async_check_assignment():
                for _i in range(0, 2):
                    x = await async_foo()
            """,
            line=3,
        ),
        Invalid(
            """
            async def async_check_list_comprehension():
                [await async_foo() for _i in range(0, 2)]
            """,
            line=2,
        ),
    ]

    def should_skip_file(self) -> bool:
        return self.context.in_tests

    def visit_Await(self, node: cst.Await) -> None:
        parent = self.context.node_stack[-2]

        if isinstance(parent, (cst.Expr, cst.Assign)) and parent.value is node:
            grand_parent = self.context.node_stack[-5]
            # for and while code block contain IndentBlock and SimpleStatementLine
            if isinstance(grand_parent, (cst.For, cst.While)):
                self.report(node)

        if (isinstance(parent, (cst.ListComp, cst.SetComp, cst.GeneratorExp))
                and parent.elt is node):
            self.report(node)