示例#1
0
文件: module.py 项目: hmph/flax
  def wrapped_module_method(*args, **kwargs):
    # We might have incorrectly wrappped a callable
    # that is not a method. Check whether the first arg is self,
    # otherwise call the wrapped function as is.
    if args and isinstance(args[0], Module):
      self, args = args[0], args[1:]
    else:
      return fun(*args, **kwargs)
    is_compact_method = hasattr(fun, 'compact')
    is_setup_method = fun.__name__ == 'setup'
    # We lazily call setup() only when needed.
    if not is_setup_method:
      self._try_setup()

    if is_compact_method:
      if self.scope is None:
        raise errors.CallCompactUnboundModuleError()
      self._state.in_compact_method = True
    _context.module_stack.append(self)
    try:
      y = fun(self, *args, **kwargs)
      if _context.capture_stack:
        filter_fn = _context.capture_stack[-1]
        if filter_fn and filter_fn(self, fun.__name__):
          self.sow('intermediates', fun.__name__, y)
      return y
    finally:
      _context.module_stack.pop()
      if is_compact_method:
        object.__setattr__(self, 'scope', self.scope.rewound())
      if is_compact_method or is_setup_method:
        self._state.reset()
示例#2
0
  def wrapped_module_method(*args, **kwargs):
    # We might have incorrectly wrappped a callable
    # that is not a method. Check whether the first arg is self,
    # otherwise call the wrapped function as is.
    if args and isinstance(args[0], Module):
      self, args = args[0], args[1:]
    else:
      return fun(*args, **kwargs)
    is_compact_method = hasattr(fun, 'compact')
    is_setup_method = fun.__name__ == 'setup'
    # We lazily call setup() only when needed.
    if is_setup_method:
      is_recurrent = self._state.in_setup
      self._state.in_setup = True
    else:
      self._try_setup()

    if is_compact_method:
      if self.scope is None:
        raise errors.CallCompactUnboundModuleError()
      is_recurrent = self._state.in_compact_method
      self._state.in_compact_method = True
    _context.module_stack.append(self)
    try:
      y = fun(self, *args, **kwargs)
      if _context.capture_stack:
        filter_fn = _context.capture_stack[-1]
        if filter_fn and filter_fn(self, fun.__name__):
          self.sow('intermediates', fun.__name__, y)
      return y
    finally:
      _context.module_stack.pop()
      if is_compact_method:
        object.__setattr__(self, 'scope', self.scope.rewound())
      # setup or compact calls can be recurrent for example due to super calls
      # resetting the state would cause is compact/setup method
      # to be set to False prematurely.
      if (is_compact_method or is_setup_method) and not is_recurrent:
        self._state.reset()