예제 #1
0
def _access_member(typ: Type, name: str, ctx: MethodContext) -> Type:
    assert isinstance(ctx.api, checker.TypeChecker)
    return checkmember.analyze_member_access(
        name=name,
        typ=typ,
        context=ctx.context,
        is_lvalue=False,
        is_super=False,
        is_operator=False,
        msg=ctx.api.msg,
        original_type=typ,
        chk=ctx.api,
    )
예제 #2
0
파일: kind.py 프로젝트: nurumaik/returns
def attribute_access(ctx: AttributeContext) -> MypyType:
    """
    Ensures that attribute access to ``KindN`` is correct.

    In other words:

    .. code:: python

        from typing import TypeVar
        from returns.primitives.hkt import KindN
        from returns.interfaces.mappable import MappableN

        _MappableType = TypeVar('_MappableType', bound=MappableN)

        kind: KindN[_MappableType, int, int, int]
        reveal_type(kind.map)  # will work correctly!

    """
    assert isinstance(ctx.type, Instance)
    instance = ctx.type.args[0]

    if isinstance(instance, TypeVarType):
        bound = get_proper_type(instance.upper_bound)
        assert isinstance(bound, Instance)
        accessed = bound.copy_modified(args=_crop_kind_args(
            ctx.type, bound.args), )
    elif isinstance(instance, Instance):
        accessed = instance.copy_modified(args=_crop_kind_args(ctx.type))
    else:
        return ctx.default_attr_type

    exprchecker = ctx.api.expr_checker  # type: ignore
    member_type = analyze_member_access(
        ctx.context.name,  # type: ignore
        accessed,
        ctx.context,
        is_lvalue=False,
        is_super=False,
        is_operator=False,
        msg=ctx.api.msg,
        original_type=instance,
        chk=ctx.api,  # type: ignore
        in_literal_context=exprchecker.is_literal_context(),
    )
    if isinstance(member_type, CallableType):
        return detach_callable(member_type)
    return member_type
예제 #3
0
def _get_callable_type(type_: Type,
                       context: FunctionContext) -> t.Optional[CallableType]:
    if isinstance(type_, CallableType):
        return type_
        # called with an object
    elif isinstance(type_, Instance) and type_.has_readable_member('__call__'):
        chk: TypeChecker = t.cast(TypeChecker, context.api)
        return t.cast(
            CallableType,
            checkmember.analyze_member_access('__call__',
                                              type_,
                                              context.context,
                                              False,
                                              False,
                                              False,
                                              context.api.msg,
                                              original_type=type_,
                                              chk=chk))
    return None
예제 #4
0
def analyze(ctx: FunctionContext) -> MypyType:
    """
    Analyzes several pointfree functions.

    Removes intermediate Protocol instances.
    """
    callee = ctx.default_return_type
    checker = ctx.api.expr_checker  # type: ignore
    return analyze_member_access(
        '__call__',
        callee,
        ctx.context,
        is_lvalue=False,
        is_super=False,
        is_operator=True,
        msg=checker.msg,
        original_type=callee,
        chk=checker.chk,
        in_literal_context=checker.is_literal_context(),
    )
예제 #5
0
def translate_to_function(
    function_def: MypyType,
    ctx: CallableContext,
) -> MypyType:
    """
    Tryies to translate a type into callable by accessing ``__call__`` attr.

    This might fail with ``mypy`` errors and that's how must work.
    This also preserves all type arguments as-is.
    """
    checker = ctx.api.expr_checker  # type: ignore
    return analyze_member_access(
        '__call__',
        function_def,
        ctx.context,
        is_lvalue=False,
        is_super=False,
        is_operator=True,
        msg=checker.msg,
        original_type=function_def,
        chk=checker.chk,
        in_literal_context=checker.is_literal_context(),
    )
예제 #6
0
    def visit_class_pattern(self, o: ClassPattern) -> PatternType:
        current_type = get_proper_type(self.type_context[-1])

        #
        # Check class type
        #
        type_info = o.class_ref.node
        assert type_info is not None
        if isinstance(type_info, TypeAlias) and not type_info.no_args:
            self.msg.fail(message_registry.CLASS_PATTERN_GENERIC_TYPE_ALIAS, o)
            return self.early_non_match()
        if isinstance(type_info, TypeInfo):
            any_type = AnyType(TypeOfAny.implementation_artifact)
            typ: Type = Instance(type_info,
                                 [any_type] * len(type_info.defn.type_vars))
        elif isinstance(type_info, TypeAlias):
            typ = type_info.target
        else:
            if isinstance(type_info, Var):
                name = str(type_info.type)
            else:
                name = type_info.name
            self.msg.fail(
                message_registry.CLASS_PATTERN_TYPE_REQUIRED.format(name),
                o.class_ref)
            return self.early_non_match()

        new_type, rest_type = self.chk.conditional_types_with_intersection(
            current_type, [get_type_range(typ)], o, default=current_type)
        if is_uninhabited(new_type):
            return self.early_non_match()
        # TODO: Do I need this?
        narrowed_type = narrow_declared_type(current_type, new_type)

        #
        # Convert positional to keyword patterns
        #
        keyword_pairs: List[Tuple[Optional[str], Pattern]] = []
        match_arg_set: Set[str] = set()

        captures: Dict[Expression, Type] = {}

        if len(o.positionals) != 0:
            if self.should_self_match(typ):
                if len(o.positionals) > 1:
                    self.msg.fail(
                        message_registry.
                        CLASS_PATTERN_TOO_MANY_POSITIONAL_ARGS, o)
                pattern_type = self.accept(o.positionals[0], narrowed_type)
                if not is_uninhabited(pattern_type.type):
                    return PatternType(
                        pattern_type.type,
                        join_types(rest_type, pattern_type.rest_type),
                        pattern_type.captures)
                captures = pattern_type.captures
            else:
                local_errors = self.msg.clean_copy()
                match_args_type = analyze_member_access("__match_args__",
                                                        typ,
                                                        o,
                                                        False,
                                                        False,
                                                        False,
                                                        local_errors,
                                                        original_type=typ,
                                                        chk=self.chk)

                if local_errors.is_errors():
                    self.msg.fail(
                        message_registry.MISSING_MATCH_ARGS.format(typ), o)
                    return self.early_non_match()

                proper_match_args_type = get_proper_type(match_args_type)
                if isinstance(proper_match_args_type, TupleType):
                    match_arg_names = get_match_arg_names(
                        proper_match_args_type)

                    if len(o.positionals) > len(match_arg_names):
                        self.msg.fail(
                            message_registry.
                            CLASS_PATTERN_TOO_MANY_POSITIONAL_ARGS, o)
                        return self.early_non_match()
                else:
                    match_arg_names = [None] * len(o.positionals)

                for arg_name, pos in zip(match_arg_names, o.positionals):
                    keyword_pairs.append((arg_name, pos))
                    if arg_name is not None:
                        match_arg_set.add(arg_name)

        #
        # Check for duplicate patterns
        #
        keyword_arg_set = set()
        has_duplicates = False
        for key, value in zip(o.keyword_keys, o.keyword_values):
            keyword_pairs.append((key, value))
            if key in match_arg_set:
                self.msg.fail(
                    message_registry.CLASS_PATTERN_KEYWORD_MATCHES_POSITIONAL.
                    format(key), value)
                has_duplicates = True
            elif key in keyword_arg_set:
                self.msg.fail(
                    message_registry.CLASS_PATTERN_DUPLICATE_KEYWORD_PATTERN.
                    format(key), value)
                has_duplicates = True
            keyword_arg_set.add(key)

        if has_duplicates:
            return self.early_non_match()

        #
        # Check keyword patterns
        #
        can_match = True
        for keyword, pattern in keyword_pairs:
            key_type: Optional[Type] = None
            local_errors = self.msg.clean_copy()
            if keyword is not None:
                key_type = analyze_member_access(keyword,
                                                 narrowed_type,
                                                 pattern,
                                                 False,
                                                 False,
                                                 False,
                                                 local_errors,
                                                 original_type=new_type,
                                                 chk=self.chk)
            else:
                key_type = AnyType(TypeOfAny.from_error)
            if local_errors.is_errors() or key_type is None:
                key_type = AnyType(TypeOfAny.from_error)
                self.msg.fail(
                    message_registry.CLASS_PATTERN_UNKNOWN_KEYWORD.format(
                        typ, keyword), value)

            inner_type, inner_rest_type, inner_captures = self.accept(
                pattern, key_type)
            if is_uninhabited(inner_type):
                can_match = False
            else:
                self.update_type_map(captures, inner_captures)
                if not is_uninhabited(inner_rest_type):
                    rest_type = current_type

        if not can_match:
            new_type = UninhabitedType()
        return PatternType(new_type, rest_type, captures)
        def process_gql_schema(ctx: AttributeContext) -> Type:
            """
            Actually perform the type-checking logic for each graphene `ObjectType` child class.
            The plugin is invoked at type-checking time.
            """

            assert isinstance(ctx.type, Instance)
            object_info = self._graphene_objects[ctx.type.type.fullname]
            all_fields = object_info.all_fields if isinstance(
                object_info, ObjectTypeInfo) else object_info.fields

            # Check that resolver methods are annotated with the correct types
            for resolver in object_info.resolvers.values():
                gql_field = all_fields.get(resolver.field_name)

                if not gql_field:
                    if isinstance(
                            object_info,
                            InterfaceInfo) and resolver.field_name == 'type':
                        # This is not a field resolver. It is the special `Interface` resolver that determines which
                        # `ObjectType` to use at runtime.
                        continue

                    ctx.api.fail(
                        f'No field with name "{resolver.field_name}" defined',
                        resolver.context)
                    continue

                # Check that the resolver's "previous" (first) argument has the correct type
                if not is_equivalent(resolver.previous_argument.type,
                                     object_info.runtime_type):
                    ctx.api.fail(
                        _get_type_mismatch_error_message(
                            resolver.previous_argument.name,
                            graphene_type=object_info.runtime_type,
                            resolver_type=resolver.previous_argument.type,
                        ),
                        resolver.previous_argument.context,
                    )
                    continue

                # Check that the resolver returns the correct type
                if not is_subtype(resolver.return_type, gql_field.type):
                    ctx.api.fail(
                        f'Resolver returns type {resolver.return_type}, expected type {gql_field.type}',
                        resolver.context,
                    )
                    continue

                for field_argument in gql_field.arguments.values():
                    resolver_argument = resolver.arguments.get(
                        field_argument.name)

                    # Check that the resolver has an argument for each argument the `Field()` defines
                    if not resolver_argument:
                        ctx.api.fail(
                            f'Parameter "{field_argument.name}" of type {field_argument.type} is missing,'
                            ' but required in resolver definition',
                            resolver.context,
                        )
                        continue

                    # Check that the resolver's argument has the correct type annotation
                    if not is_equivalent(field_argument.type,
                                         resolver_argument.type):
                        ctx.api.fail(
                            _get_type_mismatch_error_message(
                                field_argument.name,
                                graphene_type=field_argument.type,
                                resolver_type=resolver_argument.type,
                            ),
                            resolver_argument.context,
                        )
                        continue

            # If no resolver function is defined, type-check the behavior of the graphene default resolver
            if isinstance(object_info, ObjectTypeInfo):
                # The default resolver doesn't apply to `Interface`s because the `ObjectType`s that implement them could
                # have resolvers for their fields.
                # TODO: Detect if any of an `Interface`'s `ObjectType`s do _not_ define their own resolver for this
                # field. In that case, we _do_ want to type-check the default resolver.
                fields_without_resolver_names = set(
                    object_info.fields.keys()) - set(
                        object_info.resolvers.keys())
                for name in fields_without_resolver_names:
                    gql_field = object_info.fields[name]
                    # Note: `analyze_member_access` will call `ctx.api.fail()` if the provided type doesn't have
                    # a member with the given name at all. So our code only needs to do the subtype check.
                    default_resolver_return_type = analyze_member_access(
                        gql_field.name,
                        object_info.runtime_type,
                        gql_field.context,
                        False,  # is_lvalue
                        False,  # is_super
                        False,  # is_operator
                        ctx.api.msg,
                        original_type=object_info.runtime_type,
                        chk=cast(TypeChecker, ctx.api),
                    )
                    if not is_subtype(default_resolver_return_type,
                                      gql_field.type):
                        ctx.api.fail(
                            f'Field expects type {gql_field.type} but {object_info.runtime_type}.{gql_field.name} has '
                            f'type {default_resolver_return_type}',
                            gql_field.context,
                        )
                    continue

            return ctx.default_attr_type