Ejemplo n.º 1
0
def constraint_type(tvar: TypeVar):
    ts = get_constraints(tvar)
    if ts:
        return Union[ts]
    else:
        bound = get_bound(tvar)
        return object if bound is None else bound
Ejemplo n.º 2
0
def is_dataclass_type(t: Type) -> bool:
    """Returns wether t is a dataclass type or a TypeVar of a dataclass type.

    Args:
        t (Type): Some type.

    Returns:
        bool: Wether its a dataclass type.
    """
    return dataclasses.is_dataclass(t) or (
        tpi.is_typevar(t) and dataclasses.is_dataclass(tpi.get_bound(t)))
Ejemplo n.º 3
0
def _repr(val: t.Any) -> str:

    assert val is not None

    if types.is_none_type(val):
        return 'NoneType'
    elif ti.is_literal_type(val):
        return str(val)
    elif ti.is_new_type(val):
        nested_type = val.__supertype__
        return f'{_qualified_name(val)}[{get_repr(nested_type)}]'
    elif ti.is_typevar(val):
        tv_constraints = ti.get_constraints(val)
        tv_bound = ti.get_bound(val)
        if tv_constraints:
            constraints_repr = (get_repr(tt) for tt in tv_constraints)
            return f'typing.TypeVar(?, {", ".join(constraints_repr)})'
        elif tv_bound:
            return get_repr(tv_bound)
        else:
            return 'typing.Any'
    elif ti.is_optional_type(val):
        optional_args = ti.get_args(val, True)[:-1]
        nested_union = len(optional_args) > 1
        optional_reprs = (get_repr(tt) for tt in optional_args)
        if nested_union:
            return f'typing.Optional[typing.Union[{", ".join(optional_reprs)}]]'
        else:
            return f'typing.Optional[{", ".join(optional_reprs)}]'
    elif ti.is_union_type(val):
        union_reprs = (get_repr(tt) for tt in ti.get_args(val, True))
        return f'typing.Union[{", ".join(union_reprs)}]'
    elif ti.is_generic_type(val):
        attr_name = val._name
        generic_reprs = (get_repr(tt) for tt in ti.get_args(val, evaluate=True))
        return f'typing.{attr_name}[{", ".join(generic_reprs)}]'
    else:
        val_name = _qualified_name(val)
        maybe_td_entries = getattr(val, '__annotations__', {}).copy()
        if maybe_td_entries:
            # we are dealing with typed dict
            # that's quite lovely
            td_keys = sorted(maybe_td_entries.keys())
            internal_members_repr = ', '.join(
                '{key}: {type}'.format(key=k, type=get_repr(maybe_td_entries.get(k)))
                for k in td_keys
            )
            return f'{val_name}{{{internal_members_repr}}}'
        elif 'TypedDict' == getattr(val, '__name__', ''):
            return 'typing_extensions.TypedDict'
        else:
            return val_name
Ejemplo n.º 4
0
def evaluate(tp: t.Union[str, t.TypeVar, t.Type, t.ForwardRef], *, frame=None):
    if isinstance(tp, str):
        tp = t.ForwardRef(tp)

    if ti.is_typevar(tp):
        tp = ti.get_bound(tp)

    # TODO python versions
    return t._eval_type(
        tp,
        frame.f_globals if frame else None,
        frame.f_locals if frame else None,
    )
Ejemplo n.º 5
0
def normalize_pytype(typ: Type) -> Type:
    if typing_inspect.is_typevar(typ):
        # we treat type vars in the most general way possible (the bound, or as 'object')
        bound = typing_inspect.get_bound(typ)
        if bound is not None:
            return normalize_pytype(bound)
        constraints = typing_inspect.get_constraints(typ)
        if constraints:
            raise CrosshairUnsupported
            # TODO: not easy; interpreting as a Union allows the type to be
            # instantiated differently in different places. So, this doesn't work:
            # return Union.__getitem__(tuple(map(normalize_pytype, constraints)))
        return object
    if typ is Any:
        # The distinction between any and object is for type checking, crosshair treats them the same
        return object
    if typ is Type:
        return type
    return typ
Ejemplo n.º 6
0
def get_parsing_fn(t: Type[T]) -> Callable[[Any], T]:
    """Gets a parsing function for the given type or type annotation.

    Args:
        t (Type[T]): A type or type annotation.

    Returns:
        Callable[[Any], T]: A function that will parse a value of the given type
            from the command-line when available, or a no-op function that
            will return the raw value, when a parsing fn cannot be found or
            constructed.
    """
    if t in _parsing_fns:
        logger.debug(f"The type {t} has a dedicated parsing function.")
        return _parsing_fns[t]

    elif t is Any:
        logger.debug(f"parsing an Any type: {t}")
        return no_op

    # TODO: Do we want to support parsing a Dict from command-line?
    # elif is_dict(t):
    #     logger.debug(f"parsing a Dict field: {t}")
    #     args = get_type_arguments(t)
    #     if len(args) != 2:
    #         args = (Any, Any)
    #     return parse_dict(*args)

    # TODO: This would require some sort of 'postprocessing' step to convert a
    # list to a Set or something like that.
    # elif is_set(t):
    #     logger.debug(f"parsing a Set field: {t}")
    #     args = get_type_arguments(t)
    #     if len(args) != 1:
    #         args = (Any,)
    #     return parse_set(args[0])

    elif is_tuple(t):
        logger.debug(f"parsing a Tuple field: {t}")
        args = get_type_arguments(t)
        if is_homogeneous_tuple_type(t):
            if not args:
                args = (str, ...)
            parsing_fn = get_parsing_fn(args[0])
        else:
            parsing_fn = parse_tuple(args)
            parsing_fn.__name__ = str(t)
        return parsing_fn

    elif is_list(t):
        logger.debug(f"parsing a List field: {t}")
        args = get_type_arguments(t)
        assert len(args) == 1
        return parse_list(args[0])

    elif is_union(t):
        logger.debug(f"parsing a Union field: {t}")
        args = get_type_arguments(t)
        return parse_union(*args)

    elif is_enum(t):
        logger.debug(f"Parsing an Enum field of type {t}")
        return parse_enum(t)
    # import typing_inspect as tpi
    # from .serializable import get_dataclass_type_from_forward_ref, Serializable

    if tpi.is_forward_ref(t):
        forward_arg = tpi.get_forward_arg(t)
        for t, fn in _parsing_fns.items():
            if getattr(t, "__name__", str(t)) == forward_arg:
                return fn

    if tpi.is_typevar(t):
        bound = tpi.get_bound(t)
        logger.debug(f"parsing a typevar: {t}, bound type is {bound}.")
        if bound is not None:
            return get_parsing_fn(bound)

    logger.debug(f"Couldn't find a parsing function for type {t}, will try "
                 f"to use the type directly.")
    return t
Ejemplo n.º 7
0
    def get(cls, type_or_hint, *, is_argument: bool = False) -> "TypeChecker":
        # This ensures the validity of the type passed (see typing documentation for info)
        type_or_hint = is_valid_type(type_or_hint, "Invalid type.",
                                     is_argument)

        if type_or_hint is Any:
            return AnyTypeChecker()

        if is_type(type_or_hint):
            return TypeTypeChecker.make(type_or_hint, is_argument)

        if is_literal_type(type_or_hint):
            return LiteralTypeChecker.make(type_or_hint, is_argument)

        if is_generic_type(type_or_hint):
            origin = get_origin(type_or_hint)
            if issubclass(origin, MappingCol):
                return MappingTypeChecker.make(type_or_hint, is_argument)

            if issubclass(origin, Collection):
                return CollectionTypeChecker.make(type_or_hint, is_argument)

            # CONSIDER: how to cater for exhaustible generators?
            if issubclass(origin, Iterable):
                raise NotImplementedError(
                    "No type-checker is setup for iterables that exhaust.")

            return GenericTypeChecker.make(type_or_hint, is_argument)

        if is_tuple_type(type_or_hint):
            return TupleTypeChecker.make(type_or_hint, is_argument)

        if is_callable_type(type_or_hint):
            return CallableTypeChecker.make(type_or_hint, is_argument)

        if isclass(type_or_hint):
            if is_typed_dict(type_or_hint):
                return TypedDictChecker.make(type_or_hint, is_argument)
            return ConcreteTypeChecker.make(type_or_hint, is_argument)

        if is_union_type(type_or_hint):
            return UnionTypeChecker.make(type_or_hint, is_argument)

        if is_typevar(type_or_hint):
            bound_type = get_bound(type_or_hint)
            if bound_type:
                return cls.get(bound_type)
            constraints = get_constraints(type_or_hint)
            if constraints:
                union_type_checkers = tuple(
                    cls.get(type_) for type_ in constraints)
                return UnionTypeChecker(Union.__getitem__(constraints),
                                        union_type_checkers)
            else:
                return AnyTypeChecker()

        if is_new_type(type_or_hint):
            super_type = getattr(type_or_hint, "__supertype__", None)
            if super_type is None:
                raise TypeError(
                    f"No supertype for NewType: {type_or_hint}. This is not allowed."
                )
            return cls.get(super_type)

        if is_forward_ref(type_or_hint):
            return ForwardTypeChecker.make(type_or_hint,
                                           is_argument=is_argument)

        if is_classvar(type_or_hint):
            var_type = get_args(type_or_hint, evaluate=True)[0]
            return cls.get(var_type)

        raise NotImplementedError(
            f"No {TypeChecker.__qualname__} is available for type or hint: '{type_or_hint}'"
        )
Ejemplo n.º 8
0
 def test_bound(self):
     T = TypeVar('T')
     TB = TypeVar('TB', bound=int)
     self.assertEqual(get_bound(T), None)
     self.assertEqual(get_bound(TB), int)
Ejemplo n.º 9
0
def get_decoding_fn(t: Type[T]) -> Callable[[Any], T]:
    """Fetches/Creates a decoding function for the given type annotation. 

    This decoding function can then be used to create an instance of the type
    when deserializing dicts (which could have been obtained with JSON or YAML).
    
    This function inspects the type annotation and creates the right decoding
    function recursively in a "dynamic-programming-ish" fashion.
    NOTE: We cache the results in a `functools.lru_cache` decorator to avoid
    wasteful calls to the function. This makes this process pretty efficient.
    
    Args:
        t (Type[T]):
            A type or type annotation. Can be arbitrarily nested.
            For example:
            - List[int]
            - Dict[str, Foo]
            - Tuple[int, str, Any],
            - Dict[Tuple[int, int], List[str]]
            - List[List[List[List[Tuple[int, str]]]]]
            - etc.

    Returns:
        Callable[[Any], T]:
            A function that decodes a 'raw' value to an instance of type `t`.

    """
    # cache_info = get_decoding_fn.cache_info()
    # logger.debug(f"called for type {t}! Cache info: {cache_info}")

    if t in _decoding_fns:
        # The type has a dedicated decoding function.
        return _decoding_fns[t]

    elif t is Any:
        logger.debug(f"Decoding an Any type: {t}")
        return no_op

    elif is_dict(t):
        logger.debug(f"Decoding a Dict field: {t}")
        args = get_type_arguments(t)
        if len(args) != 2:
            args = (Any, Any)
        return decode_dict(*args)

    elif is_set(t):
        logger.debug(f"Decoding a Set field: {t}")
        args = get_type_arguments(t)
        if len(args) != 1:
            args = (Any, )
        return decode_set(args[0])

    elif is_tuple(t):
        logger.debug(f"Decoding a Tuple field: {t}")
        args = get_type_arguments(t)
        return decode_tuple(*args)

    elif is_list(t):
        logger.debug(f"Decoding a List field: {t}")
        args = get_type_arguments(t)
        assert len(args) == 1
        return decode_list(args[0])

    elif is_union(t):
        logger.debug(f"Decoding a Union field: {t}")
        args = get_type_arguments(t)
        return decode_union(*args)

    import typing_inspect as tpi
    from .serializable import get_dataclass_type_from_forward_ref, Serializable

    if tpi.is_forward_ref(t):
        dc = get_dataclass_type_from_forward_ref(t)
        if dc is Serializable:
            # Since dc is Serializable, this means that we found more than one
            # matching dataclass the the given forward ref, and the right
            # subclass will be determined based on the matching fields.
            # Therefore we set drop_extra_fields=False.
            return partial(dc.from_dict, drop_extra_fields=False)
        if dc:
            return dc.from_dict

    if tpi.is_typevar(t):
        bound = tpi.get_bound(t)
        logger.debug(f"Decoding a typevar: {t}, bound type is {bound}.")
        if bound is not None:
            return get_decoding_fn(bound)

    # Unknown type.
    warnings.warn(
        UserWarning(f"Unable to find a decoding function for type {t}. "
                    f"Will try to use the type as a constructor."))
    return try_constructor(t)
Ejemplo n.º 10
0
def get_decoding_fn(t: Type[T]) -> Callable[[Any], T]:
    """Fetches/Creates a decoding function for the given type annotation.

    This decoding function can then be used to create an instance of the type
    when deserializing dicts (which could have been obtained with JSON or YAML).

    This function inspects the type annotation and creates the right decoding
    function recursively in a "dynamic-programming-ish" fashion.
    NOTE: We cache the results in a `functools.lru_cache` decorator to avoid
    wasteful calls to the function. This makes this process pretty efficient.

    Args:
        t (Type[T]):
            A type or type annotation. Can be arbitrarily nested.
            For example:
            - List[int]
            - Dict[str, Foo]
            - Tuple[int, str, Any],
            - Dict[Tuple[int, int], List[str]]
            - List[List[List[List[Tuple[int, str]]]]]
            - etc.

    Returns:
        Callable[[Any], T]:
            A function that decodes a 'raw' value to an instance of type `t`.

    """
    # cache_info = get_decoding_fn.cache_info()
    # logger.debug(f"called for type {t}! Cache info: {cache_info}")

    if t in _decoding_fns:
        # The type has a dedicated decoding function.
        return _decoding_fns[t]

    if t is Any:
        logger.debug(f"Decoding an Any type: {t}")
        return no_op

    if is_dict(t):
        logger.debug(f"Decoding a Dict field: {t}")
        args = get_type_arguments(t)
        if len(args) != 2:
            args = (Any, Any)
        return decode_dict(*args)

    if is_set(t):
        logger.debug(f"Decoding a Set field: {t}")
        args = get_type_arguments(t)
        if len(args) != 1:
            args = (Any,)
        return decode_set(args[0])

    if is_tuple(t):
        logger.debug(f"Decoding a Tuple field: {t}")
        args = get_type_arguments(t)
        return decode_tuple(*args)

    if is_list(t):
        logger.debug(f"Decoding a List field: {t}")
        args = get_type_arguments(t)
        if not args:
            # Using a `List` or `list` annotation, so we don't know what do decode the
            # items into!
            args = (Any,)
        assert len(args) == 1
        return decode_list(args[0])

    if is_union(t):
        logger.debug(f"Decoding a Union field: {t}")
        args = get_type_arguments(t)
        return decode_union(*args)

    import typing_inspect as tpi
    from .serializable import (
        get_dataclass_types_from_forward_ref,
        Serializable,
        SerializableMixin,
        FrozenSerializable,
    )

    if tpi.is_forward_ref(t):
        dcs = get_dataclass_types_from_forward_ref(t)
        if len(dcs) == 1:
            dc = dcs[0]
            return dc.from_dict
        if len(dcs) > 1:
            logger.warning(
                RuntimeWarning(
                    f"More than one potential Serializable dataclass was found with a name matching "
                    f"the type annotation {t}. This will simply try each one, and return the "
                    f"first one that works. Potential classes: {dcs}"
                )
            )
            return try_functions(*[partial(dc.from_dict, drop_extra_fields=False) for dc in dcs])
        else:
            # No idea what the forward ref refers to!
            logger.warning(
                f"Unable to find a dataclass that matches the forward ref {t} inside the "
                f"registered {SerializableMixin} subclasses. Leaving the value as-is."
                f"(Consider using Serializable or FrozenSerializable as a base class?)."
            )
            return no_op

    if tpi.is_typevar(t):
        bound = tpi.get_bound(t)
        logger.debug(f"Decoding a typevar: {t}, bound type is {bound}.")
        if bound is not None:
            return get_decoding_fn(bound)

    # Unknown type.
    warnings.warn(
        UserWarning(
            f"Unable to find a decoding function for type {t}. "
            f"Will try to use the type as a constructor."
        )
    )
    return try_constructor(t)