Пример #1
0
def typed_dict_setdefault_callback(ctx: MethodContext) -> Type:
    """Type check TypedDict.setdefault and infer a precise return type."""
    if (isinstance(ctx.type, TypedDictType) and len(ctx.arg_types) == 2
            and len(ctx.arg_types[0]) == 1 and len(ctx.arg_types[1]) == 1):
        keys = try_getting_str_literals(ctx.args[0][0], ctx.arg_types[0][0])
        if keys is None:
            ctx.api.fail(message_registry.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL,
                         ctx.context)
            return AnyType(TypeOfAny.from_error)

        default_type = ctx.arg_types[1][0]

        value_types = []
        for key in keys:
            value_type = ctx.type.items.get(key)

            if value_type is None:
                ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context)
                return AnyType(TypeOfAny.from_error)

            # The signature_callback above can't always infer the right signature
            # (e.g. when the expression is a variable that happens to be a Literal str)
            # so we need to handle the check ourselves here and make sure the provided
            # default can be assigned to all key-value pairs we're updating.
            if not is_subtype(default_type, value_type):
                ctx.api.msg.typeddict_setdefault_arguments_inconsistent(
                    default_type, value_type, ctx.context)
                return AnyType(TypeOfAny.from_error)

            value_types.append(value_type)

        return make_simplified_union(value_types)
    return ctx.default_return_type
Пример #2
0
def typed_dict_get_callback(ctx: MethodContext) -> Type:
    """Infer a precise return type for TypedDict.get with literal first argument."""
    if (isinstance(ctx.type, TypedDictType) and len(ctx.arg_types) >= 1
            and len(ctx.arg_types[0]) == 1):
        keys = try_getting_str_literals(ctx.args[0][0], ctx.arg_types[0][0])
        if keys is None:
            return ctx.default_return_type

        output_types = []  # type: List[Type]
        for key in keys:
            value_type = get_proper_type(ctx.type.items.get(key))
            if value_type is None:
                ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context)
                return AnyType(TypeOfAny.from_error)

            if len(ctx.arg_types) == 1:
                output_types.append(value_type)
            elif (len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1
                  and len(ctx.args[1]) == 1):
                default_arg = ctx.args[1][0]
                if (isinstance(default_arg, DictExpr)
                        and len(default_arg.items) == 0
                        and isinstance(value_type, TypedDictType)):
                    # Special case '{}' as the default for a typed dict type.
                    output_types.append(
                        value_type.copy_modified(required_keys=set()))
                else:
                    output_types.append(value_type)
                    output_types.append(ctx.arg_types[1][0])

        if len(ctx.arg_types) == 1:
            output_types.append(NoneType())

        return make_simplified_union(output_types)
    return ctx.default_return_type
Пример #3
0
def typed_dict_pop_callback(ctx: MethodContext) -> Type:
    """Type check and infer a precise return type for TypedDict.pop."""
    if (isinstance(ctx.type, TypedDictType) and len(ctx.arg_types) >= 1
            and len(ctx.arg_types[0]) == 1):
        keys = try_getting_str_literals(ctx.args[0][0], ctx.arg_types[0][0])
        if keys is None:
            ctx.api.fail(message_registry.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL,
                         ctx.context)
            return AnyType(TypeOfAny.from_error)

        value_types = []
        for key in keys:
            if key in ctx.type.required_keys:
                ctx.api.msg.typeddict_key_cannot_be_deleted(
                    ctx.type, key, ctx.context)

            value_type = ctx.type.items.get(key)
            if value_type:
                value_types.append(value_type)
            else:
                ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context)
                return AnyType(TypeOfAny.from_error)

        if len(ctx.args[1]) == 0:
            return make_simplified_union(value_types)
        elif (len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1
              and len(ctx.args[1]) == 1):
            return make_simplified_union([*value_types, ctx.arg_types[1][0]])
    return ctx.default_return_type
Пример #4
0
def typed_dict_delitem_callback(ctx: MethodContext) -> Type:
    """Type check TypedDict.__delitem__."""
    if (isinstance(ctx.type, TypedDictType)
            and len(ctx.arg_types) == 1
            and len(ctx.arg_types[0]) == 1):
        keys = try_getting_str_literals(ctx.args[0][0], ctx.arg_types[0][0])
        if keys is None:
            ctx.api.fail(message_registry.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL, ctx.context)
            return AnyType(TypeOfAny.from_error)

        for key in keys:
            if key in ctx.type.required_keys:
                ctx.api.msg.typeddict_key_cannot_be_deleted(ctx.type, key, ctx.context)
            elif key not in ctx.type.items:
                ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context)
    return ctx.default_return_type