def typed_dict_setdefault_callback(ctx: MethodContext) -> Type: """Type check TypedDict.setdefault and infer a precise return type.""" if (isinstance(ctx.type, TypedDictType) and len(ctx.arg_types) == 2 and len(ctx.arg_types[0]) == 1 and len(ctx.arg_types[1]) == 1): keys = try_getting_str_literals(ctx.args[0][0], ctx.arg_types[0][0]) if keys is None: ctx.api.fail(message_registry.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL, ctx.context) return AnyType(TypeOfAny.from_error) default_type = ctx.arg_types[1][0] value_types = [] for key in keys: value_type = ctx.type.items.get(key) if value_type is None: ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context) return AnyType(TypeOfAny.from_error) # The signature_callback above can't always infer the right signature # (e.g. when the expression is a variable that happens to be a Literal str) # so we need to handle the check ourselves here and make sure the provided # default can be assigned to all key-value pairs we're updating. if not is_subtype(default_type, value_type): ctx.api.msg.typeddict_setdefault_arguments_inconsistent( default_type, value_type, ctx.context) return AnyType(TypeOfAny.from_error) value_types.append(value_type) return make_simplified_union(value_types) return ctx.default_return_type
def typed_dict_get_callback(ctx: MethodContext) -> Type: """Infer a precise return type for TypedDict.get with literal first argument.""" if (isinstance(ctx.type, TypedDictType) and len(ctx.arg_types) >= 1 and len(ctx.arg_types[0]) == 1): keys = try_getting_str_literals(ctx.args[0][0], ctx.arg_types[0][0]) if keys is None: return ctx.default_return_type output_types = [] # type: List[Type] for key in keys: value_type = get_proper_type(ctx.type.items.get(key)) if value_type is None: ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context) return AnyType(TypeOfAny.from_error) if len(ctx.arg_types) == 1: output_types.append(value_type) elif (len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1 and len(ctx.args[1]) == 1): default_arg = ctx.args[1][0] if (isinstance(default_arg, DictExpr) and len(default_arg.items) == 0 and isinstance(value_type, TypedDictType)): # Special case '{}' as the default for a typed dict type. output_types.append( value_type.copy_modified(required_keys=set())) else: output_types.append(value_type) output_types.append(ctx.arg_types[1][0]) if len(ctx.arg_types) == 1: output_types.append(NoneType()) return make_simplified_union(output_types) return ctx.default_return_type
def typed_dict_pop_callback(ctx: MethodContext) -> Type: """Type check and infer a precise return type for TypedDict.pop.""" if (isinstance(ctx.type, TypedDictType) and len(ctx.arg_types) >= 1 and len(ctx.arg_types[0]) == 1): keys = try_getting_str_literals(ctx.args[0][0], ctx.arg_types[0][0]) if keys is None: ctx.api.fail(message_registry.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL, ctx.context) return AnyType(TypeOfAny.from_error) value_types = [] for key in keys: if key in ctx.type.required_keys: ctx.api.msg.typeddict_key_cannot_be_deleted( ctx.type, key, ctx.context) value_type = ctx.type.items.get(key) if value_type: value_types.append(value_type) else: ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context) return AnyType(TypeOfAny.from_error) if len(ctx.args[1]) == 0: return make_simplified_union(value_types) elif (len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1 and len(ctx.args[1]) == 1): return make_simplified_union([*value_types, ctx.arg_types[1][0]]) return ctx.default_return_type
def typed_dict_delitem_callback(ctx: MethodContext) -> Type: """Type check TypedDict.__delitem__.""" if (isinstance(ctx.type, TypedDictType) and len(ctx.arg_types) == 1 and len(ctx.arg_types[0]) == 1): keys = try_getting_str_literals(ctx.args[0][0], ctx.arg_types[0][0]) if keys is None: ctx.api.fail(message_registry.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL, ctx.context) return AnyType(TypeOfAny.from_error) for key in keys: if key in ctx.type.required_keys: ctx.api.msg.typeddict_key_cannot_be_deleted(ctx.type, key, ctx.context) elif key not in ctx.type.items: ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context) return ctx.default_return_type