def typed_dict_get_callback(ctx: MethodContext) -> Type: """Infer a precise return type for TypedDict.get with literal first argument.""" if (isinstance(ctx.type, TypedDictType) and len(ctx.arg_types) >= 1 and len(ctx.arg_types[0]) == 1): key = try_getting_str_literal(ctx.args[0][0], ctx.arg_types[0][0]) if key is None: return ctx.default_return_type value_type = ctx.type.items.get(key) if value_type: if len(ctx.arg_types) == 1: return UnionType.make_simplified_union([value_type, NoneTyp()]) elif (len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1 and len(ctx.args[1]) == 1): default_arg = ctx.args[1][0] if (isinstance(default_arg, DictExpr) and len(default_arg.items) == 0 and isinstance(value_type, TypedDictType)): # Special case '{}' as the default for a typed dict type. return value_type.copy_modified(required_keys=set()) else: return UnionType.make_simplified_union([value_type, ctx.arg_types[1][0]]) else: ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context) return AnyType(TypeOfAny.from_error) return ctx.default_return_type
def typed_dict_pop_callback(ctx: MethodContext) -> Type: """Type check and infer a precise return type for TypedDict.pop.""" if (isinstance(ctx.type, TypedDictType) and len(ctx.arg_types) >= 1 and len(ctx.arg_types[0]) == 1): key = try_getting_str_literal(ctx.args[0][0], ctx.arg_types[0][0]) if key is None: ctx.api.fail(message_registry.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL, ctx.context) return AnyType(TypeOfAny.from_error) if key in ctx.type.required_keys: ctx.api.msg.typeddict_key_cannot_be_deleted( ctx.type, key, ctx.context) value_type = ctx.type.items.get(key) if value_type: if len(ctx.args[1]) == 0: return value_type elif (len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1 and len(ctx.args[1]) == 1): return UnionType.make_simplified_union( [value_type, ctx.arg_types[1][0]]) else: ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context) return AnyType(TypeOfAny.from_error) return ctx.default_return_type
def typed_dict_delitem_callback(ctx: MethodContext) -> Type: """Type check TypedDict.__delitem__.""" if (isinstance(ctx.type, TypedDictType) and len(ctx.arg_types) == 1 and len(ctx.arg_types[0]) == 1): key = try_getting_str_literal(ctx.args[0][0], ctx.arg_types[0][0]) if key is None: ctx.api.fail(message_registry.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL, ctx.context) return AnyType(TypeOfAny.from_error) if key in ctx.type.required_keys: ctx.api.msg.typeddict_key_cannot_be_deleted(ctx.type, key, ctx.context) elif key not in ctx.type.items: ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context) return ctx.default_return_type
def typed_dict_setdefault_callback(ctx: MethodContext) -> Type: """Type check TypedDict.setdefault and infer a precise return type.""" if (isinstance(ctx.type, TypedDictType) and len(ctx.arg_types) == 2 and len(ctx.arg_types[0]) == 1): key = try_getting_str_literal(ctx.args[0][0], ctx.arg_types[0][0]) if key is None: ctx.api.fail(message_registry.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL, ctx.context) return AnyType(TypeOfAny.from_error) value_type = ctx.type.items.get(key) if value_type: return value_type else: ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context) return AnyType(TypeOfAny.from_error) return ctx.default_return_type
def typed_dict_pop_callback(ctx: MethodContext) -> Type: """Type check and infer a precise return type for TypedDict.pop.""" if (isinstance(ctx.type, TypedDictType) and len(ctx.arg_types) >= 1 and len(ctx.arg_types[0]) == 1): key = try_getting_str_literal(ctx.args[0][0], ctx.arg_types[0][0]) if key is None: ctx.api.fail(message_registry.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL, ctx.context) return AnyType(TypeOfAny.from_error) if key in ctx.type.required_keys: ctx.api.msg.typeddict_key_cannot_be_deleted(ctx.type, key, ctx.context) value_type = ctx.type.items.get(key) if value_type: if len(ctx.args[1]) == 0: return value_type elif (len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1 and len(ctx.args[1]) == 1): return UnionType.make_simplified_union([value_type, ctx.arg_types[1][0]]) else: ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context) return AnyType(TypeOfAny.from_error) return ctx.default_return_type