예제 #1
0
  def param(self, name: str, init_fn: Callable[..., T], *init_args) -> T:
    """Declares and returns a parameter in this Module.

    Parameters are read-only variables in the collection named "params". See
    :mod:`flax.core.variables` for more details on variables.

    The first argument of `init_fn` is assumed to be a PRNG key, which is
    provided automatically and does not have to be passed using `init_args`::

      mean = self.param('mean', lecun_normal(), (2, 2))

    In the example above, the function `lecun_normal` expects two arguments:
    `key` and `shape`, but only `shape` has to be provided explicitly; `key`
    is set automatically using the PRNG for `params` that is passed when
    initializing the module using :meth:`init`.

    Args:
      name: The parameter name.
      init_fn: The function that will be called to compute the initial value
        of this variable. This function will only be called the first time
        this parameter is used in this module.
      *init_args: The arguments to pass to init_fn.

    Returns:
      The value of the initialized parameter.
    """
    if not self._initialization_allowed:
      raise ValueError(
          'Parameters must be initialized in `setup()` or in a method '
          'wrapped in `@compact`')
    if self._name_taken(name):
      raise errors.NameInUseError('param', name, self.__class__.__name__)
    v = self.scope.param(name, init_fn, *init_args)
    self._state.children[name] = 'params'
    return v
예제 #2
0
파일: module.py 프로젝트: davisyoshida/flax
    def __post_init__(self):
        # DO NOT REMOVE - Marker for internal logging.
        # In dataclasses, __init__ is overridden to process dataclass arguments,
        # and __post_init__ is called immediately afterwards. Here, depending on the
        # type of `parent` passed to initialize the Module, we either defer
        # initialization, attach this Module as a submodule of a parent, or bind
        # this Module at the top-level to variables and rngs.

        object.__setattr__(self, '_state', _ModuleInternalState())

        # Typically we set the parent based on the dynamic module context.
        if self.parent is _unspecified_parent:  # pytype: disable=attribute-error
            object.__setattr__(self, 'parent', _context.module_stack[-1])

        # Initialization is deferred for top level Modules or any other "orphan"
        # Modules until attachment by __setattr__ i.e. MyModule(..., parent=None)
        if self.parent is None:
            return

        # Register submodule on parent Module.
        if isinstance(self.parent, Module):
            # When initializing an unnamed Module inside setup()
            # initialization is deferred until attachment by __setattr__
            # i.e. self.mymodule = MyModule(...)
            if self.parent._state.in_setup and self.name is None:  # pytype: disable=attribute-error
                return
            if not self.parent._initialization_allowed:
                raise errors.AssignSubModuleError(self.__class__.__name__)
            # Autonaming of submodules.
            if self.name is None:  # pytype: disable=attribute-error
                prefix = f"{self.__class__.__name__}"
                cursor = self.parent._state.autoname_cursor.get(prefix, 0)
                self.name = f"{prefix}_{cursor}"
                self.parent._state.autoname_cursor[prefix] = cursor + 1
            if self.parent._name_taken(self.name, self):
                parent_class = self.parent.__class__.__name__
                raise errors.NameInUseError('submodule', self.name,
                                            parent_class)
            self.parent._state.children[self.name] = self
            object.__setattr__(self, 'scope',
                               self.parent.scope.push(self.name))

        # Top-level invocation with a functional Scope.
        elif isinstance(self.parent, Scope):
            object.__setattr__(self, 'scope', self.parent)
        else:
            raise ValueError("parent must be None, Module or Scope")

        self._state.is_initialized = True
예제 #3
0
파일: module.py 프로젝트: davisyoshida/flax
    def variable(self, col: str, name: str, init_fn, *init_args) -> Variable:
        """Declares and returns a variable in this Module.

    See :mod:`flax.core.variables` for more information. See also :meth:`param`
    for a shorthand way to define read-only variables in the "params"
    collection.

    Contrary to :meth:`param`, all arguments passing using `init_fn` should be
    passed on explictly::

      key = self.make_rng('stats')
      mean = self.variable('stats', 'mean', lecun_normal(), key, (2, 2))

    In the example above, the function `lecun_normal` expects two arguments:
    `key` and `shape`, and both have to be passed on. The PRNG for `stats` has
    to be provided explicitly when calling :meth:`init` and :meth:`apply`.

    Args:
      col: The variable collection name.
      name: The variable name.
      init_fn: The function that will be called to compute the initial value
        of this variable. This function will only be called the first time
        this variable is used in this module.
      *init_args: The arguments to pass to init_fn.

    Returns:
      A :class:`flax.core.variables.Variable` that can be read or set via 
      ".value" attribute. Throws an error if the variable exists already.
    """
        if not self._initialization_allowed:
            raise ValueError(
                'Variables must be initialized in `setup()` or in a method '
                'wrapped in `@compact`')
        if self._name_taken(name):
            raise errors.NameInUseError('variable', name,
                                        self.__class__.__name__)
        v = self.scope.variable(col, name, init_fn, *init_args)
        self._state.children[name] = col
        return v