def _field_to_schema(field: Field) -> SchemaField: field_type = _BASIC_TYPES_TO_NAME.get(field.type) if field_type: return SchemaField( name=field.name, field_type=field_type, description=_parse_field_description(field), mode=BigQueryFieldModes.REQUIRED, ) if is_dataclass(field.type): return SchemaField( name=field.name, field_type=BigQueryTypes.STRUCT, mode=BigQueryFieldModes.REQUIRED, description=_parse_field_description(field), fields=_parse_fields(field.type), ) # typing.Optional is the same as typing.Union[SomeType, NoneType] if get_origin(field.type) is Union: return _parse_optional(field) if get_origin(field.type) is list: return _parse_list(field) raise TypeError(f"Unsupported type: {field.type}.")
def _iter_imports(hint) -> Iterator[str]: """Get all imports necessary for `hint`""" # inspect.formatannotation strips "typing." from type annotations # so our signatures won't have it in there if not repr(hint).startswith("typing."): if orig := get_origin(hint): yield orig.__module__
def get_dims(type_: Type[DataArrayLike[D, T]]) -> Tuple[str, ...]: """Extract dimensions (dims) from DataArrayLike[D, T].""" if get_origin(type_) is Annotated: type_ = get_args(type_)[0] dims_ = get_args(get_args(type_)[0])[0] if get_origin(dims_) is tuple: dims_ = get_args(dims_) else: dims_ = (dims_, ) dims: List[str] = [] for dim_ in dims_: if dim_ == () or dim_ is NoneType: continue if isinstance(dim_, ForwardRef): dims.append(dim_.__forward_arg__) continue if get_origin(dim_) is Literal: dims.append(str(get_args(dim_)[0])) continue raise TypeError("Could not extract dimension.") return tuple(dims)
def small_validate(value, allowed_type: None | Any = None, name: str | None = None) -> None: """Type validation. It also works for Union and validate Literal values. Instead of typeguard validation, it define just subset of types, but is simplier and needs no extra import, therefore can be faster. Args: value (Any): Value that will be validated. allowed_type (Any, optional): For example int, str or list. It can be also Union or Literal. If Literal, validated value has to be one of Literal values. If None, it's skipped. Defaults to None. name (str | None, optional): If error raised, name will be printed. Defaults to None. Raises: TypeError: Type does not fit. Examples: >>> from typing_extensions import Literal ... >>> small_validate(1, int) >>> small_validate(None, Union[list, None]) >>> small_validate("two", Literal["one", "two"]) >>> small_validate("three", Literal["one", "two"]) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ValidationError: ... """ if allowed_type: # If Union if get_origin(allowed_type) == Union: if type(value) in get_args(allowed_type): return else: raise ValidationError( mylogging.format_str( f"Allowed type for variable '{name}' are {allowed_type}, but you try to set an {type(value)}" ) ) # If Literal - parse options elif get_origin(allowed_type) == Literal: options = getattr(allowed_type, "__args__") if value in options: return else: raise ValidationError( f"New value < {value} > for variable < {name} > is not in allowed options {options}." ) else: if isinstance(value, allowed_type): # type: ignore return else: raise ValidationError( f"Allowed allowed_type for variable < {name} > is {allowed_type}, but you try to set an {type(value)}" )
def flatten_greedy(item: T | Greedy[Any]) -> Generator[T, None, None]: if get_origin(item) in (Greedy, Union): for arg in get_args(item): if arg in INVALID_GREEDY_TYPES: raise TypeError(f"Greedy[{arg.__name__}] is invalid") if get_origin(arg) in (Greedy, Union): yield from flatten_greedy(arg) else: yield arg else: yield item
def add_api_route( self, path: str, endpoint: Callable[..., Any], **kwargs: Any, ) -> None: return_type = get_type_hints(endpoint)["return"] success_result, failure_result = return_type.__args__ endpoint = self._unpacked_container(endpoint) responses = {} if get_origin(failure_result) is Union: for error in failure_result.__args__: annotation = error.__annotations__["detail"] responses[error.status_code] = {"model": ApiErrorSchema[annotation]} # type: ignore[valid-type] else: annotation = failure_result.__annotations__["detail"] responses[failure_result.status_code] = {"model": ApiErrorSchema[annotation]} # type: ignore[valid-type] if kwargs["response_model"] is None: kwargs["response_model"] = success_result if kwargs["responses"] is None: kwargs["responses"] = responses return super().add_api_route(path, endpoint, **kwargs)
def callback(self, function: CallbackType[P]) -> None: if not inspect.iscoroutinefunction(function): raise TypeError( f"The callback for the command {function.__name__!r} must be a coroutine function." ) function = function.__func__ if inspect.ismethod( function) else function # HelpCommand.command_callback annotations = get_type_hints(function) for name, annotation in annotations.items(): if get_origin(annotation) is converters.Greedy and isinstance( annotation.converter, ForwardRef): annotations[name] = converters.Greedy[eval( annotation.converter.__forward_code__, function.__globals__)] function.__annotations__ = annotations self.params: dict[str, inspect.Parameter] = dict( inspect.signature(function).parameters) if not self.params: raise ClientException( f'Callback for {self.name} command is missing a "ctx" parameter.' ) from None self.module = function.__module__ self._callback = function
def decode(self, msg: RosMessage, buffer: bytes, offset: int = 0) -> int: for name, t in get_type_hints(msg, include_extras=True).items(): # type: ignore if get_origin(t) is Final: continue val, offset = self._decode_value(msg, buffer, offset, t) setattr(msg, name, val) return offset
def args_to_matchers( function: Callable[..., Any], ) -> Dict[str, libcst_typing.Matcher]: """Extract node matchers from the function arguments. Untyped arguments will get a `DoNotCare` matcher, while arguments typed and annoated with a `BaseMatcherNode` will return that matcher. """ matchers: Dict[str, libcst_typing.Matcher] = {} # Create default matchers for all arguments args = function.__code__.co_varnames[:function.__code__.co_argcount] if args: for name in args: if name == "self": continue matchers[name] = libcst.matchers.DoNotCare() # Check if any of the arguments was annotated with a Matcher for name, type_declaration in typing_extensions.get_type_hints( function, include_extras=True).items(): if typing_extensions.get_origin(type_declaration) is Annotated: # Check if there is a matcher arg = typing_extensions.get_args(type_declaration)[1] if isinstance(arg, libcst.matchers.BaseMatcherNode): matchers[name] = arg return matchers
def is_private(type_: object) -> bool: if get_origin(type_) is Annotated: return any( isinstance(argument, StrawberryPrivate) for argument in get_args(type_) ) return False
def _normalize_type(value: Any, annotation: Any) -> type: """Return annotation type origin or dtype of value.""" if annotation: if annotation is inspect.Parameter.empty: return type(value) return get_origin(annotation) or annotation return type(value)
def get_dtype(type_: Type[DataArrayLike[D, T]]) -> Union[type, str, None]: """Extract a data type (dtype) from DataArrayLike[D, T].""" if get_origin(type_) is Annotated: type_ = get_args(type_)[0] dtype_ = get_args(get_args(type_)[0])[1] if dtype_ is Any: return None if isinstance(dtype_, ForwardRef): return dtype_.__forward_arg__ if get_origin(dtype_) is Literal: return get_args(dtype_)[0] return cast(type, dtype_)
def encode(self, msg: RosMessage, buffer: Optional[bytearray] = None) -> bytearray: if buffer is None: buffer = bytearray() for name, t in get_type_hints(msg, include_extras=True).items(): # type: ignore origin = get_origin(t) if origin is Final: continue def_factory = t if origin is Annotated and get_origin(get_args(t)[0]) is list: def_factory = list val = getattr(msg, name, def_factory()) self._encode_value(val, t, buffer) return buffer
def get_origin(field_type: Type[Any]) -> Optional[Type[Any]]: """Generalized and robust get_origin function. This function is derived from work by pydantic, however, avoids complications from various python versions. """ # This executes a fallback that allows a list to be generated from a constrained list. return typing_extensions.get_origin(field_type) or getattr(field_type, '__origin__', None)
def extract_config(t: Type[TensorFlow2ONNX]) -> Tuple[Type[TensorFlow2ONNX], TensorFlow2ONNXConfig]: config = None if get_origin(t) is Annotated: base_type, config = get_args(t) if isinstance(config, TensorFlow2ONNXConfig): return base_type, config else: raise TypeTransformerFailedError(f"{t}'s config isn't of type TensorFlow2ONNX") return t, config
def _decode_value(self, msg: RosMessage, buffer: bytes, offset, typ): o = get_origin(typ) val: Any = None if o is None and issubclass(typ, RosMessage): val = typ() offset = self.decode(val, buffer, offset) return val, offset base, size, signed = get_args(typ) if base == str: l = int.from_bytes(buffer[offset : offset + 4], "little", signed=False) val = buffer[offset + 4 : offset + 4 + l].decode("utf-8") offset += 4 + l elif base == int: val = int.from_bytes(buffer[offset : offset + size], "little", signed=signed) offset += size elif base == float: val, *_ = struct.unpack("<f" if size == 4 else "<d", buffer[offset : offset + size]) offset += size elif base == bytes: if size == 0: size = int.from_bytes(buffer[offset : offset + 4], "little", signed=False) offset += 4 val = buffer[offset : offset + size] offset += size elif base == Time or base == Duration: secs = int.from_bytes(buffer[offset : offset + 4], "little", signed=False) nsecs = int.from_bytes(buffer[offset + 4 : offset + 8], "little", signed=False) val = base(secs, nsecs) offset += 8 elif get_origin(base) == list: t, *_ = get_args(base) if size == 0: size = int.from_bytes(buffer[offset : offset + 4], "little", signed=False) offset += 4 val = [] for _ in range(size): v, offset = self._decode_value(msg, buffer, offset, t) val.append(v) return val, offset
async def _convert( self, ctx: Context, converter: converters.Converters, param: inspect.Parameter, argument: str, ) -> Any: if isinstance(converter, converters.ConverterBase): if isinstance(converter, type): # needs to be instantiated converter = converter() try: return await converter.convert(ctx, argument) except Exception as exc: try: name = converter.__name__ except AttributeError: name = converter.__class__.__name__ raise BadArgument( f"{argument!r} failed to convert to {name}") from exc origin = get_origin(converter) if origin is not None: args = get_args(converter) for arg in args: converter = self._get_converter(arg) try: ret = await self._convert(ctx, converter, param, argument) except BadArgument: if origin is not Union: raise else: if origin is not Literal: return ret if arg == ret: return ret if origin is Union and args[-1] is type(None): # typing.Optional try: return self._get_default( ctx, param) # get the default if possible except MissingRequiredArgument: return None # fall back to None if origin is Literal: raise BadArgument( f"Expected one of {', '.join(args)} not {argument!r}") raise BadArgument(f"Failed to parse {argument!r} to any type") try: return converter(argument) except Exception as exc: try: name = converter.__name__ except AttributeError: name = converter.__class__.__name__ raise BadArgument( f"{argument!r} failed to convert to {name}") from exc
def _eval_type(type: Any, globals: Dict[str, Any]) -> Any: """Evaluate all forward reverences in the given type.""" if isinstance(type, str): type = ForwardRef(type) if isinstance(type, ForwardRef): return type._evaluate(globals, {}) if isinstance(type, _GenericAlias): args = tuple(_eval_type(arg, globals) for arg in get_args(type)) return get_origin(type)[args] return type
def _is_valid_typeddict_item( td: type[TypedDict], key: str, value: Any # type: ignore [valid-type] ) -> bool: """Check if `key` and `value` form a valid item for the TypedDict `td`.""" annotations = get_type_hints(td) if key not in annotations: return False if get_origin(annotations[key]) is Literal: return value in get_args(annotations[key]) return isinstance(value, annotations[key])
def annotation(self): """Return type annotation for the parameter represented by the widget. ForwardRefs will be resolve when setting the annotation. If the widget is nullable (had a type annototation of Optional[Type]), annotation will return the first argument in the Optional clause. """ annotation = Widget.annotation.fget(self) # type: ignore if self._nullable and get_origin(annotation) is Union: return get_args(annotation)[0] return annotation
def is_str_literal(hint: Any) -> bool: """Check if a type hint is Literal[str].""" args: Any = get_args(hint) origin = get_origin(hint) if origin is not Literal: return False if not len(args) == 1: return False return isinstance(args[0], str)
def extract_cols_and_format( t: typing.Any, ) -> typing.Tuple[Type[T], Optional[typing.OrderedDict[str, Type]], Optional[str], Optional[pa.lib.Schema]]: """ Helper function, just used to iterate through Annotations and extract out the following information: - base type, if not Annotated, it will just be the type that was passed in. - column information, as a collections.OrderedDict, - the storage format, as a ``StructuredDatasetFormat`` (str), - pa.lib.Schema If more than one of any type of thing is found, an error will be raised. If no instances of a given type are found, then None will be returned. If we add more things, we should put all the returned items in a dataclass instead of just a tuple. :param t: The incoming type which may or may not be Annotated :return: Tuple representing the original type, optional OrderedDict of columns, optional str for the format, optional pyarrow Schema """ fmt = None ordered_dict_cols = None pa_schema = None if get_origin(t) is Annotated: base_type, *annotate_args = get_args(t) for aa in annotate_args: if isinstance(aa, StructuredDatasetFormat): if fmt is not None: raise ValueError( f"A format was already specified {fmt}, cannot use {aa}" ) fmt = aa elif isinstance(aa, collections.OrderedDict): if ordered_dict_cols is not None: raise ValueError( f"Column information was already found {ordered_dict_cols}, cannot use {aa}" ) ordered_dict_cols = aa elif isinstance(aa, pyarrow.Schema): if pa_schema is not None: raise ValueError( f"Arrow schema was already found {pa_schema}, cannot use {aa}" ) pa_schema = aa return base_type, ordered_dict_cols, fmt, pa_schema # We return None as the format instead of parquet or something because the transformer engine may find # a better default for the given dataframe type. return t, ordered_dict_cols, fmt, pa_schema
async def _convert( self, ctx: "Context", converter: Union[converters.Converter, type, Callable[[str], Any]], param: inspect.Parameter, argument: str, ) -> Any: if isinstance(converter, converters.Converter): try: converter = converter() if callable(converter) else converter return await converter.convert(ctx, argument) except Exception as exc: try: name = converter.__name__ except AttributeError: name = converter.__class__.__name__ raise BadArgument( f"{argument} failed to convert to {name}") from exc else: if converter is bool: return to_bool(argument) origin = get_origin(converter) if origin is not None: for arg in get_args(converter): converter = self._get_converter(arg) try: return await self._convert(ctx, converter, argument) except BadArgument: if origin is Union: continue raise else: if origin is Union and type(None) in get_args( converter): # typing.Optional try: return self._get_default( ctx, param) # get the default if possible except MissingRequiredArgument: return None # fall back to None raise BadArgument( f"Failed to parse {argument} to any type") try: return converter(argument) except Exception as exc: try: name = converter.__name__ except AttributeError: name = converter.__class__.__name__ raise BadArgument( f"{argument!r} failed to convert to {name}") from exc
def _normalize_type(value: Any, annotation: Any) -> tuple[type, bool]: """Return annotation type origin or dtype of value.""" if not annotation: return type(value), False if annotation is inspect.Parameter.empty: return type(value), False # look for Optional[Type], which manifests as Union[Type, None] origin = get_origin(annotation) args = get_args(annotation) if origin is Union and len(args) == 2 and type(None) in args: type_ = next(i for i in args if not issubclass(i, type(None))) return type_, True return (origin or annotation), False
def update_annotations(annotations: Dict[str, Any], globals: Dict[str, Any]) -> Dict[str, Any]: """A helper function loosely based off of typing's implementation of :meth:`typing.get_type_hints`. Main purpose of this is for evaluating postponed annotations (type hints in quotes) for more info see :pep:`563`/ https://www.python.org/dev/peps/pep-0563 """ for key, annotation in annotations.items(): annotation = _eval_type(annotation, globals) if get_origin(annotation) is Greedy: annotation.converter = annotation.__args__[0] # update the old converter Greedy[annotation.converter] # check if the evaluated type is valid annotations[key] = annotation return annotations
def _encode_value(self, val: Any, typ: Type[Any], buffer: bytearray) -> None: if isinstance(val, RosMessage): self.encode(val, buffer) return o = get_origin(typ) assert o == Annotated base, size, signed = get_args(typ) if base == str: buffer.extend(len(val).to_bytes(4, "little", signed=False)) buffer.extend(val.encode("utf-8")) elif base == int: buffer.extend(val.to_bytes(size, "little", signed=signed)) elif base == float: buffer.extend(struct.pack("<f" if size == 4 else "<d", val)) elif base == bytes: if size == 0: byte_length = len(val) buffer.extend(byte_length.to_bytes(4, "little", signed=False)) else: assert len(val) == size buffer.extend(val) elif base == Time or base == Duration: buffer.extend(val.secs.to_bytes(4, "little", signed=False)) buffer.extend(val.nsecs.to_bytes(4, "little", signed=False)) elif get_origin(base) == list: typ, *_ = get_args(base) if size == 0: byte_length = len(val) buffer.extend(byte_length.to_bytes(4, "little", signed=False)) else: assert len(val) == size for v in val: self._encode_value(v, typ, buffer)
def sequence_of_paths(value, annotation) -> WidgetTuple | None: """Determine if value/annotation is a Sequence[pathlib.Path].""" if annotation: orig = get_origin(annotation) args = get_args(annotation) if not (inspect.isclass(orig) and args): return None if _is_subclass(orig, abc.Sequence) or isinstance(orig, abc.Sequence): if _is_subclass(args[0], pathlib.Path): return widgets.FileEdit, {"mode": "rm"} elif value: if isinstance(value, abc.Sequence) and all( isinstance(v, pathlib.Path) for v in value): return widgets.FileEdit, {"mode": "rm"} return None
def __class_getitem__(cls, converter: "GreedyTypes") -> "Greedy[T]": if isinstance(converter, tuple): if len(converter) != 1: raise TypeError("commands.Greedy only accepts one argument") converter = converter[0] if ( converter in INVALID_GREEDY_TYPES or get_origin(converter) is not None or not isinstance(converter, (Converter, str)) and not callable(converter) ): raise TypeError(f"Cannot type-hint Greedy with {converter!r}") annotation = super().__class_getitem__(converter) annotation.converter = converter return annotation
def generic_check_issubclass( cls: Any, class_or_tuple: Union[Type[Any], Tuple[Type[Any], ...]]) -> bool: try: return issubclass(cls, class_or_tuple) except TypeError: origin = get_origin(cls) if origin is Union: for type_ in get_args(cls): if type_ is not type(None) and not generic_check_issubclass( type_, class_or_tuple): return False return True elif origin: return issubclass(origin, class_or_tuple) return False
def field_type(field): ftypes: Sequence[str] if get_origin(field.type) is Union: ftypes = get_args(field.type) else: ftypes = [field.type] ctype = " | ".join( str("None" if ftype == type(None) else ftype) # type: ignore # noqa: E721 for ftype in ftypes) ctype = " ".join(ctype.splitlines()) ctype = ctype.replace("typing.", "") ctype = ctype.replace("typing_extensions.", "") for tname in ("str", "int", "float", "bool"): ctype = ctype.replace(f"<class '{tname}'>", tname) return ctype