def param(self, name: str, init_fn: Callable[..., T], *init_args) -> T: """Declares and returns a parameter in this Module. Parameters are read-only variables in the collection named "params". See :mod:`flax.core.variables` for more details on variables. The first argument of `init_fn` is assumed to be a PRNG key, which is provided automatically and does not have to be passed using `init_args`:: mean = self.param('mean', lecun_normal(), (2, 2)) In the example above, the function `lecun_normal` expects two arguments: `key` and `shape`, but only `shape` has to be provided explicitly; `key` is set automatically using the PRNG for `params` that is passed when initializing the module using :meth:`init`. Args: name: The parameter name. init_fn: The function that will be called to compute the initial value of this variable. This function will only be called the first time this parameter is used in this module. *init_args: The arguments to pass to init_fn. Returns: The value of the initialized parameter. """ if not self._initialization_allowed: raise ValueError( 'Parameters must be initialized in `setup()` or in a method ' 'wrapped in `@compact`') if self._name_taken(name): raise errors.NameInUseError('param', name, self.__class__.__name__) v = self.scope.param(name, init_fn, *init_args) self._state.children[name] = 'params' return v
def __post_init__(self): # DO NOT REMOVE - Marker for internal logging. # In dataclasses, __init__ is overridden to process dataclass arguments, # and __post_init__ is called immediately afterwards. Here, depending on the # type of `parent` passed to initialize the Module, we either defer # initialization, attach this Module as a submodule of a parent, or bind # this Module at the top-level to variables and rngs. object.__setattr__(self, '_state', _ModuleInternalState()) # Typically we set the parent based on the dynamic module context. if self.parent is _unspecified_parent: # pytype: disable=attribute-error object.__setattr__(self, 'parent', _context.module_stack[-1]) # Initialization is deferred for top level Modules or any other "orphan" # Modules until attachment by __setattr__ i.e. MyModule(..., parent=None) if self.parent is None: return # Register submodule on parent Module. if isinstance(self.parent, Module): # When initializing an unnamed Module inside setup() # initialization is deferred until attachment by __setattr__ # i.e. self.mymodule = MyModule(...) if self.parent._state.in_setup and self.name is None: # pytype: disable=attribute-error return if not self.parent._initialization_allowed: raise errors.AssignSubModuleError(self.__class__.__name__) # Autonaming of submodules. if self.name is None: # pytype: disable=attribute-error prefix = f"{self.__class__.__name__}" cursor = self.parent._state.autoname_cursor.get(prefix, 0) self.name = f"{prefix}_{cursor}" self.parent._state.autoname_cursor[prefix] = cursor + 1 if self.parent._name_taken(self.name, self): parent_class = self.parent.__class__.__name__ raise errors.NameInUseError('submodule', self.name, parent_class) self.parent._state.children[self.name] = self object.__setattr__(self, 'scope', self.parent.scope.push(self.name)) # Top-level invocation with a functional Scope. elif isinstance(self.parent, Scope): object.__setattr__(self, 'scope', self.parent) else: raise ValueError("parent must be None, Module or Scope") self._state.is_initialized = True
def variable(self, col: str, name: str, init_fn, *init_args) -> Variable: """Declares and returns a variable in this Module. See :mod:`flax.core.variables` for more information. See also :meth:`param` for a shorthand way to define read-only variables in the "params" collection. Contrary to :meth:`param`, all arguments passing using `init_fn` should be passed on explictly:: key = self.make_rng('stats') mean = self.variable('stats', 'mean', lecun_normal(), key, (2, 2)) In the example above, the function `lecun_normal` expects two arguments: `key` and `shape`, and both have to be passed on. The PRNG for `stats` has to be provided explicitly when calling :meth:`init` and :meth:`apply`. Args: col: The variable collection name. name: The variable name. init_fn: The function that will be called to compute the initial value of this variable. This function will only be called the first time this variable is used in this module. *init_args: The arguments to pass to init_fn. Returns: A :class:`flax.core.variables.Variable` that can be read or set via ".value" attribute. Throws an error if the variable exists already. """ if not self._initialization_allowed: raise ValueError( 'Variables must be initialized in `setup()` or in a method ' 'wrapped in `@compact`') if self._name_taken(name): raise errors.NameInUseError('variable', name, self.__class__.__name__) v = self.scope.variable(col, name, init_fn, *init_args) self._state.children[name] = col return v