Ejemplo n.º 1
0
    def test_literal_type(self) -> None:
        a = self.fx.a
        lit1 = self.fx.lit1
        lit2 = self.fx.lit2
        lit3 = self.fx.lit3

        self.assert_meet(lit1, lit1, lit1)
        self.assert_meet(lit1, a, lit1)
        self.assert_meet_uninhabited(lit1, lit3)
        self.assert_meet_uninhabited(lit1, lit2)
        self.assert_meet(UnionType([lit1, lit2]), lit1, lit1)
        self.assert_meet(UnionType([lit1, lit2]), UnionType([lit2, lit3]), lit2)
        self.assert_meet(UnionType([lit1, lit2]), UnionType([lit1, lit2]), UnionType([lit1, lit2]))
        self.assert_meet(lit1, self.fx.anyt, lit1)
        self.assert_meet(lit1, self.fx.o, lit1)

        assert is_same_type(lit1, narrow_declared_type(lit1, a))
        assert is_same_type(lit2, narrow_declared_type(lit2, a))
Ejemplo n.º 2
0
    def test_literal_type(self) -> None:
        a = self.fx.a
        d = self.fx.d
        lit1 = LiteralType(1, a)
        lit2 = LiteralType(2, a)
        lit3 = LiteralType("foo", d)

        self.assert_meet(lit1, lit1, lit1)
        self.assert_meet(lit1, a, lit1)
        self.assert_meet_uninhabited(lit1, lit3)
        self.assert_meet_uninhabited(lit1, lit2)
        self.assert_meet(UnionType([lit1, lit2]), lit1, lit1)
        self.assert_meet(UnionType([lit1, lit2]), UnionType([lit2, lit3]), lit2)
        self.assert_meet(UnionType([lit1, lit2]), UnionType([lit1, lit2]), UnionType([lit1, lit2]))
        self.assert_meet(lit1, self.fx.anyt, lit1)
        self.assert_meet(lit1, self.fx.o, lit1)

        assert_true(is_same_type(lit1, narrow_declared_type(lit1, a)))
        assert_true(is_same_type(lit2, narrow_declared_type(lit2, a)))
Ejemplo n.º 3
0
    def visit_class_pattern(self, o: ClassPattern) -> PatternType:
        current_type = get_proper_type(self.type_context[-1])

        #
        # Check class type
        #
        type_info = o.class_ref.node
        assert type_info is not None
        if isinstance(type_info, TypeAlias) and not type_info.no_args:
            self.msg.fail(message_registry.CLASS_PATTERN_GENERIC_TYPE_ALIAS, o)
            return self.early_non_match()
        if isinstance(type_info, TypeInfo):
            any_type = AnyType(TypeOfAny.implementation_artifact)
            typ: Type = Instance(type_info,
                                 [any_type] * len(type_info.defn.type_vars))
        elif isinstance(type_info, TypeAlias):
            typ = type_info.target
        else:
            if isinstance(type_info, Var):
                name = str(type_info.type)
            else:
                name = type_info.name
            self.msg.fail(
                message_registry.CLASS_PATTERN_TYPE_REQUIRED.format(name),
                o.class_ref)
            return self.early_non_match()

        new_type, rest_type = self.chk.conditional_types_with_intersection(
            current_type, [get_type_range(typ)], o, default=current_type)
        if is_uninhabited(new_type):
            return self.early_non_match()
        # TODO: Do I need this?
        narrowed_type = narrow_declared_type(current_type, new_type)

        #
        # Convert positional to keyword patterns
        #
        keyword_pairs: List[Tuple[Optional[str], Pattern]] = []
        match_arg_set: Set[str] = set()

        captures: Dict[Expression, Type] = {}

        if len(o.positionals) != 0:
            if self.should_self_match(typ):
                if len(o.positionals) > 1:
                    self.msg.fail(
                        message_registry.
                        CLASS_PATTERN_TOO_MANY_POSITIONAL_ARGS, o)
                pattern_type = self.accept(o.positionals[0], narrowed_type)
                if not is_uninhabited(pattern_type.type):
                    return PatternType(
                        pattern_type.type,
                        join_types(rest_type, pattern_type.rest_type),
                        pattern_type.captures)
                captures = pattern_type.captures
            else:
                local_errors = self.msg.clean_copy()
                match_args_type = analyze_member_access("__match_args__",
                                                        typ,
                                                        o,
                                                        False,
                                                        False,
                                                        False,
                                                        local_errors,
                                                        original_type=typ,
                                                        chk=self.chk)

                if local_errors.is_errors():
                    self.msg.fail(
                        message_registry.MISSING_MATCH_ARGS.format(typ), o)
                    return self.early_non_match()

                proper_match_args_type = get_proper_type(match_args_type)
                if isinstance(proper_match_args_type, TupleType):
                    match_arg_names = get_match_arg_names(
                        proper_match_args_type)

                    if len(o.positionals) > len(match_arg_names):
                        self.msg.fail(
                            message_registry.
                            CLASS_PATTERN_TOO_MANY_POSITIONAL_ARGS, o)
                        return self.early_non_match()
                else:
                    match_arg_names = [None] * len(o.positionals)

                for arg_name, pos in zip(match_arg_names, o.positionals):
                    keyword_pairs.append((arg_name, pos))
                    if arg_name is not None:
                        match_arg_set.add(arg_name)

        #
        # Check for duplicate patterns
        #
        keyword_arg_set = set()
        has_duplicates = False
        for key, value in zip(o.keyword_keys, o.keyword_values):
            keyword_pairs.append((key, value))
            if key in match_arg_set:
                self.msg.fail(
                    message_registry.CLASS_PATTERN_KEYWORD_MATCHES_POSITIONAL.
                    format(key), value)
                has_duplicates = True
            elif key in keyword_arg_set:
                self.msg.fail(
                    message_registry.CLASS_PATTERN_DUPLICATE_KEYWORD_PATTERN.
                    format(key), value)
                has_duplicates = True
            keyword_arg_set.add(key)

        if has_duplicates:
            return self.early_non_match()

        #
        # Check keyword patterns
        #
        can_match = True
        for keyword, pattern in keyword_pairs:
            key_type: Optional[Type] = None
            local_errors = self.msg.clean_copy()
            if keyword is not None:
                key_type = analyze_member_access(keyword,
                                                 narrowed_type,
                                                 pattern,
                                                 False,
                                                 False,
                                                 False,
                                                 local_errors,
                                                 original_type=new_type,
                                                 chk=self.chk)
            else:
                key_type = AnyType(TypeOfAny.from_error)
            if local_errors.is_errors() or key_type is None:
                key_type = AnyType(TypeOfAny.from_error)
                self.msg.fail(
                    message_registry.CLASS_PATTERN_UNKNOWN_KEYWORD.format(
                        typ, keyword), value)

            inner_type, inner_rest_type, inner_captures = self.accept(
                pattern, key_type)
            if is_uninhabited(inner_type):
                can_match = False
            else:
                self.update_type_map(captures, inner_captures)
                if not is_uninhabited(inner_rest_type):
                    rest_type = current_type

        if not can_match:
            new_type = UninhabitedType()
        return PatternType(new_type, rest_type, captures)