def _process_dataclass(self, cls: type, obj: Any, fieldpath: str) -> Any: # pylint: disable=too-many-locals # pylint: disable=too-many-branches prep = PrepSession(explicit=False).prep_dataclass(type(obj), recursion_level=0) assert prep is not None fields = dataclasses.fields(obj) out: Optional[dict[str, Any]] = {} if self._create else None for field in fields: fieldname = field.name if fieldpath: subfieldpath = f'{fieldpath}.{fieldname}' else: subfieldpath = fieldname anntype = prep.annotations[fieldname] value = getattr(obj, fieldname) anntype, ioattrs = _parse_annotated(anntype) # If we're not storing default values for this fella, # we can skip all output processing if we've got a default value. if ioattrs is not None and not ioattrs.store_default: default_factory: Any = field.default_factory if default_factory is not dataclasses.MISSING: if default_factory() == value: continue elif field.default is not dataclasses.MISSING: if field.default == value: continue else: raise RuntimeError( f'Field {fieldname} of {cls.__name__} has' f' neither a default nor a default_factory;' f' store_default=False cannot be set for it.' f' (AND THIS SHOULD HAVE BEEN CAUGHT IN PREP!)') outvalue = self._process_value(cls, subfieldpath, anntype, value, ioattrs) if self._create: assert out is not None storagename = (fieldname if (ioattrs is None or ioattrs.storagename is None) else ioattrs.storagename) out[storagename] = outvalue # If there's extra-attrs stored on us, check/include them. extra_attrs = getattr(obj, EXTRA_ATTRS_ATTR, None) if isinstance(extra_attrs, dict): if not _is_valid_for_codec(extra_attrs, self._codec): raise TypeError( f'Extra attrs on {fieldpath} contains data type(s)' f' not supported by json.') if self._create: assert out is not None out.update(extra_attrs) return out
def __getattr__(self, name: str) -> _PathCapture: # We only allow diving into sub-objects if we are a dataclass. if not self._is_dataclass: raise TypeError( f"Field path cannot include attribute '{name}' " f'under parent {self._cls}; parent types must be dataclasses.') prep = PrepSession(explicit=False).prep_dataclass(self._cls, recursion_level=0) assert prep is not None try: anntype = prep.annotations[name] except KeyError as exc: raise AttributeError(f'{type(self)} has no {name} field.') from exc anntype, ioattrs = _parse_annotated(anntype) storagename = (name if (ioattrs is None or ioattrs.storagename is None) else ioattrs.storagename) origin = _get_origin(anntype) return _PathCapture(origin, pathparts=self._pathparts + [storagename])
def _dataclass_from_input(self, cls: type, fieldpath: str, values: dict) -> Any: """Given a dict, instantiates a dataclass of the given type. The dict must be in the json-friendly format as emitted from dataclass_to_dict. This means that sequence values such as tuples or sets should be passed as lists, enums should be passed as their associated values, and nested dataclasses should be passed as dicts. """ # pylint: disable=too-many-locals # pylint: disable=too-many-branches if not isinstance(values, dict): raise TypeError( f'Expected a dict for {fieldpath} on {cls.__name__};' f' got a {type(values)}.') prep = PrepSession(explicit=False).prep_dataclass(cls, recursion_level=0) assert prep is not None extra_attrs = {} # noinspection PyDataclass fields = dataclasses.fields(cls) fields_by_name = {f.name: f for f in fields} # Preprocess all fields to convert Annotated[] to contained types # and IOAttrs. parsed_field_annotations = { f.name: _parse_annotated(prep.annotations[f.name]) for f in fields } # Go through all data in the input, converting it to either dataclass # args or extra data. args: dict[str, Any] = {} for rawkey, value in values.items(): key = prep.storage_names_to_attr_names.get(rawkey, rawkey) field = fields_by_name.get(key) # Store unknown attrs off to the side (or error if desired). if field is None: if self._allow_unknown_attrs: if self._discard_unknown_attrs: continue # Treat this like 'Any' data; ensure that it is valid # raw json. if not _is_valid_for_codec(value, self._codec): raise TypeError( f'Unknown attr \'{key}\'' f' on {fieldpath} contains data type(s)' f' not supported by the specified codec' f' ({self._codec.name}).') extra_attrs[key] = value else: raise AttributeError( f"'{cls.__name__}' has no '{key}' field.") else: fieldname = field.name anntype, ioattrs = parsed_field_annotations[fieldname] subfieldpath = (f'{fieldpath}.{fieldname}' if fieldpath else fieldname) args[key] = self._value_from_input(cls, subfieldpath, anntype, value, ioattrs) # Go through all fields looking for any not yet present in our data. # If we find any such fields with a soft-default value or factory # defined, inject that soft value into our args. for key, aparsed in parsed_field_annotations.items(): if key in args: continue ioattrs = aparsed[1] if (ioattrs is not None and (ioattrs.soft_default is not ioattrs.MISSING or ioattrs.soft_default_factory is not ioattrs.MISSING)): if ioattrs.soft_default is not ioattrs.MISSING: soft_default = ioattrs.soft_default else: assert callable(ioattrs.soft_default_factory) soft_default = ioattrs.soft_default_factory() args[key] = soft_default # Make sure these values are valid since we didn't run # them through our normal input type checking. self._type_check_soft_default( value=soft_default, anntype=aparsed[0], fieldpath=(f'{fieldpath}.{key}' if fieldpath else key)) try: out = cls(**args) except Exception as exc: raise ValueError(f'Error instantiating class {cls.__name__}' f' at {fieldpath}: {exc}') from exc if extra_attrs: setattr(out, EXTRA_ATTRS_ATTR, extra_attrs) return out
def _dataclass_from_input(self, cls: type, fieldpath: str, values: dict) -> Any: """Given a dict, instantiates a dataclass of the given type. The dict must be in the json-friendly format as emitted from dataclass_to_dict. This means that sequence values such as tuples or sets should be passed as lists, enums should be passed as their associated values, and nested dataclasses should be passed as dicts. """ # pylint: disable=too-many-locals if not isinstance(values, dict): raise TypeError( f'Expected a dict for {fieldpath} on {cls.__name__};' f' got a {type(values)}.') prep = PrepSession(explicit=False).prep_dataclass(cls, recursion_level=0) assert prep is not None extra_attrs = {} # noinspection PyDataclass fields = dataclasses.fields(cls) fields_by_name = {f.name: f for f in fields} args: dict[str, Any] = {} for rawkey, value in values.items(): key = prep.storage_names_to_attr_names.get(rawkey, rawkey) field = fields_by_name.get(key) # Store unknown attrs off to the side (or error if desired). if field is None: if self._allow_unknown_attrs: if self._discard_unknown_attrs: continue # Treat this like 'Any' data; ensure that it is valid # raw json. if not _is_valid_for_codec(value, self._codec): raise TypeError( f'Unknown attr \'{key}\'' f' on {fieldpath} contains data type(s)' f' not supported by the specified codec' f' ({self._codec.name}).') extra_attrs[key] = value else: raise AttributeError( f"'{cls.__name__}' has no '{key}' field.") else: fieldname = field.name anntype = prep.annotations[fieldname] anntype, ioattrs = _parse_annotated(anntype) subfieldpath = (f'{fieldpath}.{fieldname}' if fieldpath else fieldname) args[key] = self._value_from_input(cls, subfieldpath, anntype, value, ioattrs) try: out = cls(**args) except Exception as exc: raise RuntimeError(f'Error instantiating class {cls.__name__}' f' at {fieldpath}: {exc}') from exc if extra_attrs: setattr(out, EXTRA_ATTRS_ATTR, extra_attrs) return out
def prep_dataclass(self, cls: type, recursion_level: int) -> Optional[PrepData]: """Run prep on a dataclass if necessary and return its prep data. The only case where this will return None is for recursive types if the type is already being prepped higher in the call order. """ # pylint: disable=too-many-locals # pylint: disable=too-many-branches # We should only need to do this once per dataclass. existing_data = getattr(cls, PREP_ATTR, None) if existing_data is not None: assert isinstance(existing_data, PrepData) return existing_data # Sanity check. # Note that we now support recursive types via the PREP_SESSION_ATTR, # so we theoretically shouldn't run into this this. if recursion_level > MAX_RECURSION: raise RuntimeError('Max recursion exceeded.') # We should only be passed classes which are dataclasses. if not isinstance(cls, type) or not dataclasses.is_dataclass(cls): raise TypeError(f'Passed arg {cls} is not a dataclass type.') # Add a pointer to the prep-session while doing the prep. # This way we can ignore types that we're already in the process # of prepping and can support recursive types. existing_prep = getattr(cls, PREP_SESSION_ATTR, None) if existing_prep is not None: if existing_prep is self: return None # We shouldn't need to support failed preps # or preps from multiple threads at once. raise RuntimeError('Found existing in-progress prep.') setattr(cls, PREP_SESSION_ATTR, self) # Generate a warning on non-explicit preps; we prefer prep to # happen explicitly at runtime so errors can be detected early on. if not self.explicit: logging.warning( 'efro.dataclassio: implicitly prepping dataclass: %s.' ' It is highly recommended to explicitly prep dataclasses' ' as soon as possible after definition (via' ' efro.dataclassio.ioprep() or the' ' @efro.dataclassio.ioprepped decorator).', cls) try: # NOTE: Now passing the class' __dict__ (vars()) as locals # which allows us to pick up nested classes, etc. resolved_annotations = get_type_hints(cls, localns=vars(cls), globalns=self.globalns, include_extras=True) # pylint: enable=unexpected-keyword-arg except Exception as exc: raise TypeError( f'dataclassio prep for {cls} failed with error: {exc}.' f' Make sure all types used in annotations are defined' f' at the module or class level or add them as part of an' f' explicit prep call.') from exc # noinspection PyDataclass fields = dataclasses.fields(cls) fields_by_name = {f.name: f for f in fields} all_storage_names: set[str] = set() storage_names_to_attr_names: dict[str, str] = {} # Ok; we've resolved actual types for this dataclass. # now recurse through them, verifying that we support all contained # types and prepping any contained dataclass types. for attrname, anntype in resolved_annotations.items(): anntype, ioattrs = _parse_annotated(anntype) # If we found attached IOAttrs data, make sure it contains # valid values for the field it is attached to. if ioattrs is not None: ioattrs.validate_for_field(cls, fields_by_name[attrname]) if ioattrs.storagename is not None: storagename = ioattrs.storagename storage_names_to_attr_names[ioattrs.storagename] = attrname else: storagename = attrname else: storagename = attrname # Make sure we don't have any clashes in our storage names. if storagename in all_storage_names: raise TypeError(f'Multiple attrs on {cls} are using' f' storage-name \'{storagename}\'') all_storage_names.add(storagename) self.prep_type(cls, attrname, anntype, recursion_level=recursion_level + 1) # Success! Store our resolved stuff with the class and we're done. prepdata = PrepData( annotations=resolved_annotations, storage_names_to_attr_names=storage_names_to_attr_names) setattr(cls, PREP_ATTR, prepdata) # Clear our prep-session tag. assert getattr(cls, PREP_SESSION_ATTR, None) is self delattr(cls, PREP_SESSION_ATTR) return prepdata