def type_check(item, attr_name, value): th_cls = typing.get_type_hints(item) if hasattr(item, '__annotations__') else {} try: tt = th_cls[attr_name] except KeyError: th_init = typing.get_type_hints(item.__init__) try: tt = th_init[attr_name] except KeyError: return allowed = typing_inspect.get_args(tt) or tt if not isinstance(value, allowed): one_of = "with one of following types" if isinstance(allowed, tuple) else "of type" return "Expected value {one_of}: {exp}, got {got} for {cls}.{member}".format( one_of=one_of, exp=allowed, got=type(value), cls=item.__class__.__name__, member=attr_name) return None
def match_expression( wildcards: typing.List[Expression], template: object, expr: object) -> typing.Tuple[TypeVarMapping, WildcardMapping]: """ Returns a mapping of wildcards to the objects at that level, or None if it does not match. A wildcard can match either an expression or a value. If it matches two nodes, they must be equal. """ if template in wildcards: # If we are matching against a placeholder and the expression is not resolved to that placeholder, don't match. if (isinstance(template, PlaceholderExpression) and isinstance(expr, Expression) and not typing_inspect.is_typevar( typing_inspect.get_args( typing_inspect.get_generic_type(template))[0])): raise NoMatch # Match type of wildcard with type of expression try: return ( match_values(template, expr), UnhashableMapping(Item(typing.cast(Expression, template), expr)), ) except TypeError: raise NoMatch if isinstance(expr, Expression): if not isinstance(template, Expression): raise NoMatch # Any typevars in the template that are unbound should be matched with their # versions in the expr try: fn_type_mapping: TypeVarMapping = match_functions( template.function, expr.function) except TypeError: raise NoMatch if set(expr.kwargs.keys()) != set(template.kwargs.keys()): raise TypeError("Wrong kwargs in match") template_args: typing.Iterable[object] expr_args: typing.Iterable[object] # Process args in the template that can represent any number of args. # These are the "IteratedPlaceholder"s # Allow one iterated placeholder in the template args # For example fn(a, b, [...], *c, d, e, [...]) # Here `c` should take as many args as it can between the ends, # Each of those should be matched against the inner iterated_args = [ arg for arg in template.args if isinstance(arg, IteratedPlaceholder) ] if iterated_args: # template args, minus the iterator, is the minimum length of the values # If they have less values than this, raise an error if len(expr.args) < len(template.args) - 1: raise TypeError("Wrong number of args in match") template_args_ = list(template.args) # Only support one iterated arg for now # TODO: Support more than one, would require branching template_iterated, = iterated_args template_index_iterated = list( template.args).index(template_iterated) # Swap template iterated with inner wildcard template_args_[template_index_iterated], = template_iterated.args template_args = template_args_ expr_args = collapse_tuple( expr.args, template_index_iterated, # The number we should preserve on the right, is the number of template # args after index len(template.args) - template_index_iterated - 1, ) else: if len(template.args) != len(expr.args): raise TypeError("Wrong number of args in match") template_args = template.args expr_args = expr.args type_mappings, expr_mappings = list( zip( *(match_expression(wildcards, arg_template, arg_value) for arg_template, arg_value in zip(template_args, expr_args) ), *(match_expression(wildcards, template.kwargs[key], expr.kwargs[key]) for key in template.kwargs.keys()), )) or ((), ()) try: merged_typevars: TypeVarMapping = merge_typevars( fn_type_mapping, *type_mappings) except TypeError: raise NoMatch try: return ( merged_typevars, safe_merge(*expr_mappings, dict_constructor=UnhashableMapping), ) except ValueError: raise NoMatch if template != expr: raise NoMatch return match_values(template, expr), UnhashableMapping()
def encode_generic_set(encoder, typ, value): if not (inspect.is_generic_type(typ) and inspect.get_origin(typ) == set): return Unsupported check_type(set, value) item_type, = inspect.get_args(typ) return [encoder.encode(item, item_type) for item in value]
def resolve_type_info( typ: t.Type[t.Any], *, is_optional: bool = False, custom: t.Optional[t.Type[t.Any]] = None, _nonetype: t.Type[t.Any] = t.cast(t.Type[t.Any], type(None)), # xxx _anytype: t.Type[t.Any] = t.cast(t.Type[t.Any], t.Any), # xxx _primitives: t.Set[t.Type[t.Any]] = t.cast( t.Set[t.Type[t.Any]], set([str, int, bool, str, bytes, dict, list, t.Any])), ) -> TypeInfo: raw = typ args = typing_inspect.get_args(typ) underlying = getattr(typ, "__origin__", None) if underlying is None: if not hasattr(typ, "__iter__"): underlying = typ # xxx elif issubclass(typ, str): underlying = typ elif issubclass(typ, t.Sequence): return _make_container( raw=raw, container="tuple" if issubclass(typ, tuple) else "list", args=(resolve_type_info(_anytype), ), ) elif issubclass(typ, t.Mapping): childinfo = resolve_type_info(_anytype) return _make_container(raw=raw, container="dict", args=(childinfo, childinfo)) else: underlying = typ # xxx else: if underlying == t.Union: if len(args) == 2: if args[0] == _nonetype: is_optional = True typ = underlying = args[1] elif args[1] == _nonetype: is_optional = True typ = underlying = args[0] else: return _make_container( container="union", raw=raw, args=tuple([resolve_type_info(t) for t in args]), is_optional=is_optional, is_composite=True, ) else: is_optional = _nonetype in args if is_optional: args = [x for x in args if x != _nonetype] return _make_container( container="union", raw=raw, args=tuple([resolve_type_info(t) for t in args]), is_optional=is_optional, is_composite=True, ) if hasattr(typ, "__origin__"): underlying = typ.__origin__ if underlying == tx.Literal: args = typing_inspect.get_args(typ) underlying = type(args[0]) # TODO: meta info elif issubclass(underlying, t.Sequence): args = typing_inspect.get_args(typ) return _make_container( raw=raw, container="tuple" if issubclass(underlying, tuple) else "list", args=tuple([resolve_type_info(t) for t in args]), is_optional=is_optional, ) elif issubclass(underlying, t.Mapping): args = typing_inspect.get_args(typ) return _make_container( raw=raw, container="dict", args=tuple([resolve_type_info(t) for t in args]), is_optional=is_optional, ) else: raise ValueError(f"unsuported type %{typ}") while hasattr(underlying, "__supertype__"): underlying = underlying.__supertype__ if underlying not in _primitives: custom = underlying return _make_atom(raw=raw, underlying=underlying, is_optional=is_optional, custom=custom)
def make(cls, type_or_hint, is_argument: bool) -> "TypeTypeChecker": var_type = get_args(type_or_hint, evaluate=True)[0] return cls(type_or_hint, cls.get(var_type))
def make(cls, type_or_hint, is_argument: bool) -> "CollectionTypeChecker": origin = get_origin(type_or_hint) origin_type_checker = ConcreteTypeChecker(origin) item_type = (get_args(type_or_hint, evaluate=True) or (Any, ))[0] return cls(type_or_hint, origin_type_checker, cls.get(item_type))
def _validate_list(value: object, target_type: Type[List[object]]) -> None: if not isinstance(value, list): raise InvalidJson(f"`{value}` is not a list") (element_type, ) = get_args(target_type) for element in value: _validate_value(element, element_type)
def match_types(hint: typing.Type, t: typing.Type) -> TypeVarMapping: """ Matches a type hint with a type, return a mapping of any type vars to their values. """ logger.debug("match_types hint=%s type=%s", hint, t) if hint == object: hint = typing.Any # type: ignore if t == object: t = typing.Any # type: ignore if hint == t: return {} # If it is an instance of OfType[Type[T]], then we should consider it as T if isinstance(t, OfType): (of_type, ) = typing_inspect.get_args(get_type(t)) assert issubclass(of_type, typing.Type) (t, ) = typing_inspect.get_args(of_type) return match_types(hint, t) # If the type is an OfType[T] then we should really just consider it as T if issubclass(t, OfType) and not issubclass(hint, OfType): (t, ) = typing_inspect.get_args(t) return match_types(hint, t) if issubclass(hint, OfType) and not issubclass(t, OfType): (hint, ) = typing_inspect.get_args(hint) return match_types(hint, t) # Matching an expanded type is like matching just whatever it represents if issubclass(t, ExpandedType): (t, ) = typing_inspect.get_args(t) if typing_inspect.is_typevar(hint): return {hint: t} # This happens with match rule on conversion, like when the value is TypeVar if typing_inspect.is_typevar(t): return {} # if both are generic sequences, verify they are the same and have the same contents if (typing_inspect.is_generic_type(hint) and typing_inspect.is_generic_type(t) and typing_inspect.get_origin(hint) == collections.abc.Sequence and typing_inspect.get_origin(t) == collections.abc.Sequence): t_inner = typing_inspect.get_args(t)[0] # If t's inner arg is just the default one for seuqnce, it hasn't be initialized so assume # it was an empty tuple that created it and just return a match if t_inner == typing_inspect.get_args(typing.Sequence)[0]: return {} return match_types(typing_inspect.get_args(hint)[0], t_inner) if typing_inspect.is_union_type(hint): # If this is a union, iterate through and use the first that is a subclass for inner_type in typing_inspect.get_args(hint): if issubclass(t, inner_type): hint = inner_type break else: raise TypeError(f"Cannot match concrete type {t} with hint {hint}") logger.debug("checking if type subclass hint hint=%s type=%s", hint, t) if not issubclass(t, hint): logger.debug("not subclass") raise TypeError(f"Cannot match concrete type {t} with hint {hint}") return merge_typevars(*(match_types(inner_hint, inner_t) for inner_hint, inner_t in zip( get_inner_types(hint), get_inner_types(t))))
def _all_subclasses(typ, *, module=None): """ Return all subclasses of a given type. The type must be one of - :class:`GTScriptAstNode` (returns all subclasses of the given class) - :class:`Union` (return the subclasses of the united) - :class:`ForwardRef` (resolve the reference given the specified module and return its subclasses) - built-in python type: :class:`str`, :class:`int`, `type(None)` (return as is) """ if inspect.isclass(typ) and issubclass(typ, gtscript_ast.GTScriptASTNode): result = { typ, *typ.__subclasses__(), *[ s for c in typ.__subclasses__() for s in PyToGTScript._all_subclasses(c) if not inspect.isabstract(c) ], } return result elif inspect.isclass(typ) and typ in [ gtc_unstructured.irs.common.AssignmentKind, gtc_unstructured.irs.common.UnaryOperator, gtc_unstructured.irs.common.BinaryOperator, stable_gtc_common.UnaryOperator, ]: # note: other types in gtc_unstructured.irs.common, e.g. gtc_unstructured.irs.common.DataType are not valid leaf nodes here as they # map to symbols in the gtscript ast and are resolved there assert issubclass(typ, enum.Enum) return {typ} elif typing_inspect.is_union_type(typ): return { sub_cls for el_cls in typing_inspect.get_args(typ) for sub_cls in PyToGTScript._all_subclasses(el_cls, module=module) } elif isinstance(typ, typing.ForwardRef): type_name = typing_inspect.get_forward_arg(typ) if not hasattr(module, type_name): raise ValueError( f"Reference to type `{type_name}` in `ForwardRef` not found in module {module.__name__}" ) return PyToGTScript._all_subclasses(getattr(module, type_name), module=module) elif typ in [ type_definitions.Str, type_definitions.Int, type_definitions.Float, type_definitions.SymbolRef, type_definitions.SymbolName, str, int, float, type(None), ]: # TODO(tehrengruber): enhance return {typ} raise ValueError(f"Invalid field type {typ}")
def transform(self, node, eligible_node_types=None, node_init_args=None): """Transform python ast into GTScript ast recursively.""" if eligible_node_types is None: eligible_node_types = [gtscript_ast.Computation] if isinstance(node, ast.AST): is_leaf_node = len(list(ast.iter_fields(node))) == 0 if is_leaf_node: if not type(node) in self.leaf_map: raise ValueError( f"Leaf node of type {type(node)}, found in the python ast, can not be mapped." ) return self.leaf_map[type(node)] else: # visit node fields and transform # TODO(tehrengruber): check if multiple nodes match and throw an error in that case # disadvantage: templates can be ambiguous for node_type in eligible_node_types: if not hasattr(self.Patterns, node_type.__name__): continue captures = {} if not anm.match(node, getattr(self.Patterns, node_type.__name__), captures=captures): continue module = sys.modules[node_type.__module__] if node_init_args: captures = {**node_init_args, **captures} args = list(value for key, value in captures.items() if isinstance(key, int)) kwargs = { key: value for key, value in captures.items() if isinstance(key, str) } transformed_kwargs = {} for name, capture in kwargs.items(): assert ( name in node_type.__annotations__ ), f"Invalid capture. No field named `{name}` in `{str(node_type)}`" field_type = node_type.__annotations__[name] if typing_inspect.get_origin(field_type) == list: # determine eligible capture types el_type = typing_inspect.get_args(field_type)[0] eligible_capture_types = self._all_subclasses( el_type, module=module) # transform captures recursively transformed_kwargs[name] = [] for child_capture in capture: transformed_kwargs[name].append( self.transform(child_capture, eligible_capture_types)) else: # determine eligible capture types eligible_capture_types = self._all_subclasses( field_type, module=module) # transform captures recursively transformed_kwargs[name] = self.transform( capture, eligible_capture_types) assert len(args) + len(transformed_kwargs) == len(captures) return node_type(*args, **transformed_kwargs) raise ValueError("Expected a node of type {}".format(", ".join( [ent.__name__ for ent in eligible_node_types]))) elif type(node) in eligible_node_types: return node elif (type(node) in self.pseudo_polymorphic_types and len(self.pseudo_polymorphic_types[type(node)] & set(eligible_node_types)) > 0): valid_types = self.pseudo_polymorphic_types[type(node)] & set( eligible_node_types) if len(valid_types) > 1: raise RuntimeError( "Invalid gtscript ast specification. The node {node} has multiple valid types in the gtscript ast: {valid_types}" ) return next(iter(valid_types))(node) raise ValueError("Expected a node of type {}, but got {}".format( {*eligible_node_types, ast.AST}, type(node)))
def field_for_schema( typ: type, default=marshmallow.missing, metadata: t.Mapping[str, t.Any] = None) -> marshmallow.fields.Field: """ Get a marshmallow Field corresponding to the given python type. The metadata of the dataclass field is used as arguments to the marshmallow Field. >>> int_field = field_for_schema(int, default=9, metadata=dict(required=True)) >>> int_field.__class__ <class 'marshmallow.fields.Integer'> >>> int_field.default 9 >>> field_for_schema(t.Dict[str,str]).__class__ <class 'marshmallow.fields.Dict'> >>> field_for_schema(t.Optional[str]).__class__ <class 'marshmallow.fields.String'> >>> import marshmallow_enum >>> field_for_schema(enum.Enum("X", "a b c")).__class__ <class 'marshmallow_enum.EnumField'> >>> import typing >>> field_for_schema(t.Union[int,str]).__class__ <class 'marshmallow_union.Union'> >>> field_for_schema(t.NewType('UserId', int)).__class__ <class 'marshmallow.fields.Integer'> >>> field_for_schema(t.NewType('UserId', int), default=0).default 0 >>> class Color(enum.Enum): ... red = 1 >>> field_for_schema(Color).__class__ <class 'marshmallow_enum.EnumField'> >>> field_for_schema(t.Any).__class__ <class 'marshmallow.fields.Raw'> """ if metadata is None: metadata = {} else: metadata = dict(metadata) desert_metadata = dict(metadata).get(_DESERT_SENTINEL, {}) metadata[_DESERT_SENTINEL] = desert_metadata if default is not marshmallow.missing: desert_metadata.setdefault("default", default) desert_metadata.setdefault("allow_none", True) if not desert_metadata.get( "required"): # 'missing' must not be set for required fields. desert_metadata.setdefault("missing", default) field = None # If the field was already defined by the user predefined_field = desert_metadata.get("marshmallow_field") if predefined_field: field = predefined_field field.metadata.update(metadata) return field # Base types if not field and typ in _native_to_marshmallow: field = _native_to_marshmallow[typ](default=default) # Generic types origin = typing_inspect.get_origin(typ) if origin: arguments = typing_inspect.get_args(typ, True) if origin in (list, t.List): field = marshmallow.fields.List(field_for_schema(arguments[0])) if origin in (tuple, t.Tuple) and Ellipsis not in arguments: field = marshmallow.fields.Tuple( tuple(field_for_schema(arg) for arg in arguments)) elif origin in (tuple, t.Tuple) and Ellipsis in arguments: field = VariadicTuple( field_for_schema( only(arg for arg in arguments if arg != Ellipsis))) elif origin in (dict, t.Dict): field = marshmallow.fields.Dict( keys=field_for_schema(arguments[0]), values=field_for_schema(arguments[1]), ) elif typing_inspect.is_optional_type(typ): subtyp = next(t for t in arguments if t is not NoneType) # Treat optional types as types with a None default metadata[_DESERT_SENTINEL]["default"] = metadata.get( "default", None) metadata[_DESERT_SENTINEL]["missing"] = metadata.get( "missing", None) metadata[_DESERT_SENTINEL]["required"] = False field = field_for_schema(subtyp, metadata=metadata, default=None) field.default = None field.missing = None field.allow_none = True elif typing_inspect.is_union_type(typ): subfields = [field_for_schema(subtyp) for subtyp in arguments] import marshmallow_union field = marshmallow_union.Union(subfields) # t.NewType returns a function with a __supertype__ attribute newtype_supertype = getattr(typ, "__supertype__", None) if newtype_supertype and inspect.isfunction(typ): metadata.setdefault("description", typ.__name__) field = field_for_schema(newtype_supertype, default=default) # enumerations if type(typ) is enum.EnumMeta: import marshmallow_enum field = marshmallow_enum.EnumField(typ, metadata=metadata) # Nested dataclasses forward_reference = getattr(typ, "__forward_arg__", None) if field is None: nested = forward_reference or class_schema(typ) try: nested.help = typ.__doc__ except AttributeError: # TODO need to handle the case where nested is a string forward reference. pass field = marshmallow.fields.Nested(nested) field.metadata.update(metadata) for key in ["default", "missing", "required", "marshmallow_field"]: if key in metadata.keys(): metadata[_DESERT_SENTINEL][key] = metadata.pop(key) if field.default == field.missing == default == marshmallow.missing: field.required = True return field
def from_list(tp: Type[List[T]], v: List[Any]) -> List[T]: """Transform a list of JSON-like structures into JSON-compatible objects.""" (inner_type, ) = get_args(tp) return [_from_json_like(inner_type, value) for value in v]
def test_valid(self): variants = list(get_args(DelimitedVariant)) for variant in variants: records_format = DelimitedRecordsFormat(variant=variant) records_format.validate(fail_if_cant_handle_hint=True)
def expand_typing( cls: 'Union[Model, Type[Model], RecordType, Type[RecordType]]', exc: 'Optional[Type[ValueError]]' = None) -> 'ExpandedTyping': """Expand typing annotations. Args: cls (:class:`~zlogging.model.Model` or :class:`~zlogging.types.RecordType` object): a variadic class which supports `PEP 484`_ style attribute typing annotations exc: (:obj:`ValueError`, optional): exception to be used in case of inconsistent values for ``unset_field``, ``empty_field`` and ``set_separator`` Returns: :obj:`Dict[str, Any]`: The returned dictionary contains the following directives: * ``fields`` (:obj:`OrderedDict` mapping :obj:`str` and :class:`~zlogging.types.BaseType`): a mapping proxy of field names and their corresponding data types, i.e. an instance of a :class:`~zlogging.types.BaseType` subclass * ``record_fields`` (:obj:`OrderedDict` mapping :obj:`str` and :class:`~zlogging.types.RecordType`): a mapping proxy for fields of ``record`` data type, i.e. an instance of :class:`~zlogging.types.RecordType` * ``unset_fields`` (:obj:`bytes`): placeholder for unset field * ``empty_fields`` (:obj:`bytes`): placeholder for empty field * ``set_separator`` (:obj:`bytes`): separator for ``set``/``vector`` fields Warns: BroDeprecationWarning: Use of ``bro_*`` prefixed typing annotations. Raises: :exc:`ValueError`: In case of inconsistent values for ``unset_field``, ``empty_field`` and ``set_separator``. Example: Define a custom log data model from :class:`~zlogging.model.Model` using the prefines Bro/Zeek data types, or subclasses of :class:`~zlogging.types.BaseType`:: class MyLog(Model): field_one = StringType() field_two = SetType(element_type=PortType) Or you may use type annotations as `PEP 484`_ introduced when declaring data models. All available type hints can be found in :mod:`zlogging.typing`:: class MyLog(Model): field_one: zeek_string field_two: zeek_set[zeek_port] However, when mixing annotations and direct assignments, annotations will take proceedings, i.e. the function shall process first typing annotations then ``cls`` attribute assignments. Should there be any conflicts, the ``exc`` will be raised. Note: Fields of :class:`zlogging.types.RecordType` type will be expanded as plain fields of the ``cls``, i.e. for the variadic class as below:: class MyLog(Model): record = RecrodType(one=StringType(), two=VectorType(element_type=CountType())) will have the following fields: * ``record.one`` -> ``string`` data type * ``record.two`` -> ``vector[count]`` data type .. _PEP 484: https://www.python.org/dev/peps/pep-0484/ """ from zlogging.types import ( BaseType, _GenericType, # pylint: disable=import-outside-toplevel _SimpleType, _VariadicType) if exc is None: exc = ValueError inited = False unset_field = b'-' empty_field = b'(empty)' set_separator = b',' def register(name: str, field: 'Union[_SimpleType, _GenericType]') -> None: """Field registry.""" existed = fields.get(name) if existed is not None and field.zeek_type != existed.zeek_type: raise exc( f'inconsistent data type of {name!r} field: {field!r} and {existed!r}' ) # type: ignore[misc] fields[name] = field fields = collections.OrderedDict( ) # type: OrderedDict[str, Union[_SimpleType, _GenericType]] record_fields = collections.OrderedDict( ) # type: OrderedDict[str, _VariadicType] for name, attr in itertools.chain( getattr(cls, '__annotations__', dict()).items(), cls.__dict__.items()): # type instances if isinstance(attr, BaseType): if isinstance(attr, _VariadicType): for elm_name, elm_field in attr.element_mapping.items(): register(f'{name}.{elm_name}', elm_field) record_fields[name] = attr else: register(name, attr) # type: ignore[arg-type] # uninitialised type classes elif isinstance(attr, type) and issubclass(attr, BaseType): attr = attr() # simple typing types elif is_typevar(attr): if TYPE_CHECKING: attr = cast('TypeVar', attr) bound = attr.__bound__ if bound and issubclass(bound, _SimpleType): attr = bound() else: continue # generic typing types elif is_generic_type(attr) and issubclass(attr, _GenericType): origin = get_origin(attr) parameter = get_args(attr)[0] # uninitialised type classes if isinstance(parameter, type) and issubclass( parameter, _SimpleType): element_type = parameter() # simple typing types elif is_typevar(parameter): if TYPE_CHECKING: parameter = cast('TypeVar', parameter) bound = parameter.__bound__ if bound and issubclass(bound, _SimpleType): element_type = bound() else: element_type = bound # type: ignore[assignment] else: element_type = parameter # type: ignore[assignment] attr = origin(element_type=element_type)\ else: continue if not inited: unset_field = attr.unset_field empty_field = attr.empty_field set_separator = attr.set_separator inited = True continue if unset_field != attr.unset_field: raise exc( f"inconsistent value of 'unset_field': {unset_field!r} and {attr.unset_field!r}" ) if empty_field != attr.empty_field: raise exc( f"inconsistent value of 'empty_field': {empty_field!r} and {attr.empty_field!r}" ) if set_separator != attr.set_separator: raise exc( "inconsistent value of 'set_separator': {set_separator!r} and {attr.set_separator!r}" ) return { '_inited': inited, 'fields': fields, 'record_fields': record_fields, 'unset_field': unset_field, 'empty_field': empty_field, 'set_separator': set_separator, }
def inner_type_boundaries(typ: Type) -> Tuple: return insp.get_args(typ, evaluate=True)
def wildcard_inner_type(wildcard: object) -> typing.Type: """ Return inner type for a wildcard """ return typing_inspect.get_args( typing_inspect.get_generic_type(wildcard))[0]
def field_for_schema( typ: type, default=marshmallow.missing, metadata: Mapping[str, Any] = None, base_schema: Optional[Type[marshmallow.Schema]] = None, ) -> marshmallow.fields.Field: """ Get a marshmallow Field corresponding to the given python type. The metadata of the dataclass field is used as arguments to the marshmallow Field. :param typ: The type for which a field should be generated :param default: value to use for (de)serialization when the field is missing :param metadata: Additional parameters to pass to the marshmallow field constructor :param base_schema: marshmallow schema used as a base class when deriving dataclass schema >>> int_field = field_for_schema(int, default=9, metadata=dict(required=True)) >>> int_field.__class__ <class 'marshmallow.fields.Integer'> >>> int_field.default 9 >>> field_for_schema(str, metadata={"marshmallow_field": marshmallow.fields.Url()}).__class__ <class 'marshmallow.fields.Url'> """ metadata = {} if metadata is None else dict(metadata) if default is not marshmallow.missing: metadata.setdefault("default", default) # 'missing' must not be set for required fields. if not metadata.get("required"): metadata.setdefault("missing", default) else: metadata.setdefault("required", True) # If the field was already defined by the user predefined_field = metadata.get("marshmallow_field") if predefined_field: return predefined_field # Generic types specified without type arguments if typ is list: typ = List[Any] elif typ is dict: typ = Dict[Any, Any] # Base types field = _field_by_type(typ, base_schema) if field: return field(**metadata) if typ is Any: metadata.setdefault("allow_none", True) return marshmallow.fields.Raw(**metadata) # Generic types origin = typing_inspect.get_origin(typ) if origin: arguments = typing_inspect.get_args(typ, True) # Override base_schema.TYPE_MAPPING to change the class used for generic types below type_mapping = base_schema.TYPE_MAPPING if base_schema else {} if origin in (list, List): child_type = field_for_schema(arguments[0], base_schema=base_schema) list_type = type_mapping.get(List, marshmallow.fields.List) return list_type(child_type, **metadata) if origin in (tuple, Tuple): children = tuple( field_for_schema(arg, base_schema=base_schema) for arg in arguments) tuple_type = type_mapping.get(Tuple, marshmallow.fields.Tuple) return tuple_type(children, **metadata) elif origin in (dict, Dict): dict_type = type_mapping.get(Dict, marshmallow.fields.Dict) return dict_type( keys=field_for_schema(arguments[0], base_schema=base_schema), values=field_for_schema(arguments[1], base_schema=base_schema), **metadata, ) elif typing_inspect.is_optional_type(typ): subtyp = next(t for t in arguments if t is not NoneType) # type: ignore # Treat optional types as types with a None default metadata["default"] = metadata.get("default", None) metadata["missing"] = metadata.get("missing", None) metadata["required"] = False return field_for_schema(subtyp, metadata=metadata, base_schema=base_schema) elif typing_inspect.is_union_type(typ): subfields = [ field_for_schema(subtyp, metadata=metadata, base_schema=base_schema) for subtyp in arguments ] import marshmallow_union return marshmallow_union.Union(subfields, **metadata) # typing.NewType returns a function with a __supertype__ attribute newtype_supertype = getattr(typ, "__supertype__", None) if newtype_supertype and inspect.isfunction(typ): return _field_by_supertype( typ=typ, default=default, newtype_supertype=newtype_supertype, metadata=metadata, base_schema=base_schema, ) # enumerations if isinstance(typ, EnumMeta): import marshmallow_enum return marshmallow_enum.EnumField(typ, **metadata) # Nested marshmallow dataclass nested_schema = getattr(typ, "Schema", None) # Nested dataclasses forward_reference = getattr(typ, "__forward_arg__", None) nested = (nested_schema or forward_reference or class_schema(typ, base_schema=base_schema)) return marshmallow.fields.Nested(nested, **metadata)
def is_dataclass_or_optional_dataclass_type(t: Type) -> bool: """ Returns wether `t` is a dataclass type or an Optional[<dataclass type>]. """ return is_dataclass(t) or (tpi.is_optional_type(t) and is_dataclass(tpi.get_args(t)[0]))
get_git_url, has_git, has_uncommitted_changes, is_option_arg, type_to_str, get_literals, boolean_type, TupleTypeEnforcer, define_python_object_encoder, as_python_object, fix_py36_copy, enforce_reproducibility, ) # Constants EMPTY_TYPE = get_args(List)[0] if len(get_args(List)) > 0 else tuple() SUPPORTED_DEFAULT_BASE_TYPES = {str, int, float, bool} SUPPORTED_DEFAULT_OPTIONAL_TYPES = { Optional, Optional[str], Optional[int], Optional[float], Optional[bool] } SUPPORTED_DEFAULT_LIST_TYPES = { List, List[str], List[int], List[float], List[bool] } SUPPORTED_DEFAULT_SET_TYPES = {Set, Set[str], Set[int], Set[float], Set[bool]} SUPPORTED_DEFAULT_COLLECTION_TYPES = SUPPORTED_DEFAULT_LIST_TYPES | SUPPORTED_DEFAULT_SET_TYPES | { Tuple } SUPPORTED_DEFAULT_BOXED_TYPES = SUPPORTED_DEFAULT_OPTIONAL_TYPES | SUPPORTED_DEFAULT_COLLECTION_TYPES SUPPORTED_DEFAULT_TYPES = set.union(SUPPORTED_DEFAULT_BASE_TYPES, SUPPORTED_DEFAULT_OPTIONAL_TYPES,
def get(cls, type_or_hint, *, is_argument: bool = False) -> "TypeChecker": # This ensures the validity of the type passed (see typing documentation for info) type_or_hint = is_valid_type(type_or_hint, "Invalid type.", is_argument) if type_or_hint is Any: return AnyTypeChecker() if is_type(type_or_hint): return TypeTypeChecker.make(type_or_hint, is_argument) if is_literal_type(type_or_hint): return LiteralTypeChecker.make(type_or_hint, is_argument) if is_generic_type(type_or_hint): origin = get_origin_or_self(type_or_hint) if issubclass(origin, MappingCol): return MappingTypeChecker.make(type_or_hint, is_argument) if issubclass(origin, Collection): return CollectionTypeChecker.make(type_or_hint, is_argument) # CONSIDER: how to cater for exhaustible generators? if issubclass(origin, Iterable): raise NotImplementedError( "No type-checker is setup for iterables that exhaust.") return GenericTypeChecker.make(type_or_hint, is_argument) if is_tuple_type(type_or_hint): return TupleTypeChecker.make(type_or_hint, is_argument) if is_callable_type(type_or_hint): return CallableTypeChecker.make(type_or_hint, is_argument) if isclass(type_or_hint): if is_typed_dict(type_or_hint): return TypedDictChecker.make(type_or_hint, is_argument) return ConcreteTypeChecker.make(type_or_hint, is_argument) if is_union_type(type_or_hint): return UnionTypeChecker.make(type_or_hint, is_argument) if is_typevar(type_or_hint): bound_type = get_bound(type_or_hint) if bound_type: return cls.get(bound_type) constraints = get_constraints(type_or_hint) if constraints: union_type_checkers = tuple( cls.get(type_) for type_ in constraints) return UnionTypeChecker(Union.__getitem__(constraints), union_type_checkers) else: return AnyTypeChecker() if is_new_type(type_or_hint): super_type = getattr(type_or_hint, "__supertype__", None) if super_type is None: raise TypeError( f"No supertype for NewType: {type_or_hint}. This is not allowed." ) return cls.get(super_type) if is_forward_ref(type_or_hint): return ForwardTypeChecker.make(type_or_hint, is_argument=is_argument) if is_classvar(type_or_hint): var_type = get_args(type_or_hint, evaluate=True)[0] return cls.get(var_type) raise NotImplementedError( f"No {TypeChecker.__qualname__} is available for type or hint: '{type_or_hint}'" )
def _add_argument(self, *name_or_flags, **kwargs) -> None: """Adds an argument to self (i.e. the super class ArgumentParser). Sets the following attributes of kwargs when not explicitly provided: - type: Set to the type annotation of the argument. - default: Set to the default value of the argument (if provided). - required: True if a default value of the argument is not provided, False otherwise. - action: Set to "store_true" if the argument is a required bool or a bool with default value False. Set to "store_false" if the argument is a bool with default value True. - nargs: Set to "*" if the type annotation is List[str], List[int], or List[float]. - help: Set to the argument documentation from the class docstring. :param name_or_flags: Either a name or a list of option strings, e.g. foo or -f, --foo. :param kwargs: Keyword arguments. """ # Get variable name variable = get_argument_name(*name_or_flags) # Get default if not specified if hasattr(self, variable): kwargs['default'] = kwargs.get('default', getattr(self, variable)) # Set required if option arg if (is_option_arg(*name_or_flags) and variable != 'help' and 'default' not in kwargs and kwargs.get('action') != 'version'): kwargs['required'] = kwargs.get('required', not hasattr(self, variable)) # Set help if necessary if 'help' not in kwargs: kwargs['help'] = '(' # Type if variable in self._annotations: kwargs['help'] += type_to_str( self._annotations[variable]) + ', ' # Required/default if kwargs.get('required', False): kwargs['help'] += 'required' else: kwargs['help'] += f'default={kwargs.get("default", None)}' kwargs['help'] += ')' # Description if variable in self.class_variables: kwargs[ 'help'] += ' ' + self.class_variables[variable]['comment'] # Set other kwargs where not provided if variable in self._annotations: # Get type annotation var_type = self._annotations[variable] # If type is not explicitly provided, set it if it's one of our supported default types if 'type' not in kwargs: # First check whether it is a literal type or a boxed literal type if is_literal_type(var_type): var_type, kwargs['choices'] = get_literals( var_type, variable) elif (get_origin(var_type) in (List, list, Set, set) and len(get_args(var_type)) > 0 and is_literal_type(get_args(var_type)[0])): var_type, kwargs['choices'] = get_literals( get_args(var_type)[0], variable) kwargs['nargs'] = kwargs.get('nargs', '*') # Handle Tuple type (with type args) by extracting types of Tuple elements and enforcing them elif get_origin(var_type) in (Tuple, tuple) and len( get_args(var_type)) > 0: loop = False types = get_args(var_type) # Don't allow Tuple[()] if len(types) == 1 and types[0] == tuple(): raise ValueError( 'Empty Tuples (i.e. Tuple[()]) are not a valid Tap type ' 'because they have no arguments.') # Handle Tuple[type, ...] if len(types) == 2 and types[1] == Ellipsis: types = types[0:1] loop = True kwargs['nargs'] = '*' else: kwargs['nargs'] = len(types) var_type = TupleTypeEnforcer(types=types, loop=loop) # To identify an Optional type, check if it's a union of a None and something else elif (is_union_type(var_type) and len(get_args(var_type)) == 2 and isinstance(None, get_args(var_type)[1]) and is_literal_type(get_args(var_type)[0])): var_type, kwargs['choices'] = get_literals( get_args(var_type)[0], variable) elif var_type not in SUPPORTED_DEFAULT_TYPES: is_required = kwargs.get('required', False) arg_params = 'required=True' if is_required else f'default={getattr(self, variable)}' raise ValueError( f'Variable "{variable}" has type "{var_type}" which is not supported by default.\n' f'Please explicitly add the argument to the parser by writing:\n\n' f'def configure(self) -> None:\n' f' self.add_argument("--{variable}", type=func, {arg_params})\n\n' f'where "func" maps from str to {var_type}.') if var_type in SUPPORTED_DEFAULT_BOXED_TYPES: # If List or Set type, set nargs if (var_type in SUPPORTED_DEFAULT_COLLECTION_TYPES and kwargs.get('action') not in {'append', 'append_const'}): kwargs['nargs'] = kwargs.get('nargs', '*') # Extract boxed type for Optional, List, Set arg_types = get_args(var_type) # Set defaults type to str for Type and Type[()] if len(arg_types) == 0 or arg_types[0] == EMPTY_TYPE: var_type = str else: var_type = arg_types[0] # Handle the cases of Optional[bool], List[bool], Set[bool] if var_type == bool: var_type = boolean_type # If bool then set action, otherwise set type if var_type == bool: if self._explicit_bool: kwargs['type'] = boolean_type kwargs['choices'] = [ True, False ] # this makes the help message more helpful else: action_cond = "true" if kwargs.get( "required", False) or not kwargs["default"] else "false" kwargs['action'] = kwargs.get('action', f'store_{action_cond}') elif kwargs.get('action') not in {'count', 'append_const'}: kwargs['type'] = var_type if self._underscores_to_dashes: name_or_flags = [ name_or_flag.replace('_', '-') for name_or_flag in name_or_flags ] super(Tap, self).add_argument(*name_or_flags, **kwargs)
def make(cls, type_or_hint, is_argument: bool) -> "LiteralTypeChecker": literals_values = get_args(type_or_hint, evaluate=True) return cls(type_or_hint, literals_values)
def transform(self, node, eligible_node_types=None): """ Transform python ast into GTScript ast recursively. """ if eligible_node_types is None: eligible_node_types = [gtscript_ast.Computation] if isinstance(node, ast.AST): is_leaf_node = len(list(ast.iter_fields(node))) == 0 if is_leaf_node: if not type(node) in self.leaf_map: raise ValueError( f"Leaf node of type {type(node)}, found in the python ast, can not be mapped." ) return self.leaf_map[type(node)] else: # visit node fields and transform # TODO(tehrengruber): check if multiple nodes match and throw an error in that case # disadvantage: templates can be ambiguous for node_type in eligible_node_types: if not hasattr(self.Patterns, node_type.__name__): continue captures = {} if not anm.match(node, getattr(self.Patterns, node_type.__name__), captures=captures): continue module = sys.modules[node_type.__module__] transformed_captures = {} for name, capture in captures.items(): assert ( name in node_type.__annotations__ ), f"Invalid capture. No field named `{name}` in `{str(node_type)}`" field_type = node_type.__annotations__[name] if typing_inspect.get_origin(field_type) == list: # determine eligible capture types el_type = typing_inspect.get_args(field_type)[0] eligible_capture_types = self._all_subclasses( el_type, module=module) # transform captures recursively transformed_captures[name] = [] for child_capture in capture: transformed_captures[name].append( self.transform(child_capture, eligible_capture_types)) else: # determine eligible capture types eligible_capture_types = self._all_subclasses( field_type, module=module) # transform captures recursively transformed_captures[name] = self.transform( capture, eligible_capture_types) return node_type(**transformed_captures) raise ValueError("Expected a node of type {}".format(", ".join( [ent.__name__ for ent in eligible_node_types]))) elif type(node) in eligible_node_types: return node raise ValueError("Expected a node of type {}, but got {}".format( {*eligible_node_types, ast.AST}, type(node)))
def make(cls, type_or_hint, is_argument: bool) -> "UnionTypeChecker": union_types = get_args(type_or_hint, evaluate=True) union_type_checkers = tuple(cls.get(type_) for type_ in union_types) return cls(type_or_hint, union_type_checkers)
def field_for_schema( typ: type, default=marshmallow.missing, metadata: Mapping[str, Any] = None) -> marshmallow.fields.Field: """ Get a marshmallow Field corresponding to the given python type. The metadata of the dataclass field is used as arguments to the marshmallow Field. >>> int_field = field_for_schema(int, default=9, metadata=dict(required=True)) >>> int_field.__class__ <class 'marshmallow.fields.Integer'> >>> int_field.default 9 >>> int_field.required True >>> field_for_schema(Dict[str,str]).__class__ <class 'marshmallow.fields.Dict'> >>> field_for_schema(Callable[[str],str]).__class__ <class 'marshmallow.fields.Function'> >>> field_for_schema(str, metadata={"marshmallow_field": marshmallow.fields.Url()}).__class__ <class 'marshmallow.fields.Url'> >>> field_for_schema(Optional[str]).__class__ <class 'marshmallow.fields.String'> >>> field_for_schema(NewType('UserId', int)).__class__ <class 'marshmallow.fields.Integer'> >>> field_for_schema(NewType('UserId', int), default=0).default 0 >>> class Color(Enum): ... red = 1 >>> field_for_schema(Color).__class__ <class 'marshmallow_enum.EnumField'> >>> field_for_schema(Any).__class__ <class 'marshmallow.fields.Raw'> """ metadata = {} if metadata is None else dict(metadata) metadata.setdefault("required", True) if default is not marshmallow.missing: metadata.setdefault("default", default) metadata.setdefault("missing", default) # If the field was already defined by the user predefined_field = metadata.get("marshmallow_field") if predefined_field: return predefined_field # Base types if typ in _native_to_marshmallow: return _native_to_marshmallow[typ](**metadata) # Generic types origin: type = typing_inspect.get_origin(typ) if origin in (list, List): list_elements_type = typing_inspect.get_args(typ, True)[0] return marshmallow.fields.List(field_for_schema(list_elements_type), **metadata) elif origin in (dict, Dict): key_type, value_type = typing_inspect.get_args(typ, True) return marshmallow.fields.Dict( keys=field_for_schema(key_type), values=field_for_schema(value_type), **metadata, ) elif origin in (collections.abc.Callable, Callable): return marshmallow.fields.Function(**metadata) elif typing_inspect.is_optional_type(typ): subtyp = next(t for t in typing_inspect.get_args(typ) if t is not NoneType) # Treat optional types as types with a None default metadata["default"] = metadata.get("default", None) metadata["missing"] = metadata.get("missing", None) metadata["required"] = False return field_for_schema(subtyp, metadata=metadata) # typing.NewType returns a function with a __supertype__ attribute newtype_supertype = getattr(typ, "__supertype__", None) if newtype_supertype and inspect.isfunction(typ): metadata.setdefault("description", typ.__name__) return field_for_schema(newtype_supertype, metadata=metadata, default=default) # enumerations if type(typ) is EnumMeta: import marshmallow_enum return marshmallow_enum.EnumField(typ, **metadata) # Nested attr forward_reference = getattr(typ, "__forward_arg__", None) nested = forward_reference or class_schema(typ) return marshmallow.fields.Nested(nested, **metadata)
def is_value_of_type( # noqa: C901 "too complex" # pyre-fixme[2]: Parameter annotation cannot be `Any`. value: Any, # pyre-fixme[2]: Parameter annotation cannot be `Any`. expected_type: Any, invariant_check: bool = False, ) -> bool: """ This method attempts to verify a given value is of a given type. If the type is not supported, it returns True but throws an exception in tests. It is similar to typeguard / enforce pypi modules, but neither of those have permissive options for types they do not support. Supported types for now: - List/Set/Iterable - Dict/Mapping - base types (str, int, etc) - Literal - Unions - Tuples - Concrete Classes - ClassVar Not supported: - Callables, which will likely not be used in XHP anyways - Generics, Type Vars (treated as Any) - Generators - Forward Refs -- use `typing.get_type_hints` to resolve these - Type[...] """ if is_classvar(expected_type): # `ClassVar` (no subscript) is implicitly `ClassVar[Any]` if hasattr(expected_type, "__type__"): # py36 expected_type = expected_type.__type__ or Any else: # py37+ classvar_args = get_args(expected_type) expected_type = (classvar_args[0] or Any) if classvar_args else Any if is_typevar(expected_type): # treat this the same as Any # TODO: evaluate bounds return True expected_origin_type = get_origin(expected_type) or expected_type if expected_origin_type == Any: return True elif is_union_type(expected_type): return any( is_value_of_type(value, subtype) for subtype in expected_type.__args__ ) elif isinstance(expected_origin_type, type(Literal)): if hasattr(expected_type, "__values__"): # py36 literal_values = expected_type.__values__ else: # py37+ literal_values = get_args(expected_type, evaluate=True) return any(value == literal for literal in literal_values) elif isinstance(expected_origin_type, ForwardRef): # not much we can do here for now, lets just return :( return True # Handle `Tuple[A, B, C]`. # We don't want to include Tuple subclasses, like NamedTuple, because they're # unlikely to behave similarly. elif expected_origin_type in [Tuple, tuple]: # py36 uses Tuple, py37+ uses tuple if not isinstance(value, tuple): return False type_args = get_args(expected_type, evaluate=True) if len(type_args) == 0: # `Tuple` (no subscript) is implicitly `Tuple[Any, ...]` return True if type_args is None: return True if len(value) != len(type_args): return False # TODO: Handle `Tuple[T, ...]` like `Iterable[T]` for subvalue, subtype in zip(value, type_args): if not is_value_of_type(subvalue, subtype): return False return True elif issubclass(expected_origin_type, Mapping): # We're expecting *some* kind of Mapping, but we also want to make sure it's # the correct Mapping subtype. That means we want {a: b, c: d} to match Mapping, # MutableMapping, and Dict, but we don't want MappingProxyType({a: b, c: d}) to # match MutableMapping or Dict. if not issubclass(type(value), expected_origin_type): return False type_args = get_args(expected_type, evaluate=True) if len(type_args) == 0: # `Mapping` (no subscript) is implicitly `Mapping[Any, Any]`. return True invariant_check = issubclass(expected_origin_type, MutableMapping) for subkey, subvalue in value.items(): if not is_value_of_type( subkey, type_args[0], # key type is always invariant invariant_check=True, ): return False if not is_value_of_type( subvalue, type_args[1], invariant_check=invariant_check ): return False return True # While this does technically work fine for str and bytes (they are iterables), it's # better to use the default isinstance behavior for them. # # Similarly, tuple subclasses tend to have pretty different behavior, and we should # fall back to the default check. elif issubclass(expected_origin_type, Iterable) and not issubclass( expected_origin_type, (str, bytes, tuple), ): # We know this thing is *some* kind of Iterable, but we want to # allow subclasses. That means we want [1,2,3] to match both # List[int] and Iterable[int], but we do NOT want that # to match Set[int]. if not issubclass(type(value), expected_origin_type): return False type_args = get_args(expected_type, evaluate=True) if len(type_args) == 0: # `Iterable` (no subscript) is implicitly `Iterable[Any]`. return True # We invariant check if its a mutable sequence invariant_check = issubclass(expected_origin_type, MutableSequence) return all( is_value_of_type(subvalue, type_args[0], invariant_check=invariant_check) for subvalue in value ) try: if not invariant_check: if expected_type is float: return isinstance(value, (int, float)) else: return isinstance(value, expected_type) return type(value) is expected_type except Exception as e: raise NotImplementedError( f"the value {value!r} was compared to type {expected_type!r} " + f"but support for that has not been implemented yet! Exception: {e!r}" )
def test_get_forward_arg(self): tp = List["FRef"] fr = get_args(tp)[0] self.assertEqual(get_forward_arg(fr), "FRef") self.assertEqual(get_forward_arg(tp), None)
def decode_generic_set(decoder, typ, json_value): if not (inspect.is_generic_type(typ) and inspect.get_origin(typ) == set): return Unsupported check_type(list, json_value) item_type = inspect.get_args(typ)[0] return set([decoder.decode(item_type, item) for item in json_value])
def from_pyvalue(cls, data, *, allow_missing=False): if not isinstance(data, dict): raise cls._err(f'expected a dict value, got {type(data)!r}') spec = config.get_settings() data = dict(data) tname = data.pop('_tname', None) if tname is not None: if '::' in tname: tname = s_name.Name(tname).name cls = spec.get_type_by_name(tname) fields = {f.name: f for f in dataclasses.fields(cls)} items = {} inv_keys = [] for fieldname, value in data.items(): field = fields.get(fieldname) if field is None: if value is None: # This may happen when data is produced by # a polymorphic config query. pass else: inv_keys.append(fieldname) continue f_type = field.type if value is None: # Config queries return empty pointer values as None. continue if typing_inspect.is_generic_type(f_type): container = typing_inspect.get_origin(f_type) if container not in (frozenset, list): raise RuntimeError( f'invalid type annotation on ' f'{cls.__name__}.{fieldname}: ' f'{f_type!r} is not supported') eltype = typing_inspect.get_args(f_type, evaluate=True)[0] if isinstance(value, eltype): value = container((value,)) elif (typeutils.is_container(value) and all(isinstance(v, eltype) for v in value)): value = container(value) else: raise cls._err( f'invalid {fieldname!r} field value: expecting ' f'{eltype.__name__} or a list thereof, but got ' f'{type(value).__name__}' ) elif (issubclass(f_type, CompositeConfigType) and isinstance(value, dict)): tname = value.get('_tname', None) if tname is not None: if '::' in tname: tname = s_name.Name(tname).name actual_f_type = spec.get_type_by_name(tname) else: actual_f_type = f_type value['_tname'] = f_type.__name__ value = actual_f_type.from_pyvalue(value) elif not isinstance(value, f_type): raise cls._err( f'invalid {fieldname!r} field value: expecting ' f'{f_type.__name__}, but got {type(value).__name__}' ) items[fieldname] = value if inv_keys: inv_keys = ', '.join(repr(r) for r in inv_keys) raise cls._err(f'unknown fields: {inv_keys}') for fieldname, field in fields.items(): if fieldname not in items and field.default is dataclasses.MISSING: if allow_missing: items[fieldname] = None else: raise cls._err(f'missing required field: {fieldname!r}') try: return cls(**items) except (TypeError, ValueError) as ex: raise cls._err(str(ex))
def _to_type_member(type_: type, metadata_: Dict[Any, Any]) -> ElementDecl: """Resolve attribute type hint to type related part of element declaration.""" result = ElementDecl() # Get argument of List[...] type hint is_list = get_origin(type_) is not None and get_origin(type_) is list if is_list: type_ = get_args(type_)[0] key_ = metadata_.get('key', None) if key_ is not None and type_ is str: key_type = ClassInfo.get_type(key_) result.key = _create_type_declaration_key(key_type.__module__, key_) return result meta_type = metadata_.get('type', None) # Primitive types if type_ is str: result.value = ValueDecl(type=ValueParamType.String) elif type_ is bool: result.value = ValueDecl(type=ValueParamType.NullableBool) elif type_ is float: result.value = ValueDecl(type=ValueParamType.NullableDouble) elif type_ is ObjectId: result.value = ValueDecl(type=ValueParamType.NullableTemporalId) # Date additional cases elif type_ is dt.date: result.value = result.value = ValueDecl( type=ValueParamType.NullableDate) elif type_ is dt.time: result.value = result.value = ValueDecl( type=ValueParamType.NullableTime) # dt.datetime depends on metadata elif type_ is dt.datetime: if meta_type == 'Instant': result.value = ValueDecl(type=ValueParamType.NullableInstant) elif meta_type is None: result.value = ValueDecl(type=ValueParamType.NullableDateTime) else: raise Exception( f'Unexpected dt.datetime and metadata type combination: dt.datetime + {type_.__name__}' ) # Restore int/long/Local... separation using info from metadata elif type_ is int: if meta_type == 'long': result.value = ValueDecl(type=ValueParamType.NullableLong) elif meta_type == 'LocalDate': result.value = ValueDecl(type=ValueParamType.NullableDate) elif meta_type == 'LocalTime': result.value = ValueDecl(type=ValueParamType.NullableTime) elif meta_type == 'LocalMinute': result.value = ValueDecl(type=ValueParamType.NullableMinute) elif meta_type == 'LocalDateTime': result.value = ValueDecl(type=ValueParamType.NullableDateTime) elif meta_type is None: result.value = ValueDecl(type=ValueParamType.NullableInt) else: raise Exception( f'Unexpected int and metadata type combination: int + {type_.__name__}' ) elif issubclass(type_, Data): result.data = _create_type_declaration_key(type_.__module__, type_.__name__) elif issubclass(type_, IntEnum): result.enum = _create_enum_declaration_key(str(type_.__module__), type_.__name__) else: raise Exception(f'Unexpected type {type_.__name__}') return result
def field_for_schema( typ: type, default=marshmallow.missing, metadata: Mapping[str, Any] = None, base_schema: Optional[Type[marshmallow.Schema]] = None, ) -> marshmallow.fields.Field: """ Get a marshmallow Field corresponding to the given python type. The metadata of the dataclass field is used as arguments to the marshmallow Field. :param base_schema: marshmallow schema used as a base class when deriving dataclass schema >>> int_field = field_for_schema(int, default=9, metadata=dict(required=True)) >>> int_field.__class__ <class 'marshmallow.fields.Integer'> >>> int_field.default 9 >>> int_field.required True >>> field_for_schema(Dict[str,str]).__class__ <class 'marshmallow.fields.Dict'> >>> field_for_schema(str, metadata={"marshmallow_field": marshmallow.fields.Url()}).__class__ <class 'marshmallow.fields.Url'> >>> field_for_schema(Optional[str]).__class__ <class 'marshmallow.fields.String'> >>> from enum import Enum >>> import marshmallow_enum >>> field_for_schema(Enum("X", "a b c")).__class__ <class 'marshmallow_enum.EnumField'> >>> import typing >>> field_for_schema(typing.Union[int,str]).__class__ <class 'marshmallow_union.Union'> >>> field_for_schema(typing.NewType('UserId', int)).__class__ <class 'marshmallow.fields.Integer'> >>> field_for_schema(typing.NewType('UserId', int), default=0).default 0 >>> class Color(Enum): ... red = 1 >>> field_for_schema(Color).__class__ <class 'marshmallow_enum.EnumField'> >>> field_for_schema(Any).__class__ <class 'marshmallow.fields.Raw'> """ metadata = {} if metadata is None else dict(metadata) if default is not marshmallow.missing: metadata.setdefault("default", default) if not metadata.get( "required"): # 'missing' must not be set for required fields. metadata.setdefault("missing", default) else: metadata.setdefault("required", True) # If the field was already defined by the user predefined_field = metadata.get("marshmallow_field") if predefined_field: return predefined_field # Base types if typ in _native_to_marshmallow: return _native_to_marshmallow[typ](**metadata) # Generic types origin = typing_inspect.get_origin(typ) if origin: arguments = typing_inspect.get_args(typ, True) if origin in (list, List): return marshmallow.fields.List( field_for_schema(arguments[0], base_schema=base_schema), **metadata) if origin in (tuple, Tuple): return marshmallow.fields.Tuple( tuple( field_for_schema(arg, base_schema=base_schema) for arg in arguments), **metadata, ) elif origin in (dict, Dict): return marshmallow.fields.Dict( keys=field_for_schema(arguments[0], base_schema=base_schema), values=field_for_schema(arguments[1], base_schema=base_schema), **metadata, ) elif typing_inspect.is_optional_type(typ): subtyp = next(t for t in arguments if t is not NoneType) # type: ignore # Treat optional types as types with a None default metadata["default"] = metadata.get("default", None) metadata["missing"] = metadata.get("missing", None) metadata["required"] = False return field_for_schema(subtyp, metadata=metadata, base_schema=base_schema) elif typing_inspect.is_union_type(typ): subfields = [ field_for_schema(subtyp, metadata=metadata, base_schema=base_schema) for subtyp in arguments ] import marshmallow_union return marshmallow_union.Union(subfields, **metadata) # typing.NewType returns a function with a __supertype__ attribute newtype_supertype = getattr(typ, "__supertype__", None) if newtype_supertype and inspect.isfunction(typ): # Add the information coming our custom NewType implementation metadata = { "description": typ.__name__, **getattr(typ, "_marshmallow_args", {}), **metadata, } field = getattr(typ, "_marshmallow_field", None) if field: return field(**metadata) else: return field_for_schema( newtype_supertype, metadata=metadata, default=default, base_schema=base_schema, ) # enumerations if type(typ) is EnumMeta: import marshmallow_enum return marshmallow_enum.EnumField(typ, **metadata) # Nested dataclasses forward_reference = getattr(typ, "__forward_arg__", None) nested = forward_reference or class_schema(typ, base_schema=base_schema) return marshmallow.fields.Nested(nested, **metadata)