Exemplo n.º 1
0
    def __init__(self, module, sn_kwargs, pow_iter_collection, *args,
                 **kwargs):
        """Constructs a wrapped Sonnet module with Spectral Normalization.

    The module expects a first argument which should be a Sonnet AbstractModule
    and a second argument which is a dictionary which is passed to the inner
    spectral_norm function as kwargs.

    When connecting this module to the graph,the argument 'pow_iter_collection'
    is treated specially for this wrapper (rather than for the _build
    method of the inner module). If pow_iter_collection is None (the default),
    the approximate first singular value for weights will *not* be updated based
    on the inputs passed at the given _build call. However an op for updating
    the singular value will be placed into the pow_iter_collection global
    collection.

    If pow_iter_collection is None or not passed, a control dependency on the
    update op will be applied to the output of the _build function. Regardless,
    the kwarg is deleted from the list of keywords passed to the inner module.

    Args:
      module: A constructor/class reference for a Sonnet module you would like
          to construct.
      sn_kwargs: Keyword arguments to be passed to the spectral_norm function
          in addition to the weight tensor.
      pow_iter_collection: The name of a global collection for potentially
          storing ops for updating internal variables.
      *args: Construction-time arguments to the module.
      **kwargs: Construction-time  keyword arguments to the module.
    """
        name = kwargs.get('name', 'sn') + '_wrapper'
        # Our getter needs to be able to be disabled.
        getter_immediate_update, getter_deferred_update = self.sn_getter(
            sn_kwargs)
        w_getter = lambda g: util.custom_getter_router({'.*/w$': g}, lambda s:
                                                       s)
        getter_immediate_update = w_getter(getter_immediate_update)
        getter_deferred_update = w_getter(getter_deferred_update)
        self._context_getter = context.Context(
            getter_immediate_update, default_getter=getter_deferred_update)
        self.pow_iter_collection = pow_iter_collection
        super(SpectralNormWrapper,
              self).__init__(name=name, custom_getter=self._context_getter)

        # Let's construct our model.
        with self._enter_variable_scope():
            self._module = module(*args, **kwargs)
Exemplo n.º 2
0
    def __init__(self, _sentinel=None, custom_getter=None, name=None):  # pylint: disable=invalid-name
        """Performs the initialisation necessary for all AbstractModule instances.

    Every subclass of AbstractModule must begin their constructor with a call to
    this constructor, i.e.

    `super(MySubModule, self).__init__(custom_getter=custom_getter, name=name)`.

    If you instantiate sub-modules in __init__ you must create them within the
    `_enter_variable_scope` context manager to ensure they are in the module's
    variable scope. Alternatively, instantiate sub-modules in `_build`.

    Args:
      _sentinel: Variable that only carries a non-None value if `__init__` was
          called without named parameters. If this is the case, a deprecation
          warning is issued in form of a `ValueError`.
      custom_getter: Callable or dictionary of callables to use as
        custom getters inside the module. If a dictionary, the keys
        correspond to regexes to match variable names. See the `tf.get_variable`
        documentation for information about the custom_getter API.
      name: Name of this module. Used to construct the Templated build function.
          If `None` the module's class name is used (converted to snake case).

    Raises:
      TypeError: If `name` is not a string.
      TypeError: If a given `custom_getter` is not callable.
      ValueError: If `__init__` was called without named arguments.
    """
        if _sentinel is not None:
            raise ValueError("Calling AbstractModule.__init__ without named "
                             "arguments is not supported.")

        if name is None:
            name = util.to_snake_case(self.__class__.__name__)
        elif not isinstance(name, six.string_types):
            raise TypeError("Name must be a string, not {} of type {}.".format(
                name, type(name)))

        self._is_connected = False
        self._connected_subgraphs = []

        # If the given custom getter is a dictionary with a per-variable custom
        # getter, wrap it into a single custom getter.
        if isinstance(custom_getter, collections.Mapping):
            self._custom_getter = util.custom_getter_router(
                custom_getter_map=custom_getter,
                name_fn=lambda name: name[len(self.scope_name) + 1:])
        elif custom_getter is not None and not callable(custom_getter):
            raise TypeError("Given custom_getter is not callable.")
        else:
            self._custom_getter = custom_getter

        self._template = tf.make_template(name,
                                          self._build_wrapper,
                                          create_scope_now_=True,
                                          custom_getter_=self._custom_getter)

        self._original_name = name
        self._unique_name = self._template.variable_scope.name.split("/")[-1]

        # Update __call__ and the object docstrings to enable better introspection.
        self.__doc__ = self._build.__doc__
        self.__call__.__func__.__doc__ = self._build.__doc__

        # Keep track of which graph this module has been connected to. Sonnet
        # modules cannot be connected to multiple graphs, as transparent variable
        # sharing is impossible in that case.
        self._graph = None

        # Container for all variables created in this module and its sub-modules.
        self._all_variables = set([])

        # Calling `.defun()` causes the module's call method to become wrapped as
        # a graph function.
        self._defun_wrapped = False
Exemplo n.º 3
0
  def __init__(self, _sentinel=None, custom_getter=None,
               name=None):  # pylint: disable=invalid-name
    """Performs the initialisation necessary for all AbstractModule instances.

    Every subclass of AbstractModule must begin their constructor with a call to
    this constructor, i.e.

    `super(MySubModule, self).__init__(custom_getter=custom_getter, name=name)`.

    If you instantiate sub-modules in __init__ you must create them within the
    `_enter_variable_scope` context manager to ensure they are in the module's
    variable scope. Alternatively, instantiate sub-modules in `_build`.

    Args:
      _sentinel: Variable that only carries a non-None value if `__init__` was
          called without named parameters. If this is the case, a deprecation
          warning is issued in form of a `ValueError`.
      custom_getter: Callable or dictionary of callables to use as
        custom getters inside the module. If a dictionary, the keys
        correspond to regexes to match variable names. See the `tf.get_variable`
        documentation for information about the custom_getter API.
      name: Name of this module. Used to construct the Templated build function.
          If `None` the module's class name is used (converted to snake case).

    Raises:
      TypeError: If `name` is not a string.
      TypeError: If a given `custom_getter` is not callable.
      ValueError: If `__init__` was called without named arguments.
    """
    if _sentinel is not None:
      raise ValueError("Calling AbstractModule.__init__ without named "
                       "arguments is not supported.")

    if name is None:
      name = util.to_snake_case(self.__class__.__name__)
    elif not isinstance(name, six.string_types):
      raise TypeError("Name must be a string.")

    self._connected_subgraphs = []

    # If the given custom getter is a dictionary with a per-variable custom
    # getter, wrap it into a single custom getter.
    if isinstance(custom_getter, collections.Mapping):
      self._custom_getter = util.custom_getter_router(
          custom_getter_map=custom_getter,
          name_fn=lambda name: name[len(self.scope_name) + 1:])
    else:
      if not (custom_getter is None or callable(custom_getter)):
        raise TypeError("Given custom_getter is not callable.")
      self._custom_getter = custom_getter
    self._custom_getter = _maybe_wrap_custom_getter(
        _variable_tracking_custom_getter, self._custom_getter)

    self._template = tf.make_template(name,
                                      self._build_wrapper,
                                      create_scope_now_=True,
                                      custom_getter_=self._custom_getter)

    self._original_name = name
    self._unique_name = self._template.variable_scope.name.split("/")[-1]

    # Update __call__ and the object docstrings to enable better introspection.
    self.__doc__ = self._build.__doc__
    self.__call__.__func__.__doc__ = self._build.__doc__

    # Keep track of which graph this module has been connected to. Sonnet
    # modules cannot be connected to multiple graphs, as transparent variable
    # sharing is impossible in that case.
    self._graph = None

    # Container for all variables created in this module and its sub-modules.
    self._all_variables = set([])