def __init__(self, module, sn_kwargs, pow_iter_collection, *args, **kwargs): """Constructs a wrapped Sonnet module with Spectral Normalization. The module expects a first argument which should be a Sonnet AbstractModule and a second argument which is a dictionary which is passed to the inner spectral_norm function as kwargs. When connecting this module to the graph,the argument 'pow_iter_collection' is treated specially for this wrapper (rather than for the _build method of the inner module). If pow_iter_collection is None (the default), the approximate first singular value for weights will *not* be updated based on the inputs passed at the given _build call. However an op for updating the singular value will be placed into the pow_iter_collection global collection. If pow_iter_collection is None or not passed, a control dependency on the update op will be applied to the output of the _build function. Regardless, the kwarg is deleted from the list of keywords passed to the inner module. Args: module: A constructor/class reference for a Sonnet module you would like to construct. sn_kwargs: Keyword arguments to be passed to the spectral_norm function in addition to the weight tensor. pow_iter_collection: The name of a global collection for potentially storing ops for updating internal variables. *args: Construction-time arguments to the module. **kwargs: Construction-time keyword arguments to the module. """ name = kwargs.get('name', 'sn') + '_wrapper' # Our getter needs to be able to be disabled. getter_immediate_update, getter_deferred_update = self.sn_getter( sn_kwargs) w_getter = lambda g: util.custom_getter_router({'.*/w$': g}, lambda s: s) getter_immediate_update = w_getter(getter_immediate_update) getter_deferred_update = w_getter(getter_deferred_update) self._context_getter = context.Context( getter_immediate_update, default_getter=getter_deferred_update) self.pow_iter_collection = pow_iter_collection super(SpectralNormWrapper, self).__init__(name=name, custom_getter=self._context_getter) # Let's construct our model. with self._enter_variable_scope(): self._module = module(*args, **kwargs)
def __init__(self, _sentinel=None, custom_getter=None, name=None): # pylint: disable=invalid-name """Performs the initialisation necessary for all AbstractModule instances. Every subclass of AbstractModule must begin their constructor with a call to this constructor, i.e. `super(MySubModule, self).__init__(custom_getter=custom_getter, name=name)`. If you instantiate sub-modules in __init__ you must create them within the `_enter_variable_scope` context manager to ensure they are in the module's variable scope. Alternatively, instantiate sub-modules in `_build`. Args: _sentinel: Variable that only carries a non-None value if `__init__` was called without named parameters. If this is the case, a deprecation warning is issued in form of a `ValueError`. custom_getter: Callable or dictionary of callables to use as custom getters inside the module. If a dictionary, the keys correspond to regexes to match variable names. See the `tf.get_variable` documentation for information about the custom_getter API. name: Name of this module. Used to construct the Templated build function. If `None` the module's class name is used (converted to snake case). Raises: TypeError: If `name` is not a string. TypeError: If a given `custom_getter` is not callable. ValueError: If `__init__` was called without named arguments. """ if _sentinel is not None: raise ValueError("Calling AbstractModule.__init__ without named " "arguments is not supported.") if name is None: name = util.to_snake_case(self.__class__.__name__) elif not isinstance(name, six.string_types): raise TypeError("Name must be a string, not {} of type {}.".format( name, type(name))) self._is_connected = False self._connected_subgraphs = [] # If the given custom getter is a dictionary with a per-variable custom # getter, wrap it into a single custom getter. if isinstance(custom_getter, collections.Mapping): self._custom_getter = util.custom_getter_router( custom_getter_map=custom_getter, name_fn=lambda name: name[len(self.scope_name) + 1:]) elif custom_getter is not None and not callable(custom_getter): raise TypeError("Given custom_getter is not callable.") else: self._custom_getter = custom_getter self._template = tf.make_template(name, self._build_wrapper, create_scope_now_=True, custom_getter_=self._custom_getter) self._original_name = name self._unique_name = self._template.variable_scope.name.split("/")[-1] # Update __call__ and the object docstrings to enable better introspection. self.__doc__ = self._build.__doc__ self.__call__.__func__.__doc__ = self._build.__doc__ # Keep track of which graph this module has been connected to. Sonnet # modules cannot be connected to multiple graphs, as transparent variable # sharing is impossible in that case. self._graph = None # Container for all variables created in this module and its sub-modules. self._all_variables = set([]) # Calling `.defun()` causes the module's call method to become wrapped as # a graph function. self._defun_wrapped = False
def __init__(self, _sentinel=None, custom_getter=None, name=None): # pylint: disable=invalid-name """Performs the initialisation necessary for all AbstractModule instances. Every subclass of AbstractModule must begin their constructor with a call to this constructor, i.e. `super(MySubModule, self).__init__(custom_getter=custom_getter, name=name)`. If you instantiate sub-modules in __init__ you must create them within the `_enter_variable_scope` context manager to ensure they are in the module's variable scope. Alternatively, instantiate sub-modules in `_build`. Args: _sentinel: Variable that only carries a non-None value if `__init__` was called without named parameters. If this is the case, a deprecation warning is issued in form of a `ValueError`. custom_getter: Callable or dictionary of callables to use as custom getters inside the module. If a dictionary, the keys correspond to regexes to match variable names. See the `tf.get_variable` documentation for information about the custom_getter API. name: Name of this module. Used to construct the Templated build function. If `None` the module's class name is used (converted to snake case). Raises: TypeError: If `name` is not a string. TypeError: If a given `custom_getter` is not callable. ValueError: If `__init__` was called without named arguments. """ if _sentinel is not None: raise ValueError("Calling AbstractModule.__init__ without named " "arguments is not supported.") if name is None: name = util.to_snake_case(self.__class__.__name__) elif not isinstance(name, six.string_types): raise TypeError("Name must be a string.") self._connected_subgraphs = [] # If the given custom getter is a dictionary with a per-variable custom # getter, wrap it into a single custom getter. if isinstance(custom_getter, collections.Mapping): self._custom_getter = util.custom_getter_router( custom_getter_map=custom_getter, name_fn=lambda name: name[len(self.scope_name) + 1:]) else: if not (custom_getter is None or callable(custom_getter)): raise TypeError("Given custom_getter is not callable.") self._custom_getter = custom_getter self._custom_getter = _maybe_wrap_custom_getter( _variable_tracking_custom_getter, self._custom_getter) self._template = tf.make_template(name, self._build_wrapper, create_scope_now_=True, custom_getter_=self._custom_getter) self._original_name = name self._unique_name = self._template.variable_scope.name.split("/")[-1] # Update __call__ and the object docstrings to enable better introspection. self.__doc__ = self._build.__doc__ self.__call__.__func__.__doc__ = self._build.__doc__ # Keep track of which graph this module has been connected to. Sonnet # modules cannot be connected to multiple graphs, as transparent variable # sharing is impossible in that case. self._graph = None # Container for all variables created in this module and its sub-modules. self._all_variables = set([])