def _wrap_module_methods(cls): """Wraps user-defined non-inherited methods with state management functions.""" exclusions = ([f.name for f in dataclasses.fields(cls)] + ['__eq__', '__repr__', '__init__', '__hash__']) for key in _get_local_method_names(cls, exclude=exclusions): method = getattr(cls, key) if _use_named_call and key != 'setup': # We import named_call at runtime to avoid a circular import issue. from flax.linen.transforms import named_call # pylint: disable=g-import-not-at-top method = named_call(method) setattr(cls, key, wrap_method(method)) return cls
def _wrap_module_methods(cls): # We only want to wrap user-defined non-inherited methods. exclusions = ([f.name for f in dataclasses.fields(cls)] + ['__eq__', '__repr__', '__init__']) for key in get_local_method_names(cls, exclude=exclusions): method = getattr(cls, key) if _use_named_call and key != 'setup': printkey = f'.{key}' if key != '__call__' else '' method_name = f'{cls.__name__}{printkey}' # We import named_call at runtime to avoid a circular import issue. from flax.linen.transforms import named_call # pylint: disable=g-import-not-at-top method = named_call(method, method_name) setattr(cls, key, wrap_method(method)) return cls