Example #1
0
File: module.py Project: hmph/flax
 def _customized_dataclass_transform(cls):
   """Handles final optional dataclass attributes: `parent` and `name`."""
   # Use cls.__dict__ to get annotations of cls itself (no parent class).
   annotations = dict(cls.__dict__.get('__annotations__', {}))
   if 'parent' in annotations or 'name' in annotations:
     raise errors.ReservedModuleAttributeError(annotations)
   # Add `parent` and `name` default fields at end.
   # We temporarily modify base class __dataclass_fields__ to force desired
   # argument behavior and ordering from dataclass class-transform.
   parent_dataclass_fields = dict(getattr(cls, '__dataclass_fields__', {}))
   # Remove 'parent' and 'name' from parents because we always want parent and
   # name to show up last in the dataclass args.
   if 'parent' in parent_dataclass_fields:
     cls.__dataclass_fields__.pop('parent')  # pytype: disable=attribute-error
   if 'name' in parent_dataclass_fields:
     cls.__dataclass_fields__.pop('name')  # pytype: disable=attribute-error
   annotations['parent'] = Union[Type["Module"], Type["Scope"],
                                 Type["_Sentinel"], None]
   cls.parent = dataclasses.field(repr=False, default=_unspecified_parent)
   annotations['name'] = str
   cls.name = None  # default value of name is None.
   cls.__annotations__ = annotations
   # Now apply dataclass transform (which operates in-place).
   dataclasses.dataclass(cls, unsafe_hash=True, repr=False)  # pytype: disable=wrong-keyword-args
   cls.__hash__ = _wrap_hash(cls.__hash__)
   # Restore original base class __dataclass_fields__.
   if dataclasses.is_dataclass(cls.__bases__[0]):
     cls.__bases__[0].__dataclass_fields__ = parent_dataclass_fields
Example #2
0
  def _customized_dataclass_transform(cls):
    """Handles final optional dataclass attributes: `parent` and `name`."""
    # Use cls.__dict__ to get annotations of cls itself (no parent class).
    annotations = dict(cls.__dict__.get('__annotations__', {}))
    parent_annotation = Union[Type["Module"], Type["Scope"],
                              Type["_Sentinel"], None]
    if ('parent' in annotations
        and annotations['parent'] != parent_annotation):
      raise errors.ReservedModuleAttributeError(annotations)
    if 'name' in annotations and annotations['name'] != str:
      raise errors.ReservedModuleAttributeError(annotations)
    # Add `parent` and `name` default fields at end.
    # We temporarily modify base class __dataclass_fields__ to force desired
    # argument behavior and ordering from dataclass class-transform.
    parent_dataclass_fields = []
    for clz in cls.__mro__[1:]:
      pdf = dict(getattr(clz, '__dataclass_fields__', {}))
      parent_dataclass_fields.append(pdf)

      # Remove 'parent' and 'name' from parents because we always want parent
      # and name to show up last in the dataclass args.
      if 'parent' in pdf:
        clz.__dataclass_fields__.pop('parent')  # pytype: disable=attribute-error
      if 'name' in pdf:
        clz.__dataclass_fields__.pop('name')  # pytype: disable=attribute-error

    annotations['parent'] = parent_annotation
    cls.parent = dataclasses.field(repr=False, default=_unspecified_parent)
    annotations['name'] = str
    cls.name = None  # default value of name is None.
    cls.__annotations__ = annotations
    # Now apply dataclass transform (which operates in-place).
    # Do generate a hash function only if not provided by the class.
    dataclasses.dataclass(
      cls, unsafe_hash="__hash__" not in cls.__dict__, repr=False)  # pytype: disable=wrong-keyword-args
    cls.__hash__ = _wrap_hash(cls.__hash__)
    # Restore original base class __dataclass_fields__.
    for clz, pdf in zip(cls.__mro__[1:], parent_dataclass_fields):
      if dataclasses.is_dataclass(clz):
        clz.__dataclass_fields__ = pdf