def _process_dataclass(self, cls: type, obj: Any, fieldpath: str) -> Any:
        # pylint: disable=too-many-locals
        # pylint: disable=too-many-branches
        prep = PrepSession(explicit=False).prep_dataclass(type(obj),
                                                          recursion_level=0)
        assert prep is not None
        fields = dataclasses.fields(obj)
        out: Optional[dict[str, Any]] = {} if self._create else None
        for field in fields:
            fieldname = field.name
            if fieldpath:
                subfieldpath = f'{fieldpath}.{fieldname}'
            else:
                subfieldpath = fieldname
            anntype = prep.annotations[fieldname]
            value = getattr(obj, fieldname)

            anntype, ioattrs = _parse_annotated(anntype)

            # If we're not storing default values for this fella,
            # we can skip all output processing if we've got a default value.
            if ioattrs is not None and not ioattrs.store_default:
                default_factory: Any = field.default_factory
                if default_factory is not dataclasses.MISSING:
                    if default_factory() == value:
                        continue
                elif field.default is not dataclasses.MISSING:
                    if field.default == value:
                        continue
                else:
                    raise RuntimeError(
                        f'Field {fieldname} of {cls.__name__} has'
                        f' neither a default nor a default_factory;'
                        f' store_default=False cannot be set for it.'
                        f' (AND THIS SHOULD HAVE BEEN CAUGHT IN PREP!)')

            outvalue = self._process_value(cls, subfieldpath, anntype, value,
                                           ioattrs)
            if self._create:
                assert out is not None
                storagename = (fieldname if
                               (ioattrs is None or ioattrs.storagename is None)
                               else ioattrs.storagename)
                out[storagename] = outvalue

        # If there's extra-attrs stored on us, check/include them.
        extra_attrs = getattr(obj, EXTRA_ATTRS_ATTR, None)
        if isinstance(extra_attrs, dict):
            if not _is_valid_for_codec(extra_attrs, self._codec):
                raise TypeError(
                    f'Extra attrs on {fieldpath} contains data type(s)'
                    f' not supported by json.')
            if self._create:
                assert out is not None
                out.update(extra_attrs)
        return out
    def __getattr__(self, name: str) -> _PathCapture:

        # We only allow diving into sub-objects if we are a dataclass.
        if not self._is_dataclass:
            raise TypeError(
                f"Field path cannot include attribute '{name}' "
                f'under parent {self._cls}; parent types must be dataclasses.')

        prep = PrepSession(explicit=False).prep_dataclass(self._cls,
                                                          recursion_level=0)
        assert prep is not None
        try:
            anntype = prep.annotations[name]
        except KeyError as exc:
            raise AttributeError(f'{type(self)} has no {name} field.') from exc
        anntype, ioattrs = _parse_annotated(anntype)
        storagename = (name if (ioattrs is None or ioattrs.storagename is None)
                       else ioattrs.storagename)
        origin = _get_origin(anntype)
        return _PathCapture(origin, pathparts=self._pathparts + [storagename])
Exemple #3
0
    def _dataclass_from_input(self, cls: type, fieldpath: str,
                              values: dict) -> Any:
        """Given a dict, instantiates a dataclass of the given type.

        The dict must be in the json-friendly format as emitted from
        dataclass_to_dict. This means that sequence values such as tuples or
        sets should be passed as lists, enums should be passed as their
        associated values, and nested dataclasses should be passed as dicts.
        """
        # pylint: disable=too-many-locals
        # pylint: disable=too-many-branches
        if not isinstance(values, dict):
            raise TypeError(
                f'Expected a dict for {fieldpath} on {cls.__name__};'
                f' got a {type(values)}.')

        prep = PrepSession(explicit=False).prep_dataclass(cls,
                                                          recursion_level=0)
        assert prep is not None

        extra_attrs = {}

        # noinspection PyDataclass
        fields = dataclasses.fields(cls)
        fields_by_name = {f.name: f for f in fields}

        # Preprocess all fields to convert Annotated[] to contained types
        # and IOAttrs.
        parsed_field_annotations = {
            f.name: _parse_annotated(prep.annotations[f.name])
            for f in fields
        }

        # Go through all data in the input, converting it to either dataclass
        # args or extra data.
        args: dict[str, Any] = {}
        for rawkey, value in values.items():
            key = prep.storage_names_to_attr_names.get(rawkey, rawkey)
            field = fields_by_name.get(key)

            # Store unknown attrs off to the side (or error if desired).
            if field is None:
                if self._allow_unknown_attrs:
                    if self._discard_unknown_attrs:
                        continue

                    # Treat this like 'Any' data; ensure that it is valid
                    # raw json.
                    if not _is_valid_for_codec(value, self._codec):
                        raise TypeError(
                            f'Unknown attr \'{key}\''
                            f' on {fieldpath} contains data type(s)'
                            f' not supported by the specified codec'
                            f' ({self._codec.name}).')
                    extra_attrs[key] = value
                else:
                    raise AttributeError(
                        f"'{cls.__name__}' has no '{key}' field.")
            else:
                fieldname = field.name
                anntype, ioattrs = parsed_field_annotations[fieldname]
                subfieldpath = (f'{fieldpath}.{fieldname}'
                                if fieldpath else fieldname)
                args[key] = self._value_from_input(cls, subfieldpath, anntype,
                                                   value, ioattrs)

        # Go through all fields looking for any not yet present in our data.
        # If we find any such fields with a soft-default value or factory
        # defined, inject that soft value into our args.
        for key, aparsed in parsed_field_annotations.items():
            if key in args:
                continue
            ioattrs = aparsed[1]
            if (ioattrs is not None and
                (ioattrs.soft_default is not ioattrs.MISSING
                 or ioattrs.soft_default_factory is not ioattrs.MISSING)):
                if ioattrs.soft_default is not ioattrs.MISSING:
                    soft_default = ioattrs.soft_default
                else:
                    assert callable(ioattrs.soft_default_factory)
                    soft_default = ioattrs.soft_default_factory()
                args[key] = soft_default

                # Make sure these values are valid since we didn't run
                # them through our normal input type checking.

                self._type_check_soft_default(
                    value=soft_default,
                    anntype=aparsed[0],
                    fieldpath=(f'{fieldpath}.{key}' if fieldpath else key))

        try:
            out = cls(**args)
        except Exception as exc:
            raise ValueError(f'Error instantiating class {cls.__name__}'
                             f' at {fieldpath}: {exc}') from exc
        if extra_attrs:
            setattr(out, EXTRA_ATTRS_ATTR, extra_attrs)
        return out
Exemple #4
0
    def _dataclass_from_input(self, cls: type, fieldpath: str,
                              values: dict) -> Any:
        """Given a dict, instantiates a dataclass of the given type.

        The dict must be in the json-friendly format as emitted from
        dataclass_to_dict. This means that sequence values such as tuples or
        sets should be passed as lists, enums should be passed as their
        associated values, and nested dataclasses should be passed as dicts.
        """
        # pylint: disable=too-many-locals
        if not isinstance(values, dict):
            raise TypeError(
                f'Expected a dict for {fieldpath} on {cls.__name__};'
                f' got a {type(values)}.')

        prep = PrepSession(explicit=False).prep_dataclass(cls,
                                                          recursion_level=0)
        assert prep is not None

        extra_attrs = {}

        # noinspection PyDataclass
        fields = dataclasses.fields(cls)
        fields_by_name = {f.name: f for f in fields}
        args: dict[str, Any] = {}
        for rawkey, value in values.items():
            key = prep.storage_names_to_attr_names.get(rawkey, rawkey)
            field = fields_by_name.get(key)

            # Store unknown attrs off to the side (or error if desired).
            if field is None:
                if self._allow_unknown_attrs:
                    if self._discard_unknown_attrs:
                        continue

                    # Treat this like 'Any' data; ensure that it is valid
                    # raw json.
                    if not _is_valid_for_codec(value, self._codec):
                        raise TypeError(
                            f'Unknown attr \'{key}\''
                            f' on {fieldpath} contains data type(s)'
                            f' not supported by the specified codec'
                            f' ({self._codec.name}).')
                    extra_attrs[key] = value
                else:
                    raise AttributeError(
                        f"'{cls.__name__}' has no '{key}' field.")
            else:
                fieldname = field.name
                anntype = prep.annotations[fieldname]
                anntype, ioattrs = _parse_annotated(anntype)

                subfieldpath = (f'{fieldpath}.{fieldname}'
                                if fieldpath else fieldname)
                args[key] = self._value_from_input(cls, subfieldpath, anntype,
                                                   value, ioattrs)
        try:
            out = cls(**args)
        except Exception as exc:
            raise RuntimeError(f'Error instantiating class {cls.__name__}'
                               f' at {fieldpath}: {exc}') from exc
        if extra_attrs:
            setattr(out, EXTRA_ATTRS_ATTR, extra_attrs)
        return out
    def prep_dataclass(self, cls: type,
                       recursion_level: int) -> Optional[PrepData]:
        """Run prep on a dataclass if necessary and return its prep data.

        The only case where this will return None is for recursive types
        if the type is already being prepped higher in the call order.
        """
        # pylint: disable=too-many-locals
        # pylint: disable=too-many-branches

        # We should only need to do this once per dataclass.
        existing_data = getattr(cls, PREP_ATTR, None)
        if existing_data is not None:
            assert isinstance(existing_data, PrepData)
            return existing_data

        # Sanity check.
        # Note that we now support recursive types via the PREP_SESSION_ATTR,
        # so we theoretically shouldn't run into this this.
        if recursion_level > MAX_RECURSION:
            raise RuntimeError('Max recursion exceeded.')

        # We should only be passed classes which are dataclasses.
        if not isinstance(cls, type) or not dataclasses.is_dataclass(cls):
            raise TypeError(f'Passed arg {cls} is not a dataclass type.')

        # Add a pointer to the prep-session while doing the prep.
        # This way we can ignore types that we're already in the process
        # of prepping and can support recursive types.
        existing_prep = getattr(cls, PREP_SESSION_ATTR, None)
        if existing_prep is not None:
            if existing_prep is self:
                return None
            # We shouldn't need to support failed preps
            # or preps from multiple threads at once.
            raise RuntimeError('Found existing in-progress prep.')
        setattr(cls, PREP_SESSION_ATTR, self)

        # Generate a warning on non-explicit preps; we prefer prep to
        # happen explicitly at runtime so errors can be detected early on.
        if not self.explicit:
            logging.warning(
                'efro.dataclassio: implicitly prepping dataclass: %s.'
                ' It is highly recommended to explicitly prep dataclasses'
                ' as soon as possible after definition (via'
                ' efro.dataclassio.ioprep() or the'
                ' @efro.dataclassio.ioprepped decorator).', cls)

        try:
            # NOTE: Now passing the class' __dict__ (vars()) as locals
            # which allows us to pick up nested classes, etc.
            resolved_annotations = get_type_hints(cls,
                                                  localns=vars(cls),
                                                  globalns=self.globalns,
                                                  include_extras=True)
            # pylint: enable=unexpected-keyword-arg
        except Exception as exc:
            raise TypeError(
                f'dataclassio prep for {cls} failed with error: {exc}.'
                f' Make sure all types used in annotations are defined'
                f' at the module or class level or add them as part of an'
                f' explicit prep call.') from exc

        # noinspection PyDataclass
        fields = dataclasses.fields(cls)
        fields_by_name = {f.name: f for f in fields}

        all_storage_names: set[str] = set()
        storage_names_to_attr_names: dict[str, str] = {}

        # Ok; we've resolved actual types for this dataclass.
        # now recurse through them, verifying that we support all contained
        # types and prepping any contained dataclass types.
        for attrname, anntype in resolved_annotations.items():

            anntype, ioattrs = _parse_annotated(anntype)

            # If we found attached IOAttrs data, make sure it contains
            # valid values for the field it is attached to.
            if ioattrs is not None:
                ioattrs.validate_for_field(cls, fields_by_name[attrname])
                if ioattrs.storagename is not None:
                    storagename = ioattrs.storagename
                    storage_names_to_attr_names[ioattrs.storagename] = attrname
                else:
                    storagename = attrname
            else:
                storagename = attrname

            # Make sure we don't have any clashes in our storage names.
            if storagename in all_storage_names:
                raise TypeError(f'Multiple attrs on {cls} are using'
                                f' storage-name \'{storagename}\'')
            all_storage_names.add(storagename)

            self.prep_type(cls,
                           attrname,
                           anntype,
                           recursion_level=recursion_level + 1)

        # Success! Store our resolved stuff with the class and we're done.
        prepdata = PrepData(
            annotations=resolved_annotations,
            storage_names_to_attr_names=storage_names_to_attr_names)
        setattr(cls, PREP_ATTR, prepdata)

        # Clear our prep-session tag.
        assert getattr(cls, PREP_SESSION_ATTR, None) is self
        delattr(cls, PREP_SESSION_ATTR)
        return prepdata