예제 #1
0
def infer_typing_alias(
    node: Call, ctx: context.InferenceContext | None = None
) -> Iterator[ClassDef]:
    """
    Infers the call to _alias function
    Insert ClassDef, with same name as aliased class,
    in mro to simulate _GenericAlias.

    :param node: call node
    :param context: inference context
    """
    if (
        not isinstance(node.parent, Assign)
        or not len(node.parent.targets) == 1
        or not isinstance(node.parent.targets[0], AssignName)
    ):
        raise UseInferenceDefault
    try:
        res = next(node.args[0].infer(context=ctx))
    except StopIteration as e:
        raise InferenceError(node=node.args[0], context=context) from e

    assign_name = node.parent.targets[0]

    class_def = ClassDef(
        name=assign_name.name,
        lineno=assign_name.lineno,
        col_offset=assign_name.col_offset,
        parent=node.parent,
    )
    if res != Uninferable and isinstance(res, ClassDef):
        # Only add `res` as base if it's a `ClassDef`
        # This isn't the case for `typing.Pattern` and `typing.Match`
        class_def.postinit(bases=[res], body=[], decorators=None)

    maybe_type_var = node.args[1]
    if (
        not PY39_PLUS
        and not (isinstance(maybe_type_var, Tuple) and not maybe_type_var.elts)
        or PY39_PLUS
        and isinstance(maybe_type_var, Const)
        and maybe_type_var.value > 0
    ):
        # If typing alias is subscriptable, add `__class_getitem__` to ClassDef
        func_to_add = _extract_single_node(CLASS_GETITEM_TEMPLATE)
        class_def.locals["__class_getitem__"] = [func_to_add]
    else:
        # If not, make sure that `__class_getitem__` access is forbidden.
        # This is an issue in cases where the aliased class implements it,
        # but the typing alias isn't subscriptable. E.g., `typing.ByteString` for PY39+
        _forbid_class_getitem_access(class_def)
    return iter([class_def])
예제 #2
0
def infer_typedDict(  # pylint: disable=invalid-name
        node: FunctionDef,
        ctx: context.InferenceContext = None) -> typing.Iterator[ClassDef]:
    """Replace TypedDict FunctionDef with ClassDef."""
    class_def = ClassDef(
        name="TypedDict",
        lineno=node.lineno,
        col_offset=node.col_offset,
        parent=node.parent,
    )
    class_def.postinit(bases=[extract_node("dict")], body=[], decorators=None)
    func_to_add = extract_node("dict")
    class_def.locals["__call__"] = [func_to_add]
    return iter([class_def])
예제 #3
0
def _looks_like_subscriptable(node: ClassDef) -> bool:
    """
    Returns True if the node corresponds to a ClassDef of the Collections.abc module that
    supports subscripting

    :param node: ClassDef node
    """
    if node.qname().startswith("_collections") or node.qname().startswith(
            "collections"):
        try:
            node.getattr("__class_getitem__")
            return True
        except AttributeInferenceError:
            pass
    return False
예제 #4
0
def dataclass_transform(node: ClassDef) -> None:
    """Rewrite a dataclass to be easily understood by pylint"""
    node.is_dataclass = True

    for assign_node in _get_dataclass_attributes(node):
        name = assign_node.target.name

        rhs_node = Unknown(
            lineno=assign_node.lineno,
            col_offset=assign_node.col_offset,
            parent=assign_node,
        )
        rhs_node = AstroidManager().visit_transforms(rhs_node)
        node.instance_attrs[name] = [rhs_node]

    if not _check_generate_dataclass_init(node):
        return

    try:
        reversed_mro = list(reversed(node.mro()))
    except MroError:
        reversed_mro = [node]

    field_assigns = {}
    field_order = []
    for klass in (k for k in reversed_mro if is_decorated_with_dataclass(k)):
        for assign_node in _get_dataclass_attributes(klass, init=True):
            name = assign_node.target.name
            if name not in field_assigns:
                field_order.append(name)
            field_assigns[name] = assign_node

    init_str = _generate_dataclass_init(
        [field_assigns[name] for name in field_order])
    try:
        init_node = parse(init_str)["__init__"]
    except AstroidSyntaxError:
        pass
    else:
        init_node.parent = node
        init_node.lineno, init_node.col_offset = None, None
        node.locals["__init__"] = [init_node]

        root = node.root()
        if DEFAULT_FACTORY not in root.locals:
            new_assign = parse(f"{DEFAULT_FACTORY} = object()").body[0]
            new_assign.parent = root
            root.locals[DEFAULT_FACTORY] = [new_assign.targets[0]]
예제 #5
0
def attr_attributes_transform(node: ClassDef) -> None:
    """Given that the ClassNode has an attr decorator,
    rewrite class attributes as instance attributes
    """
    # Astroid can't infer this attribute properly
    # Prevents https://github.com/PyCQA/pylint/issues/1884
    node.locals["__attrs_attrs__"] = [Unknown(parent=node)]

    for cdef_body_node in node.body:
        if not isinstance(cdef_body_node, (Assign, AnnAssign)):
            continue
        if isinstance(cdef_body_node.value, Call):
            if cdef_body_node.value.func.as_string() not in ATTRIB_NAMES:
                continue
        else:
            continue
        targets = (cdef_body_node.targets if hasattr(cdef_body_node, "targets")
                   else [cdef_body_node.target])
        for target in targets:
            rhs_node = Unknown(
                lineno=cdef_body_node.lineno,
                col_offset=cdef_body_node.col_offset,
                parent=cdef_body_node,
            )
            if isinstance(target, AssignName):
                # Could be a subscript if the code analysed is
                # i = Optional[str] = ""
                # See https://github.com/PyCQA/pylint/issues/4439
                node.locals[target.name] = [rhs_node]
                node.instance_attrs[target.name] = [rhs_node]
예제 #6
0
def infer_old_typedDict(  # pylint: disable=invalid-name
    node: ClassDef,
    ctx: typing.Optional[context.InferenceContext] = None
) -> typing.Iterator[ClassDef]:
    func_to_add = extract_node("dict")
    node.locals["__call__"] = [func_to_add]
    return iter([node])
예제 #7
0
def infer_special_alias(
        node: Call,
        ctx: context.InferenceContext = None) -> typing.Iterator[ClassDef]:
    """Infer call to tuple alias as new subscriptable class typing.Tuple."""
    if not (isinstance(node.parent, Assign) and len(node.parent.targets) == 1
            and isinstance(node.parent.targets[0], AssignName)):
        raise UseInferenceDefault
    try:
        res = next(node.args[0].infer(context=ctx))
    except StopIteration as e:
        raise InferenceError(node=node.args[0], context=context) from e

    assign_name = node.parent.targets[0]
    class_def = ClassDef(
        name=assign_name.name,
        parent=node.parent,
    )
    class_def.postinit(bases=[res], body=[], decorators=None)
    func_to_add = extract_node(CLASS_GETITEM_TEMPLATE)
    class_def.locals["__class_getitem__"] = [func_to_add]
    return iter([class_def])
예제 #8
0
def _forbid_class_getitem_access(node: ClassDef) -> None:
    """
    Disable the access to __class_getitem__ method for the node in parameters
    """
    def full_raiser(origin_func, attr, *args, **kwargs):
        """
        Raises an AttributeInferenceError in case of access to __class_getitem__ method.
        Otherwise just call origin_func.
        """
        if attr == "__class_getitem__":
            raise AttributeInferenceError(
                "__class_getitem__ access is not allowed")
        return origin_func(attr, *args, **kwargs)

    try:
        node.getattr("__class_getitem__")
        # If we are here, then we are sure to modify object that do have __class_getitem__ method (which origin is one the
        # protocol defined in collections module) whereas the typing module consider it should not
        # We do not want __class_getitem__ to be found in the classdef
        partial_raiser = partial(full_raiser, node.getattr)
        node.getattr = partial_raiser
    except AttributeInferenceError:
        pass
예제 #9
0
    def test_clear_cache_clears_other_lru_caches(self) -> None:
        lrus = (
            astroid.nodes.node_classes.LookupMixIn.lookup,
            astroid.modutils._cache_normalize_path_,
            util.is_namespace,
            astroid.interpreter.objectmodel.ObjectModel.attributes,
        )

        # Get a baseline for the size of the cache after simply calling bootstrap()
        baseline_cache_infos = [lru.cache_info() for lru in lrus]

        # Generate some hits and misses
        ClassDef().lookup("garbage")
        is_standard_module("unittest", std_path=["garbage_path"])
        util.is_namespace("unittest")
        astroid.interpreter.objectmodel.ObjectModel().attributes()

        # Did the hits or misses actually happen?
        incremented_cache_infos = [lru.cache_info() for lru in lrus]
        for incremented_cache, baseline_cache in zip(incremented_cache_infos,
                                                     baseline_cache_infos):
            with self.subTest(incremented_cache=incremented_cache):
                self.assertGreater(
                    incremented_cache.hits + incremented_cache.misses,
                    baseline_cache.hits + baseline_cache.misses,
                )

        astroid.MANAGER.clear_cache()  # also calls bootstrap()

        # The cache sizes are now as low or lower than the original baseline
        cleared_cache_infos = [lru.cache_info() for lru in lrus]
        for cleared_cache, baseline_cache in zip(cleared_cache_infos,
                                                 baseline_cache_infos):
            with self.subTest(cleared_cache=cleared_cache):
                # less equal because the "baseline" might have had multiple calls to bootstrap()
                self.assertLessEqual(cleared_cache.currsize,
                                     baseline_cache.currsize)
예제 #10
0
def infer_old_typedDict(  # pylint: disable=invalid-name
    node: ClassDef, ctx: context.InferenceContext | None = None
) -> Iterator[ClassDef]:
    func_to_add = _extract_single_node("dict")
    node.locals["__call__"] = [func_to_add]
    return iter([node])
예제 #11
0
 def assert_classes_equal(self, cls: ClassDef, other: ClassDef) -> None:
     self.assertEqual(cls.name, other.name)
     self.assertEqual(cls.parent, other.parent)
     self.assertEqual(cls.qname(), other.qname())